Skip to content

Conversation

chivatam
Copy link

@chivatam chivatam commented Sep 7, 2025

Description

added rocshmem dependencies to the dockerfile

@msaroufim

@chivatam chivatam marked this pull request as draft September 7, 2025 14:29
@chivatam chivatam marked this pull request as ready for review September 7, 2025 14:29
@msaroufim
Copy link
Member

Could you share a toy user submission as well using rocshmem. Just wanna get a sense of what things will look like e2e

@msaroufim
Copy link
Member

Also @saienduri to sanity check

@chivatam
Copy link
Author

chivatam commented Sep 7, 2025

Could you share a toy user submission as well using rocshmem. Just wanna get a sense of what things will look like e2e

import os
from typing import Any

from torch.utils.cpp_extension import load_inline


ROCSHMEM_INSTALL_DIR = os.environ.get("ROCSHMEM_INSTALL_DIR", "/opt/rocshmem")
OMPI_INSTALL_DIR = os.environ.get("OMPI_INSTALL_DIR", "/opt/openmpi")


EXT_NAME = "rocshmem_all2all_ext"

CUDA_SRC = r"""
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <cstdlib>
#include <vector>

#include <hip/hip_runtime.h>
#include <roc_shmem.hpp>

namespace py = pybind11;

__global__ void all2all_kernel(int* symm, int npes) {
    if (threadIdx.x == 0) {
        int me = roc_shmem_my_pe();

        // Initialize local symmetric buffer
        for (int i = 0; i < npes; ++i) symm[i] = -1;
        roc_shmem_barrier_all();

        // Put my rank into every PE's symmetric buffer at index 'me'
        for (int dst = 0; dst < npes; ++dst) {
            roc_shmem_int_p(symm + me, me, dst);
        }
        roc_shmem_barrier_all();
    }
}

static void hip_check(hipError_t err, const char* where) {
    if (err != hipSuccess) {
        throw std::runtime_error(std::string("HIP error at ") + where + ": " + hipGetErrorString(err));
    }
}

void bind_and_init() {
    // Bind device based on rank
    int dev_count = 0;
    hip_check(hipGetDeviceCount(&dev_count), "hipGetDeviceCount");

    int rank = 0;
    if (const char* s = std::getenv("OMPI_COMM_WORLD_RANK")) {
        rank = std::atoi(s);
    }
    hip_check(hipSetDevice(dev_count == 0 ? 0 : (rank % dev_count)), "hipSetDevice");

    // Initialize rocSHMEM after device selection
    roc_shmem_init();
}

std::vector<int> run_all2all() {
    int me   = roc_shmem_my_pe();
    int npes = roc_shmem_n_pes();

    int* symm = (int*)roc_shmem_malloc(sizeof(int) * npes);
    if (!symm) throw std::runtime_error("roc_shmem_malloc failed");

    // Launch one-thread kernel to do the collective
    all2all_kernel<<<1, 1>>>(symm, npes);
    hip_check(hipDeviceSynchronize(), "hipDeviceSynchronize");

    // Copy local symmetric buffer back to host
    std::vector<int> out(npes, -1);
    hip_check(hipMemcpy(out.data(), symm, sizeof(int) * npes, hipMemcpyDeviceToHost), "hipMemcpy D2H");

    roc_shmem_free(symm);
    return out;
}

void finalize() {
    roc_shmem_finalize();
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("bind_and_init", &bind_and_init);
    m.def("run_all2all",  &run_all2all);
    m.def("finalize",     &finalize);
}
"""


def _build_ext():
    return load_inline(
        name=EXT_NAME,
        cuda_sources=[CUDA_SRC],
        functions=["bind_and_init", "run_all2all", "finalize"],
        with_cuda=True,
        extra_cflags=["-std=c++17"],
        extra_cuda_cflags=["-std=c++17"],
        extra_include_paths=[f"{ROCSHMEM_INSTALL_DIR}/include"],
        extra_ldflags=[
            f"-L{ROCSHMEM_INSTALL_DIR}/lib", "-lrocshmem",
            f"-L{OMPI_INSTALL_DIR}/lib", "-lmpi",
            f"-Wl,-rpath,{ROCSHMEM_INSTALL_DIR}/lib:{OMPI_INSTALL_DIR}/lib",
        ],
        verbose=True,
    )


# --- Optional: type-compatible stub for the Python leaderboard pattern ---
def custom_kernel(data: Any):  # input_t -> output_t, toy no-op to fit signature
    return data


def _rank_and_world():
    r = int(os.environ.get("OMPI_COMM_WORLD_RANK", "0"))
    w = int(os.environ.get("OMPI_COMM_WORLD_SIZE", "1"))
    return r, w


if __name__ == "__main__":
    rank, world = _rank_and_world()
    ext = _build_ext()
    ext.bind_and_init()
    out = ext.run_all2all()
    print(f"Rank {rank}/{world} all2all -> {out}")
    ext.finalize()

Vibe coded this but is gonna look similar to HIP kernels in python

@msaroufim

@saienduri
Copy link
Collaborator

Looks good to me. Starting a test docker build here to check status: https://github.com/gpu-mode/discord-cluster-manager/actions/runs/17545534459.

@chivatam
Copy link
Author

chivatam commented Sep 8, 2025

ooo! looks like there is some issue with UCX. I ll debug it today!

@chivatam
Copy link
Author

chivatam commented Sep 9, 2025

@saienduri I made some changes but not sure if it works, is there a way to test the workflow without approval? I don't have MI300X to test 😅

@msaroufim msaroufim requested a review from saienduri September 9, 2025 18:17
@saienduri
Copy link
Collaborator

saienduri commented Sep 13, 2025

Thanks, trying a build here now: https://github.com/gpu-mode/discord-cluster-manager/actions/runs/17701378282. You can locally try building the docker just to see if the build passes.

@saienduri
Copy link
Collaborator

Cool, the build passed and a sanity test passed here: https://github.com/gpu-mode/discord-cluster-manager/actions/runs/17702258708
I was using this test payload: https://github.com/gpu-mode/discord-cluster-manager/blob/saienduri/fix-payload/scripts/github_test_payload.json
Can you also share a small payload for testing if rocshmem works before we merge this PR?

@chivatam
Copy link
Author

@saienduri added one, lmk if it works!

@saienduri
Copy link
Collaborator

Hmm getting ValueError: Invalid language cpp (https://github.com/gpu-mode/discord-cluster-manager/actions/runs/17714138325)

@msaroufim
Copy link
Member

msaroufim commented Sep 14, 2025

You want the example working with load_inline in PyTorch

@chivatam
Copy link
Author

done but idk if it works 😬

@msaroufim
Copy link
Member

@saienduri can we test the provided payload example on the server directly? If it's fine then we should be good to merge

@saienduri
Copy link
Collaborator

saienduri commented Sep 17, 2025

ok running the payload in github actions yielded the following (https://github.com/gpu-mode/discord-cluster-manager/actions/runs/17790562194):

"stdout": "=== ROCshmem PyTorch Inline Test ===\nROCshmem test failed: module 'torch.utils' has no attribute 'cpp_extension'\n"

I'll try on the server itself, but pretty sure it will be the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants