From 64bd481cc434873b3862df6a1a40dd2625243675 Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Mon, 22 Sep 2025 10:23:25 -0700 Subject: [PATCH] Add casting support to resource::GetState messages (#1262) Summary: Pull Request resolved: https://github.com/meta-pytorch/monarch/pull/1262 Part of: https://github.com/meta-pytorch/monarch/issues/1209 Use casting to implement the `supervision_events` API instead of iterating over all ProcMeshAgents. This will scale better as the size of the ProcMesh increases. This requires making some trait bound changes to `resource::GetState` so that it can be casted with. Unfortunately, the Actor name for the mesh agent is not compatible with the v1::Name struct due to the missing uuid. Make `v1::Name` an enum to allow reserved names to be used for things like ActorMeshes. Also, minor improvement: make ActorMeshRef::supervision_events not take a Name, we can assume it is for the current mesh's name. Differential Revision: D82687236 --- hyperactor_mesh/src/proc_mesh/mesh_agent.rs | 6 +- hyperactor_mesh/src/resource.rs | 37 ++++++++++++ hyperactor_mesh/src/v1.rs | 67 +++++++++++++++------ hyperactor_mesh/src/v1/actor_mesh.rs | 13 ++-- hyperactor_mesh/src/v1/proc_mesh.rs | 43 ++++++++++--- 5 files changed, 132 insertions(+), 34 deletions(-) diff --git a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs index e7ce961de..008c6828d 100644 --- a/hyperactor_mesh/src/proc_mesh/mesh_agent.rs +++ b/hyperactor_mesh/src/proc_mesh/mesh_agent.rs @@ -20,6 +20,7 @@ use enum_as_inner::EnumAsInner; use hyperactor::Actor; use hyperactor::ActorHandle; use hyperactor::ActorId; +use hyperactor::Bind; use hyperactor::Context; use hyperactor::Data; use hyperactor::HandleClient; @@ -31,6 +32,7 @@ use hyperactor::PortHandle; use hyperactor::PortRef; use hyperactor::ProcId; use hyperactor::RefClient; +use hyperactor::Unbind; use hyperactor::actor::ActorStatus; use hyperactor::actor::remote::Remote; use hyperactor::channel; @@ -167,7 +169,7 @@ impl State { handlers=[ MeshAgentMessage, resource::CreateOrUpdate, - resource::GetState + resource::GetState { cast = true }, ] )] pub struct ProcMeshAgent { @@ -425,7 +427,7 @@ pub struct ActorSpec { } /// Actor state. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)] pub struct ActorState { /// The actor's ID. pub actor_id: ActorId, diff --git a/hyperactor_mesh/src/resource.rs b/hyperactor_mesh/src/resource.rs index 2b5fa8737..151068eca 100644 --- a/hyperactor_mesh/src/resource.rs +++ b/hyperactor_mesh/src/resource.rs @@ -16,6 +16,10 @@ use hyperactor::Handler; use hyperactor::Named; use hyperactor::PortRef; use hyperactor::RefClient; +use hyperactor::RemoteMessage; +use hyperactor::message::Bind; +use hyperactor::message::Bindings; +use hyperactor::message::Unbind; use serde::Deserialize; use serde::Serialize; @@ -70,3 +74,36 @@ pub struct GetState { #[reply] pub reply: PortRef>, } + +// Cannot derive Bind and Unbind for this generic, implement manually. +impl Unbind for GetState +where + S: RemoteMessage, + S: Unbind, +{ + fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> { + self.reply.unbind(bindings) + } +} + +impl Bind for GetState +where + S: RemoteMessage, + S: Bind, +{ + fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> { + self.reply.bind(bindings) + } +} + +impl Clone for GetState +where + S: RemoteMessage, +{ + fn clone(&self) -> Self { + Self { + name: self.name.clone(), + reply: self.reply.clone(), + } + } +} diff --git a/hyperactor_mesh/src/v1.rs b/hyperactor_mesh/src/v1.rs index c3f04611b..a71b381f7 100644 --- a/hyperactor_mesh/src/v1.rs +++ b/hyperactor_mesh/src/v1.rs @@ -167,27 +167,51 @@ pub type Result = std::result::Result; Serialize, Deserialize )] -pub struct Name(pub String, pub ShortUuid); +pub enum Name { + /// Normal names for most actors. + Suffixed(String, ShortUuid), + /// Reserved names for system actors without UUIDs. + Reserved(String), +} impl Name { /// Create a new `Name` from a user-provided base name. pub fn new(name: impl Into) -> Self { + Self::new_with_uuid(name, Some(ShortUuid::generate())) + } + + /// Create a Reserved `Name` with no uuid. Only for use by system actors. + pub(crate) fn new_reserved(name: impl Into) -> Self { + Self::new_with_uuid(name, None) + } + + fn new_with_uuid(name: impl Into, uuid: Option) -> Self { let mut name = name.into(); if name.is_empty() { name = "unnamed".to_string(); } - let uuid = ShortUuid::generate(); - Self(name, uuid) + if let Some(uuid) = uuid { + Self::Suffixed(name, uuid) + } else { + Self::Reserved(name) + } } /// The name portion of this `Name`. pub fn name(&self) -> &str { - &self.0 + match self { + Self::Suffixed(n, _) => n, + Self::Reserved(n) => n, + } } /// The UUID portion of this `Name`. + /// Only valid for Name::Suffixed, if called on Name::Reserved it'll panic. pub fn uuid(&self) -> &ShortUuid { - &self.1 + match self { + Self::Suffixed(_, uuid) => uuid, + Self::Reserved(_) => panic!("Reserved name has no UUID"), + } } } @@ -211,24 +235,33 @@ impl FromStr for Name { type Err = NameParseError; fn from_str(s: &str) -> std::result::Result { - let (name, uuid) = s.split_once('-').ok_or(NameParseError::MissingSeparator)?; - if name.is_empty() { - return Err(NameParseError::MissingName); + if let Some((name, uuid)) = s.split_once('-') { + if name.is_empty() { + return Err(NameParseError::MissingName); + } + if uuid.is_empty() { + return Err(NameParseError::MissingName); + } + + Ok(Name::new_with_uuid(name.to_string(), Some(uuid.parse()?))) + } else { + if s.is_empty() { + return Err(NameParseError::MissingName); + } + Ok(Name::new_reserved(s)) } - if uuid.is_empty() { - return Err(NameParseError::MissingName); - } - - let name = name.to_string(); - let uuid = uuid.parse()?; - Ok(Name(name, uuid)) } } impl std::fmt::Display for Name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}-", self.name())?; - self.uuid().format(f, true /*raw*/) + match self { + Self::Suffixed(n, uuid) => { + write!(f, "{}-", n)?; + uuid.format(f, true /*raw*/) + } + Self::Reserved(n) => write!(f, "{}", n), + } } } diff --git a/hyperactor_mesh/src/v1/actor_mesh.rs b/hyperactor_mesh/src/v1/actor_mesh.rs index 5cce80815..cb654eb07 100644 --- a/hyperactor_mesh/src/v1/actor_mesh.rs +++ b/hyperactor_mesh/src/v1/actor_mesh.rs @@ -182,9 +182,10 @@ impl ActorMeshRef { pub async fn supervision_events( &self, cx: &impl context::Actor, - name: Name, ) -> v1::Result>> { - self.proc_mesh.supervision_events(cx, name).await + self.proc_mesh + .supervision_events(cx, self.name.clone()) + .await } } @@ -455,7 +456,7 @@ mod tests { } #[async_timed_test(timeout_secs = 30)] - async fn test_status() { + async fn test_supervision_events() { hyperactor_telemetry::initialize_logging_for_test(); let instance = testing::instance().await; @@ -489,12 +490,8 @@ mod tests { // Now that all ranks have completed, set up a continuous poll of the // status such that when a process switches to unhealthy it sets a // supervision event. - let child_name_clone = child_name.clone(); let supervision_task = tokio::spawn(async move { - match actor_mesh - .supervision_events(instance, child_name_clone) - .await - { + match actor_mesh.supervision_events(&instance).await { Ok(events) => { for event_list in events.values() { assert!(!event_list.is_empty()); diff --git a/hyperactor_mesh/src/v1/proc_mesh.rs b/hyperactor_mesh/src/v1/proc_mesh.rs index 0d6613bea..6996c7753 100644 --- a/hyperactor_mesh/src/v1/proc_mesh.rs +++ b/hyperactor_mesh/src/v1/proc_mesh.rs @@ -29,6 +29,7 @@ use hyperactor::supervision::ActorSupervisionEvent; use ndslice::Extent; use ndslice::ViewExt as _; use ndslice::view; +use ndslice::view::CollectMeshExt; use ndslice::view::MapIntoExt; use ndslice::view::Ranked; use ndslice::view::Region; @@ -48,6 +49,7 @@ use crate::proc_mesh::mesh_agent::ProcMeshAgent; use crate::resource; use crate::v1; use crate::v1::ActorMesh; +use crate::v1::ActorMeshRef; use crate::v1::Error; use crate::v1::HostMeshRef; use crate::v1::Name; @@ -93,6 +95,7 @@ impl ProcRef { } /// Get the supervision events for one actor with the given name. + #[allow(dead_code)] async fn supervision_events( &self, cx: &impl context::Actor, @@ -453,18 +456,45 @@ impl ProcMeshRef { vm.join().await.transpose() } + fn agent_mesh(&self) -> ActorMeshRef { + let agent_name = self.ranks.first().unwrap().agent.actor_id().name(); + // This name must match the ProcMeshAgent name, which can change depending on the allocator. + ActorMeshRef::new(Name::new_reserved(agent_name), self.clone()) + } + /// The supervision events of procs in this mesh. pub async fn supervision_events( &self, cx: &impl context::Actor, name: Name, ) -> v1::Result>> { - let vm: ValueMesh<_> = self.map_into(|proc_ref| { - let proc_ref = proc_ref.clone(); - let name = name.clone(); - async move { proc_ref.supervision_events(cx, name).await } - }); - vm.join().await.transpose() + let agent_mesh = self.agent_mesh(); + let (port, mut rx) = cx.mailbox().open_port::>(); + agent_mesh.cast( + cx, + resource::GetState:: { + name: name.clone(), + reply: port.bind(), + }, + )?; + let expected = self.ranks.len(); + let mut states = Vec::with_capacity(expected); + for _ in 0..expected { + let state = rx.recv().await?; + states.push(state); + } + let vm = states + .into_iter() + .map(|state| { + if let Some(state) = state.state { + state.supervision_events + } else { + // Empty vec for ranks with no supervision events. + Vec::new() + } + }) + .collect_mesh::>>(self.region.clone())?; + Ok(vm) } /// Spawn an actor on all of the procs in this mesh, returning a new ActorMesh. @@ -600,7 +630,6 @@ mod tests { use timed_test::async_timed_test; use crate::v1::ActorMesh; - use crate::v1::ActorMeshRef; use crate::v1::testactor; use crate::v1::testing;