Skip to content

Commit 692e1d6

Browse files
pzhan9facebook-github-bot
authored andcommitted
Use v0's cast implementation for v1 ActorMesh (#1187)
Summary: As title. Reviewed By: mariusae Differential Revision: D82251703
1 parent 9f3f506 commit 692e1d6

File tree

4 files changed

+134
-21
lines changed

4 files changed

+134
-21
lines changed

hyperactor_mesh/src/v1.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ pub enum Error {
7777

7878
#[error("error while sending message to actor {0}: {1}")]
7979
SendingError(ActorId, Box<MailboxSenderError>),
80+
81+
#[error("error while casting message to {0}: {1}")]
82+
CastingError(Name, anyhow::Error),
8083
}
8184

8285
impl From<crate::alloc::AllocatorError> for Error {

hyperactor_mesh/src/v1/actor_mesh.rs

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@ use hyperactor::RemoteMessage;
1515
use hyperactor::actor::RemoteActor;
1616
use hyperactor::context;
1717
use hyperactor::message::Castable;
18+
use hyperactor::message::IndexedErasedUnbound;
19+
use hyperactor_mesh_macros::sel;
20+
use ndslice::Selection;
21+
use ndslice::Shape;
22+
use ndslice::ViewExt;
1823
use ndslice::view;
1924
use ndslice::view::Region;
2025
use ndslice::view::View;
21-
use ndslice::view::ViewExt;
2226
use serde::Deserialize;
2327
use serde::Serialize;
2428

29+
use crate::CommActor;
30+
use crate::actor_mesh as v0_actor_mesh;
31+
use crate::reference::ActorMeshId;
2532
use crate::v1;
2633
use crate::v1::Error;
2734
use crate::v1::Name;
@@ -82,16 +89,41 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
8289
/// Cast a message to all actors in this mesh.
8390
pub fn cast<M>(&self, cx: &impl context::Actor, message: M) -> v1::Result<()>
8491
where
85-
M: Castable + RemoteMessage + Clone,
86-
A: RemoteHandles<M>,
92+
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
93+
M: Castable + RemoteMessage,
8794
{
88-
// todo: headers, binding/unbinding/accumulation
89-
for actor_ref in self.values() {
90-
actor_ref
91-
.send(cx, message.clone())
92-
.map_err(|e| Error::SendingError(actor_ref.actor_id().clone(), Box::new(e)))?;
95+
self.region();
96+
let cast_mesh_shape = to_shape(view::Ranked::region(self));
97+
let comm_actor_ref = self
98+
.proc_mesh
99+
.root_mesh_rank_0
100+
.attest::<CommActor>(self.proc_mesh.comm_actor_name());
101+
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
102+
match &self.proc_mesh.root_region {
103+
Some(root_region) => {
104+
let root_mesh_shape = to_shape(root_region);
105+
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
106+
cx,
107+
actor_mesh_id,
108+
&comm_actor_ref,
109+
&sel!(*),
110+
message,
111+
&cast_mesh_shape,
112+
&root_mesh_shape,
113+
)
114+
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
115+
}
116+
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
117+
cx,
118+
actor_mesh_id,
119+
&comm_actor_ref,
120+
sel!(*),
121+
&cast_mesh_shape,
122+
&cast_mesh_shape,
123+
message,
124+
)
125+
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
93126
}
94-
Ok(())
95127
}
96128
}
97129

@@ -120,3 +152,8 @@ impl<A: RemoteActor> view::RankedSliceable for ActorMeshRef<A> {
120152
ActorMeshRef::new(self.name.clone(), proc_mesh, actor_refs)
121153
}
122154
}
155+
156+
fn to_shape(region: &Region) -> Shape {
157+
Shape::new(region.labels().to_vec(), region.slice().clone())
158+
.expect("Shape::new should not fail because a Region by definition is a valid Shape")
159+
}

hyperactor_mesh/src/v1/proc_mesh.rs

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use ndslice::view::Region;
3232
use serde::Deserialize;
3333
use serde::Serialize;
3434

35+
use crate::CommActor;
3536
use crate::alloc::Alloc;
3637
use crate::alloc::AllocExt;
3738
use crate::alloc::AllocatedProc;
@@ -90,6 +91,7 @@ impl ProcRef {
9091
pub struct ProcMesh {
9192
name: Name,
9293
allocation: ProcMeshAllocation,
94+
comm_actor_name: Name,
9395
}
9496

9597
impl ProcMesh {
@@ -99,12 +101,21 @@ impl ProcMesh {
99101
let region = self.allocation.extent().clone().into();
100102
match &self.allocation {
101103
ProcMeshAllocation::Allocated { ranks, .. } => {
102-
ProcMeshRef::new(self.name.clone(), region, Arc::clone(ranks)).unwrap()
104+
let root_mesh_rank_0 = ranks.first().expect("root mesh cannot be empty").clone();
105+
ProcMeshRef::new(
106+
self.name.clone(),
107+
region,
108+
Arc::clone(ranks),
109+
self.comm_actor_name.clone(),
110+
None, // this is the root mesh
111+
root_mesh_rank_0,
112+
)
113+
.unwrap()
103114
}
104115
}
105116
}
106117

107-
/// Allocate a new ProcMeshRef from the provided alloc.
118+
/// Allocate a new ProcMesh from the provided alloc.
108119
/// Allocate does not require an owning actor because references are not owned.
109120
/// Allocate a new ProcMesh from the provided alloc.
110121
pub async fn allocate(
@@ -173,7 +184,7 @@ impl ProcMesh {
173184
}
174185
}
175186

176-
let ranks = running
187+
let ranks: Vec<_> = running
177188
.into_iter()
178189
.enumerate()
179190
.map(|(create_rank, allocated)| ProcRef {
@@ -183,13 +194,23 @@ impl ProcMesh {
183194
})
184195
.collect();
185196

186-
Ok(Self {
197+
let proc_mesh = Self {
187198
name: Name::new(name),
188199
allocation: ProcMeshAllocation::Allocated {
189200
alloc: Box::new(alloc),
190201
ranks: Arc::new(ranks),
191202
},
192-
})
203+
comm_actor_name: Name::new("comm"),
204+
};
205+
// Spawn a comm actor on each proc, so that they can be used to perform
206+
// tree distribution and accumulation.
207+
let comm_actor_name = proc_mesh.comm_actor_name.clone();
208+
proc_mesh
209+
.freeze()
210+
.spawn_with_name::<CommActor>(cx, comm_actor_name, &Default::default())
211+
.await?;
212+
213+
Ok(proc_mesh)
193214
}
194215
}
195216

@@ -237,11 +258,28 @@ pub struct ProcMeshRef {
237258
name: Name,
238259
region: Region,
239260
ranks: Arc<Vec<ProcRef>>,
261+
comm_actor_name: Name,
262+
// Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
263+
// should be removed after we remove the v0 code.
264+
// The root region of this mesh. None means this mesh itself is the root.
265+
pub(crate) root_region: Option<Region>,
266+
// Temporary: used to fit v1 ActorMesh with v0's casting implementation. This
267+
// should be removed after we remove the v0 code.
268+
// v0 casting requires root mesh rank 0 as the 1st hop, so we need to provide
269+
// it here. For v1, this can be removed since v1 can use any rank.
270+
pub(crate) root_mesh_rank_0: ProcRef,
240271
}
241272

242273
impl ProcMeshRef {
243-
/// Create a new ProcMeshRef from the given name, region, and ranks.
244-
fn new(name: Name, region: Region, ranks: Arc<Vec<ProcRef>>) -> v1::Result<Self> {
274+
/// Create a new ProcMeshRef from the given name, region, ranks, and so on.
275+
fn new(
276+
name: Name,
277+
region: Region,
278+
ranks: Arc<Vec<ProcRef>>,
279+
comm_actor_name: Name,
280+
root_region: Option<Region>,
281+
root_mesh_rank_0: ProcRef,
282+
) -> v1::Result<Self> {
245283
if region.num_ranks() != ranks.len() {
246284
return Err(v1::Error::InvalidRankCardinality {
247285
expected: region.num_ranks(),
@@ -252,9 +290,16 @@ impl ProcMeshRef {
252290
name,
253291
region,
254292
ranks,
293+
comm_actor_name,
294+
root_region,
295+
root_mesh_rank_0,
255296
})
256297
}
257298

299+
pub(crate) fn comm_actor_name(&self) -> &Name {
300+
&self.comm_actor_name
301+
}
302+
258303
/// The current statuses of procs in this mesh.
259304
#[allow(dead_code)]
260305
async fn status(&self, cx: &impl context::Actor) -> v1::Result<ValueMesh<bool>> {
@@ -273,6 +318,19 @@ impl ProcMeshRef {
273318
name: &str,
274319
params: &A::Params,
275320
) -> v1::Result<ActorMesh<A>>
321+
where
322+
A::Params: RemoteMessage,
323+
{
324+
self.spawn_with_name(cx, Name::new(name), params).await
325+
}
326+
327+
#[allow(dead_code)]
328+
async fn spawn_with_name<A: Actor + RemoteActor>(
329+
&self,
330+
cx: &impl context::Actor,
331+
name: Name,
332+
params: &A::Params,
333+
) -> v1::Result<ActorMesh<A>>
276334
where
277335
A::Params: RemoteMessage,
278336
{
@@ -282,7 +340,6 @@ impl ProcMeshRef {
282340
.ok_or(Error::ActorTypeNotRegistered(type_name::<A>().to_string()))?
283341
.to_string();
284342

285-
let name = Name::new(name);
286343
let serialized_params = bincode::serialize(params)?;
287344

288345
let (completed_handle, mut completed_receiver) = cx.mailbox().open_port();
@@ -292,7 +349,7 @@ impl ProcMeshRef {
292349
.gspawn(
293350
cx,
294351
actor_type.clone(),
295-
name.clone().to_string(),
352+
name.to_string(),
296353
serialized_params.clone(),
297354
completed_handle.bind(),
298355
)
@@ -352,7 +409,15 @@ impl view::RankedSliceable for ProcMeshRef {
352409
.unwrap()
353410
.map(|index| self.get(index).unwrap().clone())
354411
.collect();
355-
Self::new(self.name.clone(), region, Arc::new(ranks)).unwrap()
412+
Self::new(
413+
self.name.clone(),
414+
region,
415+
Arc::new(ranks),
416+
self.comm_actor_name.clone(),
417+
Some(self.root_region.as_ref().unwrap_or(&self.region).clone()),
418+
self.root_mesh_rank_0.clone(),
419+
)
420+
.unwrap()
356421
}
357422
}
358423

hyperactor_mesh/src/v1/testactor.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
19
//! This module defines a test actor. It is defined in a separate module
210
//! (outside of [`crate::v1::testing`]) to ensure that it is compiled into
311
//! the bootstrap binary, which is not built in test mode (and anyway, test mode
@@ -20,14 +28,14 @@ use serde::Serialize;
2028
#[hyperactor::export(
2129
spawn = true,
2230
handlers = [
23-
GetActorId,
31+
GetActorId { cast = true },
2432
]
2533
)]
2634
pub struct TestActor;
2735

2836
/// A message that returns the recipient actor's id.
2937
#[derive(Debug, Clone, Named, Bind, Unbind, Serialize, Deserialize)]
30-
pub struct GetActorId(pub PortRef<ActorId>);
38+
pub struct GetActorId(#[binding(include)] pub PortRef<ActorId>);
3139

3240
#[async_trait]
3341
impl Handler<GetActorId> for TestActor {

0 commit comments

Comments
 (0)