@@ -32,6 +32,7 @@ use ndslice::view::Region;
32
32
use serde:: Deserialize ;
33
33
use serde:: Serialize ;
34
34
35
+ use crate :: CommActor ;
35
36
use crate :: alloc:: Alloc ;
36
37
use crate :: alloc:: AllocExt ;
37
38
use crate :: alloc:: AllocatedProc ;
@@ -90,6 +91,7 @@ impl ProcRef {
90
91
pub struct ProcMesh {
91
92
name : Name ,
92
93
allocation : ProcMeshAllocation ,
94
+ comm_actor_name : Name ,
93
95
}
94
96
95
97
impl ProcMesh {
@@ -99,12 +101,21 @@ impl ProcMesh {
99
101
let region = self . allocation . extent ( ) . clone ( ) . into ( ) ;
100
102
match & self . allocation {
101
103
ProcMeshAllocation :: Allocated { ranks, .. } => {
102
- ProcMeshRef :: new ( self . name . clone ( ) , region, Arc :: clone ( ranks) ) . unwrap ( )
104
+ let root_mesh_rank_0 = ranks. first ( ) . expect ( "root mesh cannot be empty" ) . clone ( ) ;
105
+ ProcMeshRef :: new (
106
+ self . name . clone ( ) ,
107
+ region,
108
+ Arc :: clone ( ranks) ,
109
+ self . comm_actor_name . clone ( ) ,
110
+ None , // this is the root mesh
111
+ root_mesh_rank_0,
112
+ )
113
+ . unwrap ( )
103
114
}
104
115
}
105
116
}
106
117
107
- /// Allocate a new ProcMeshRef from the provided alloc.
118
+ /// Allocate a new ProcMesh from the provided alloc.
108
119
/// Allocate does not require an owning actor because references are not owned.
109
120
/// Allocate a new ProcMesh from the provided alloc.
110
121
pub async fn allocate (
@@ -173,7 +184,7 @@ impl ProcMesh {
173
184
}
174
185
}
175
186
176
- let ranks = running
187
+ let ranks: Vec < _ > = running
177
188
. into_iter ( )
178
189
. enumerate ( )
179
190
. map ( |( create_rank, allocated) | ProcRef {
@@ -183,13 +194,23 @@ impl ProcMesh {
183
194
} )
184
195
. collect ( ) ;
185
196
186
- Ok ( Self {
197
+ let proc_mesh = Self {
187
198
name : Name :: new ( name) ,
188
199
allocation : ProcMeshAllocation :: Allocated {
189
200
alloc : Box :: new ( alloc) ,
190
201
ranks : Arc :: new ( ranks) ,
191
202
} ,
192
- } )
203
+ comm_actor_name : Name :: new ( "comm" ) ,
204
+ } ;
205
+ // Spawn a comm actor on each proc, so that they can be used to perform
206
+ // tree distribution and accumulation.
207
+ let comm_actor_name = proc_mesh. comm_actor_name . clone ( ) ;
208
+ proc_mesh
209
+ . freeze ( )
210
+ . spawn_with_name :: < CommActor > ( cx, comm_actor_name, & Default :: default ( ) )
211
+ . await ?;
212
+
213
+ Ok ( proc_mesh)
193
214
}
194
215
}
195
216
@@ -237,11 +258,28 @@ pub struct ProcMeshRef {
237
258
name : Name ,
238
259
region : Region ,
239
260
ranks : Arc < Vec < ProcRef > > ,
261
+ comm_actor_name : Name ,
262
+ // Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
263
+ // should be removed after we remove the v0 code.
264
+ // The root region of this mesh. None means this mesh itself is the root.
265
+ pub ( crate ) root_region : Option < Region > ,
266
+ // Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
267
+ // should be removed after we remove the v0 code.
268
+ // v0 casting requires root mesh rank 0 as the 1st hop, so we need to provide
269
+ // it here. For v1, this can be removed since v1 can use any rank.
270
+ pub ( crate ) root_mesh_rank_0 : ProcRef ,
240
271
}
241
272
242
273
impl ProcMeshRef {
243
- /// Create a new ProcMeshRef from the given name, region, and ranks.
244
- fn new ( name : Name , region : Region , ranks : Arc < Vec < ProcRef > > ) -> v1:: Result < Self > {
274
+ /// Create a new ProcMeshRef from the given name, region, ranks, and so on.
275
+ fn new (
276
+ name : Name ,
277
+ region : Region ,
278
+ ranks : Arc < Vec < ProcRef > > ,
279
+ comm_actor_name : Name ,
280
+ root_region : Option < Region > ,
281
+ root_mesh_rank_0 : ProcRef ,
282
+ ) -> v1:: Result < Self > {
245
283
if region. num_ranks ( ) != ranks. len ( ) {
246
284
return Err ( v1:: Error :: InvalidRankCardinality {
247
285
expected : region. num_ranks ( ) ,
@@ -252,9 +290,16 @@ impl ProcMeshRef {
252
290
name,
253
291
region,
254
292
ranks,
293
+ comm_actor_name,
294
+ root_region,
295
+ root_mesh_rank_0,
255
296
} )
256
297
}
257
298
299
+ pub ( crate ) fn comm_actor_name ( & self ) -> & Name {
300
+ & self . comm_actor_name
301
+ }
302
+
258
303
/// The current statuses of procs in this mesh.
259
304
#[ allow( dead_code) ]
260
305
async fn status ( & self , cx : & impl context:: Actor ) -> v1:: Result < ValueMesh < bool > > {
@@ -273,6 +318,19 @@ impl ProcMeshRef {
273
318
name : & str ,
274
319
params : & A :: Params ,
275
320
) -> v1:: Result < ActorMesh < A > >
321
+ where
322
+ A :: Params : RemoteMessage ,
323
+ {
324
+ self . spawn_with_name ( cx, Name :: new ( name) , params) . await
325
+ }
326
+
327
+ #[ allow( dead_code) ]
328
+ async fn spawn_with_name < A : Actor + RemoteActor > (
329
+ & self ,
330
+ cx : & impl context:: Actor ,
331
+ name : Name ,
332
+ params : & A :: Params ,
333
+ ) -> v1:: Result < ActorMesh < A > >
276
334
where
277
335
A :: Params : RemoteMessage ,
278
336
{
@@ -282,7 +340,6 @@ impl ProcMeshRef {
282
340
. ok_or ( Error :: ActorTypeNotRegistered ( type_name :: < A > ( ) . to_string ( ) ) ) ?
283
341
. to_string ( ) ;
284
342
285
- let name = Name :: new ( name) ;
286
343
let serialized_params = bincode:: serialize ( params) ?;
287
344
288
345
let ( completed_handle, mut completed_receiver) = cx. mailbox ( ) . open_port ( ) ;
@@ -292,7 +349,7 @@ impl ProcMeshRef {
292
349
. gspawn (
293
350
cx,
294
351
actor_type. clone ( ) ,
295
- name. clone ( ) . to_string ( ) ,
352
+ name. to_string ( ) ,
296
353
serialized_params. clone ( ) ,
297
354
completed_handle. bind ( ) ,
298
355
)
@@ -352,7 +409,15 @@ impl view::RankedSliceable for ProcMeshRef {
352
409
. unwrap ( )
353
410
. map ( |index| self . get ( index) . unwrap ( ) . clone ( ) )
354
411
. collect ( ) ;
355
- Self :: new ( self . name . clone ( ) , region, Arc :: new ( ranks) ) . unwrap ( )
412
+ Self :: new (
413
+ self . name . clone ( ) ,
414
+ region,
415
+ Arc :: new ( ranks) ,
416
+ self . comm_actor_name . clone ( ) ,
417
+ Some ( self . root_region . as_ref ( ) . unwrap_or ( & self . region ) . clone ( ) ) ,
418
+ self . root_mesh_rank_0 . clone ( ) ,
419
+ )
420
+ . unwrap ( )
356
421
}
357
422
}
358
423
0 commit comments