Skip to content

Commit 9b03dda

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Enable CUDA path in Python FE, update grpo example (#1224)
Summary: Enable device='cuda' for RDMABuffer Reviewed By: zdevito Differential Revision: D82331433
1 parent 350deb6 commit 9b03dda

File tree

7 files changed

+90
-51
lines changed

7 files changed

+90
-51
lines changed

docs/source/examples/grpo_actor.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,10 @@ async def weights_handle(self) -> Dict[str, Tuple[torch.Tensor, RDMABuffer]]:
292292
Returns:
293293
Dictionary mapping parameter names to RDMA buffers
294294
"""
295-
cpu_tensors = {
296-
k: v.cpu().view(torch.uint8).flatten()
295+
self._weights_handle = {
296+
k: (v, RDMABuffer(v.view(torch.uint8).flatten()))
297297
for k, v in self.model.state_dict().items()
298298
}
299-
self._weights_handle = {k: (v, RDMABuffer(v)) for k, v in cpu_tensors.items()}
300299
return self._weights_handle
301300

302301
def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
@@ -372,11 +371,6 @@ def _apply_policy_update(
372371
self.optim.step()
373372
self.policy_version += 1
374373

375-
# update buffers
376-
sd = self.model.state_dict()
377-
for n, (t, _) in self._weights_handle.items():
378-
t.copy_(sd[n].view(torch.uint8).flatten())
379-
380374
# Return loss value
381375
return loss.detach()
382376

@@ -486,9 +480,8 @@ async def update(self, version: int) -> None:
486480
async with self.cond:
487481
# Copy weights from RDMA buffers
488482
sd = self.model.state_dict()
489-
cpu_sd = {k: torch.zeros_like(v, device="cpu") for k, v in sd.items()}
490483
for n, (_, b) in self.weight_buffers.items():
491-
await b.read_into(cpu_sd[n].view(torch.uint8).flatten())
484+
await b.read_into(sd[n].view(torch.uint8).flatten())
492485
self.model.load_state_dict(sd)
493486
# Update version and state
494487
self.policy_version = version

monarch_rdma/extension/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ impl PyRdmaBuffer {
116116
ibverbs_supported()
117117
}
118118

119+
#[classmethod]
120+
fn pt_cuda_allocator_compatibility<'py>(_cls: &Bound<'_, PyType>, _py: Python<'py>) -> bool {
121+
monarch_rdma::pt_cuda_allocator_compatibility()
122+
}
123+
119124
#[pyo3(name = "__repr__")]
120125
fn repr(&self) -> String {
121126
format!("<RdmaBuffer'{:?}'>", self.buffer)

monarch_rdma/src/rdma_components.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,21 @@ pub fn get_registered_cuda_segments() -> Vec<rdmaxcel_sys::rdma_segment_info_t>
12451245
}
12461246
}
12471247

1248+
/// Check if PyTorch CUDA caching allocator has expandable segments enabled.
1249+
///
1250+
/// This function calls the C++ implementation that directly accesses the
1251+
/// PyTorch C10 CUDA allocator configuration to check if expandable segments
1252+
/// are enabled, which is required for RDMA operations with CUDA tensors.
1253+
///
1254+
/// # Returns
1255+
///
1256+
/// `true` if both CUDA caching allocator is enabled AND expandable segments are enabled,
1257+
/// `false` otherwise.
1258+
pub fn pt_cuda_allocator_compatibility() -> bool {
1259+
// SAFETY: We are calling a C++ function from rdmaxcel that accesses PyTorch C10 APIs.
1260+
unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() }
1261+
}
1262+
12481263
#[cfg(test)]
12491264
mod tests {
12501265
use super::*;

monarch_rdma/src/rdma_manager_actor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ impl Actor for RdmaManagerActor {
275275
async fn new(_params: Self::Params) -> Result<Self, anyhow::Error> {
276276
let mut config = _params;
277277

278-
let pt_cuda_alloc = unsafe { rdmaxcel_sys::pt_cuda_allocator_compatibility() };
278+
let pt_cuda_alloc = crate::rdma_components::pt_cuda_allocator_compatibility();
279279

280280
// check config and hardware support align
281281
if config.use_gpu_direct {

python/monarch/_rust_bindings/rdma.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,5 @@ class _RdmaBuffer:
5959
def new_from_json(json: str) -> _RdmaBuffer: ...
6060
@classmethod
6161
def rdma_supported(cls) -> bool: ...
62+
@classmethod
63+
def pt_cuda_allocator_compatibility(cls) -> bool: ...

python/monarch/_src/tensor_engine/rdma.py

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,61 @@ async def init_rdma_on_mesh(self, proc_mesh: ProcMesh) -> None:
138138
)
139139

140140

141+
# Cached so that we don't have to call out to the root client every time,
142+
# which may be on a different host.
143+
@functools.cache
144+
def _ensure_init_rdma_manager() -> Shared[None]:
145+
async def task() -> None:
146+
await (
147+
await get_or_spawn_controller("rdma_controller", RdmaController)
148+
).init_rdma_on_mesh.call_one(none_throws(context().actor_instance.proc_mesh))
149+
150+
return PythonTask.from_coroutine(task()).spawn()
151+
152+
153+
@functools.cache
154+
def _check_cuda_expandable_segments_enabled() -> bool:
155+
"""
156+
Check if PyTorch CUDA caching allocator is using expandable segments.
157+
158+
Uses the Rust extension which calls the C++ implementation from rdmaxcel-sys
159+
that directly accesses the PyTorch C10 CUDA allocator configuration.
160+
161+
Returns:
162+
bool: True if expandable segments are enabled, False otherwise
163+
164+
Raises:
165+
RuntimeError: If expandable segments are not enabled but required for RDMA
166+
"""
167+
try:
168+
# Use the new Rust utility function that calls the C++ pt_cuda_allocator_compatibility()
169+
pt_cuda_compat = _RdmaBuffer.pt_cuda_allocator_compatibility()
170+
171+
if not pt_cuda_compat:
172+
raise RuntimeError(
173+
"CUDA caching allocator is not using expandable segments.\n"
174+
"This is required for RDMA to work correctly with CUDA tensors.\n\n"
175+
"To fix this, set the environment variable BEFORE importing PyTorch:\n"
176+
"1. In shell:\n"
177+
' export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"\n'
178+
"2. Or in Python script (BEFORE any PyTorch imports):\n"
179+
" import os\n"
180+
' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"\n'
181+
" import torch # Must come after setting the env var\n\n"
182+
"Note: This setting must be configured before PyTorch's CUDA allocator is initialized."
183+
)
184+
return True
185+
186+
except Exception as e:
187+
logging.error(f"Failed to check CUDA allocator configuration: {e}")
188+
raise RuntimeError(
189+
"Unable to verify CUDA allocator configuration.\n"
190+
"Please ensure expandable segments are enabled:\n"
191+
' export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"\n'
192+
"Set this environment variable before importing PyTorch."
193+
)
194+
195+
141196
class RDMABuffer:
142197
def __init__(
143198
self,
@@ -159,15 +214,9 @@ def __init__(
159214
160215
TODO: Create TensorBuffer, which will be main user API supporting non-contiguous tensors
161216
"""
162-
if isinstance(data, torch.Tensor) and data.device.type != "cpu":
163-
# TODO - CUDA support for RDMABuffer exists at the Rust layer, but
164-
# runs into issues with MR creation. For now, only support CPU tensors.
165-
# Remove this once GPU support is added.
166-
raise ValueError(
167-
"RDMABuffer currently only supports CPU tensors (got device {})".format(
168-
data.device
169-
)
170-
)
217+
if isinstance(data, torch.Tensor) and data.device.type == "cuda":
218+
# Check if CUDA caching allocator is using expandable segments
219+
_check_cuda_expandable_segments_enabled()
171220

172221
assert (
173222
is_available()
@@ -221,16 +270,6 @@ def read_into(
221270
Currently only CPU tensors are fully supported. GPU tensors will be temporarily
222271
copied to CPU, which may impact performance.
223272
"""
224-
dst_gpu = None
225-
if isinstance(dst, torch.Tensor) and dst.device.type != "cpu":
226-
warnings.warn(
227-
"note: read_into only supports CPU tensors, so `dst` is being copied to CPU.",
228-
RDMAReadTransferWarning,
229-
stacklevel=2,
230-
)
231-
dst_gpu = dst
232-
dst = dst.cpu()
233-
234273
dst_addr, dst_size = _get_addr_and_size(dst)
235274

236275
if self.size() > dst_size:
@@ -251,9 +290,6 @@ async def read_into_nonblocking() -> Optional[int]:
251290
client=client,
252291
timeout=timeout,
253292
)
254-
# TODO - remove this once GPU support is added.
255-
if dst_gpu is not None:
256-
dst_gpu.copy_(dst)
257293
return res
258294

259295
return Future(coro=read_into_nonblocking())
@@ -285,16 +321,6 @@ def write_from(
285321
Currently only CPU tensors are fully supported. GPU tensors will be temporarily
286322
copied to CPU, which may impact performance.
287323
"""
288-
src_gpu = None
289-
if isinstance(src, torch.Tensor) and src.device.type != "cpu":
290-
# TODO - remove this once GPU support is added.
291-
warnings.warn(
292-
"note: write_from only supports CPU tensors, so we will write to CPU first, then transfer to `src` in place.",
293-
RDMAWriteTransferWarning,
294-
stacklevel=2,
295-
)
296-
src_gpu = src # Save the original GPU tensor reference
297-
src = src.cpu() # Convert to CPU for RDMA operation
298324

299325
src_addr, src_size = _get_addr_and_size(src)
300326

@@ -315,9 +341,6 @@ async def write_from_nonblocking() -> None:
315341
client=client,
316342
timeout=timeout,
317343
)
318-
# TODO - remove this once GPU support is added.
319-
if src_gpu is not None:
320-
src_gpu.copy_(src)
321344
return res
322345

323346
return Future(coro=write_from_nonblocking())

python/tests/test_rdma.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
8+
import os
9+
10+
# required to enable RDMA support
11+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
812

913
import pytest
1014
import torch
@@ -105,23 +109,20 @@ async def test_proc_mesh_rdma():
105109
x = await client_gpu.get_buffer.call_one()
106110
buffer_gpu = x.view(torch.float32).view(10, 10)
107111
assert torch.sum(buffer_gpu) == 0
108-
# copying a tensor across hosts moves it to CPU
109-
assert buffer_gpu.device.type == "cpu"
110112

111113
# Modify server state again
112114
await server.update.call_one()
113115
await client_gpu.download.call_one()
114116
x = await client_gpu.get_buffer.call_one()
115117
buffer_gpu = x.view(torch.float32).view(10, 10)
116118
remote_grad = await server.get_grad_buffer.call_one()
117-
assert torch.allclose(buffer_gpu.cpu(), remote_grad)
119+
assert torch.allclose(buffer_gpu.cpu(), remote_grad.cpu())
118120

119121

120122
class TrainerActor(Actor):
121123
def __init__(self):
122124
super().__init__()
123-
# TODO - switch to CUDA once GPU support is added
124-
self.trainer = torch.nn.Linear(10, 10).to("cpu")
125+
self.trainer = torch.nn.Linear(10, 10).to("cuda")
125126
self.trainer.weight.data.zero_()
126127

127128
@endpoint

0 commit comments

Comments
 (0)