diff --git a/src/ATen/native/xpu/Copy.cpp b/src/ATen/native/xpu/Copy.cpp index b011baa80..60a379276 100644 --- a/src/ATen/native/xpu/Copy.cpp +++ b/src/ATen/native/xpu/Copy.cpp @@ -71,8 +71,10 @@ void memcpyAsync( Device dst_device = iter.device(0); Device src_device = iter.device(1); if (dst_device == src_device) { + std::cout << "zl_debug: go to same device and specialized kernel" << std::endl; copy_kernel(iter); } else { + std::cout << "zl_debug: go to sycl copy kernel" << std::endl; TORCH_INTERNAL_ASSERT(p2p_enabled == true); auto dst = (char*)iter.data_ptr(0); auto src = (char*)iter.data_ptr(1); diff --git a/src/xccl/IpcExchange.hpp b/src/xccl/IpcExchange.hpp new file mode 100644 index 000000000..e515cd6ce --- /dev/null +++ b/src/xccl/IpcExchange.hpp @@ -0,0 +1,400 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "xccl/ze_symbol.hpp" + +#include + +#include +#include +#include +#include +#include + +struct exchange_contents { + // first 4-byte is file descriptor for drmbuf or gem object + union { + ze_ipc_mem_handle_t ipc_handle; + int fd = -1; + }; + size_t offset = 0; + int pid = -1; +}; + +#define sysCheck(x) \ + if (x == -1) { \ + throw std::system_error(std::make_error_code(std::errc(errno))); \ + } + +// We can't inherit it from cmsghdr because flexible array member +struct exchange_fd { + char obscure[CMSG_LEN(sizeof(int)) - sizeof(int)]; + int fd; + + exchange_fd(int cmsg_level, int cmsg_type, int fd) : fd(fd) { + auto* cmsg = reinterpret_cast(obscure); + cmsg->cmsg_len = sizeof(exchange_fd); + cmsg->cmsg_level = cmsg_level; + cmsg->cmsg_type = cmsg_type; + } + + exchange_fd() : fd(-1) { + memset(obscure, 0, sizeof(obscure)); + }; +}; + +void un_send_fd(int sock, int fd, int rank, size_t offset) { + iovec iov[1]; + msghdr msg; + auto rank_offset = std::make_pair(rank, offset); + + iov[0].iov_base = &rank_offset; + iov[0].iov_len = sizeof(rank_offset); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + + exchange_fd cmsg(SOL_SOCKET, SCM_RIGHTS, fd); + + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(exchange_fd); + sysCheck(sendmsg(sock, &msg, 0)); +} + +std::tuple un_recv_fd(int sock) { + iovec iov[1]; + msghdr msg; + std::pair rank_offset; + + iov[0].iov_base = &rank_offset; + iov[0].iov_len = sizeof(rank_offset); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_name = nullptr; + msg.msg_namelen = 0; + + exchange_fd cmsg; + msg.msg_control = &cmsg; + msg.msg_controllen = sizeof(exchange_fd); + int n_recv = recvmsg(sock, &msg, 0); + sysCheck(n_recv); + // assert(n_recv == sizeof(int)); + + return std::make_tuple(cmsg.fd, rank_offset.first, rank_offset.second); +} + +int prepare_socket(const char* sockname) { + sockaddr_un un; + memset(&un, 0, sizeof(un)); + un.sun_family = AF_UNIX; + strcpy(un.sun_path, sockname); + + auto sock = socket(AF_UNIX, SOCK_STREAM, 0); + sysCheck(sock); + + int on = 1; + sysCheck(ioctl(sock, FIONBIO, &on)); + + auto size = offsetof(sockaddr_un, sun_path) + strlen(un.sun_path); + sysCheck(bind(sock, (sockaddr*)&un, size)); + + return sock; +} + +int server_listen(const char* sockname) { + unlink(sockname); + auto sock = prepare_socket(sockname); + sysCheck(listen(sock, 10)); + + return sock; +} + +int serv_accept(int listen_sock) { + sockaddr_un un; + + socklen_t len = sizeof(un); + auto accept_sock = accept(listen_sock, (sockaddr*)&un, &len); + sysCheck(accept_sock); + + return accept_sock; +} + +bool wait_for_socket_file(const char* path, int max_seconds = 10) { + struct stat buffer; + for (int i = 0; i < max_seconds * 10; ++i) { + if (stat(path, &buffer) == 0) { + return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; +} + +int client_connect(const char* server, const char* client) { + if (!wait_for_socket_file(server, 10)) { + std::cerr << "Error: timeout waiting for server socket file: " << server + << std::endl; + exit(EXIT_FAILURE); + } + auto sock = prepare_socket(client); + sockaddr_un sun; + memset(&sun, 0, sizeof(sun)); + sun.sun_family = AF_UNIX; + strcpy(sun.sun_path, server); + auto len = offsetof(sockaddr_un, sun_path) + strlen(server); + const int max_retries = 50; + int retry = 0; + int ret = -1; + while (retry < max_retries) { + ret = connect(sock, (sockaddr*)&sun, len); + if (ret == 0) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + retry++; + } + if (ret != 0) { + perror("connect failed"); + exit(EXIT_FAILURE); + } + + // sysCheck(connect(sock, (sockaddr*)&sun, len)); + return sock; +} + +void un_allgather( + exchange_contents* send_buf, + exchange_contents recv_buf[], + int rank, + int world) { + const char* servername_prefix = "/tmp/open-peer-ipc-mem-server-rank_"; + const char* clientname_prefix = "/tmp/open-peer-ipc-mem-client-rank_"; + char server_name[64]; + /* get username to make server_name unique */ + auto uid = getuid(); + auto pwd = getpwuid(uid); + snprintf( + server_name, + sizeof(server_name), + "%s%d_%s", + servername_prefix, + rank, + pwd->pw_name); + unlink(server_name); + auto s_listen = server_listen(server_name); + + pollfd fdarray[world]; + int recv_socks[world - 1]; + + for (auto& pollfd : fdarray) + pollfd.fd = -1; + std::fill(recv_socks, recv_socks + world - 1, -1); + + auto fd_guard = [&]() { + for (int i = 0, j = 0; i < world; ++i) { + if (i != rank && recv_socks[j] != -1) + sysCheck(close(recv_socks[j++])); + if (fdarray[i].fd != -1) + sysCheck(close(fdarray[i].fd)); + } + }; + + struct guard__ { + using F = decltype(fd_guard); + F f; + guard__(const F& f) : f(f) {} + ~guard__() { + f(); + } + } free_fd(fd_guard); + + // connect to all ranks + for (int i = 0; i < world; ++i) { + if (rank == i) { + fdarray[i].fd = s_listen; + fdarray[i].events = POLLIN; + fdarray[i].revents = 0; + } else { + char peer_name[64]; + char client_name[64]; + + snprintf( + client_name, + sizeof(client_name), + "%s%d-%d_%s", + clientname_prefix, + rank, + i, + pwd->pw_name); + unlink(client_name); + + snprintf( + peer_name, + sizeof(peer_name), + "%s%d_%s", + servername_prefix, + i, + pwd->pw_name); + fdarray[i].fd = client_connect(peer_name, client_name); + fdarray[i].events = POLLOUT; + fdarray[i].revents = 0; + } + } + + // std::future> future_fds[world -1]; + int slot = 0; + uint32_t send_progress = 1 << rank; + + while (slot < world - 1 || send_progress != (1 << world) - 1) { + sysCheck(ppoll(fdarray, world, nullptr, nullptr)); + + for (int i = 0; i < world; ++i) { + if (i == rank && (fdarray[i].revents & POLLIN)) { + // auto accept_sock = serv_accept(fdarray[i].fd); + // future_fds[slot ++] = std::async( + // std::launch::async, [=]() { + // struct sock_guard{ + // int sock; + // sock_guard(int sock) : sock(sock) {} + // ~guard_sock() {sysCheck(close(sock));} + // } release(accept_sock); + // auto ret = un_recv_fd(accept_sock); + // return ret;}); + recv_socks[slot++] = serv_accept(fdarray[i].fd); + } else if ( + (send_progress & (1 << i)) == 0 && fdarray[i].revents & POLLOUT) { + un_send_fd(fdarray[i].fd, send_buf->fd, rank, send_buf->offset); + send_progress |= 1 << i; + } + } + } + + for (int i = 0; i < world - 1; ++i) { + // future_fds[i].wait(); + // auto [fd, peer, offset] = future_fds[i].get(); + auto [fd, peer, offset] = un_recv_fd(recv_socks[i]); + recv_buf[peer].fd = fd; + recv_buf[peer].offset = offset; + } + + recv_buf[rank] = *send_buf; +} + +class IpcChannel { + public: + IpcChannel() { + initialized = false; + } + void init(sycl::queue& queue, uint32_t rank_in, uint32_t world_in) { + if (initialized) + return; + + if (!load_level_zero_library()) { + throw std::runtime_error("Failed to initialize Level Zero"); + } + + zeCheck_dynamic(zeInit_dynamic(0)); + int tmp_rank, tmp_world; + + tmp_world = world_in; + tmp_rank = rank_in; + + rank = tmp_rank; + world = tmp_world; + initialized = true; + } + void release(sycl::queue& queue) { + if (!initialized) + return; + try { + auto l0_ctx = sycl::get_native( + queue.get_context()); + for (int i = 0; i < world; i++) { + if (i != rank) { + zeCheck_dynamic(zeMemCloseIpcHandle_dynamic( + l0_ctx, (char*)buffers[i] - offsets[i])); + } + } + } catch (const std::exception& e) { + std::cerr << "Warning: Level Zero cleanup failed: " << e.what() + << std::endl; + } + sycl::free(buffers[rank], queue); + initialized = false; + } + + // buffer_size as element size + void exchange_peer_ipc_mem( + sycl::queue& queue, + void* ptr, + uint32_t rank_in, + uint32_t world_in) { + if (!initialized) + init(queue, rank_in, world_in); + if (!load_level_zero_library()) { + throw std::runtime_error("Level Zero not available"); + } + + // Step 1: Get base address of the pointer + sycl::context ctx = queue.get_context(); + auto l0_ctx = sycl::get_native(ctx); + + void* base_addr; + size_t base_size; + zeCheck_dynamic( + zeMemGetAddressRange_dynamic(l0_ctx, ptr, &base_addr, &base_size)); + + // Step 2: Get IPC mem handle from base address + alignas(64) exchange_contents send_buf; + alignas(64) exchange_contents recv_buf[world]; + + // fill in the exchange info + zeCheck_dynamic( + zeMemGetIpcHandle_dynamic(l0_ctx, base_addr, &send_buf.ipc_handle)); + send_buf.offset = (char*)ptr - (char*)base_addr; + + send_buf.pid = getpid(); + + // Step 3: Exchange the handles and offsets + memset(recv_buf, 0, sizeof(recv_buf)); + // Overkill if we don't really needs all peer's handles + un_allgather(&send_buf, recv_buf, rank, world); + for (uint32_t i = 0; i < world; i++) { + // Step 4: Prepare pid file descriptor of next process + auto* peer = recv_buf + i; + // Step 6: Open IPC handle of remote peer + auto l0_device = sycl::get_native( + queue.get_device()); + void* peer_base; + + zeCheck_dynamic(zeMemOpenIpcHandle_dynamic( + l0_ctx, + l0_device, + peer->ipc_handle, + ZE_IPC_MEMORY_FLAG_BIAS_CACHED, + &peer_base)); + + buffers[i] = (char*)peer_base + peer->offset; + offsets[i] = peer->offset; + ipc_handle[i] = send_buf.ipc_handle; + } + } + + bool initialized; + static constexpr uint32_t max_rank = 16; + void* buffers[max_rank]; + void* sync_buffer[max_rank]; + size_t offsets[max_rank]; + ze_ipc_mem_handle_t ipc_handle[max_rank]; + int rank, world; + int size_per_buffer; + int data_size_per_buffer; + int buffer_index; +}; diff --git a/src/xccl/XPUSymmetricMemory.cpp b/src/xccl/XPUSymmetricMemory.cpp new file mode 100644 index 000000000..d49d12612 --- /dev/null +++ b/src/xccl/XPUSymmetricMemory.cpp @@ -0,0 +1,460 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace c10d { +namespace symmetric_memory { + +static StoreExchange storeExchange = StoreExchange("XPUSymmetricMemory"); + +AllocationRef::AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx, + bool local_allocation) + : ptr(ptr), + handle(handle), + block_size(block_size), + device_idx(device_idx), + local_allocation(local_allocation){} + +AllocationRef::~AllocationRef() { + if (is_finalizing()) { + return; + } + // Currently, we cannot free virtual memory exchanged from other device. + if (!local_allocation) { + return; + } + c10::Device local_device(c10::DeviceType::XPU, device_idx); + c10::DeviceGuard guard(local_device); + c10::xpu::syncStreamsOnDevice(); + auto stream = at::xpu::getCurrentXPUStream(); + sycl::free(ptr, stream); +} + +XPUSymmetricMemory::XPUSymmetricMemory( + std::vector> alloc_refs, + std::vector buffers, + std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size) + : alloc_refs_(std::move(alloc_refs)), + buffers_(std::move(buffers)), + signal_pads_(std::move(signal_pads)), + mc_handle_(mc_handle), + mc_addr_(mc_addr), + buffer_size_(buffer_size), + local_device_idx_(local_device_idx), + rank_(rank), + world_size_(world_size) { + const size_t arr_size = sizeof(void*) * world_size_; + buffers_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + signal_pads_dev_ = reinterpret_cast( + c10::xpu::XPUCachingAllocator::raw_alloc(arr_size)); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx); + c10::DeviceGuard guard(local_device); + + at::xpu::getCurrentXPUStream().queue().memcpy( + buffers_dev_, buffers_.data(), arr_size); + at::xpu::getCurrentXPUStream().queue().memcpy( + signal_pads_dev_, signal_pads_.data(), arr_size); +} + +std::vector XPUSymmetricMemory::get_buffer_ptrs() { + return buffers_; +} + +std::vector XPUSymmetricMemory::get_signal_pad_ptrs() { + return signal_pads_; +} + +void** XPUSymmetricMemory::get_buffer_ptrs_dev() { + return buffers_dev_; +} + +void** XPUSymmetricMemory::get_signal_pad_ptrs_dev() { + return signal_pads_dev_; +} + +size_t XPUSymmetricMemory::get_buffer_size() { + return buffer_size_; +} + +size_t XPUSymmetricMemory::get_signal_pad_size() { + return signal_pad_size; +} + +bool XPUSymmetricMemory::has_multicast_support() { + return false; +} + +void* XPUSymmetricMemory::get_multicast_ptr() { + return nullptr; +} + +at::Tensor XPUSymmetricMemory::get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset) { + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); + const auto element_size = c10::elementSize(dtype); + const auto req_size = (numel + storage_offset) * element_size; + TORCH_CHECK( + req_size <= buffer_size_, + "XPUSymmetricMemory::get_buffer: the requested size (", + req_size, + " bytes) exceeds the allocated size (", + buffer_size_, + " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; + // check the device of this device buffer + auto ptr_to_device_id = c10::xpu::get_device_idx_from_pointer(data_ptr); + auto device = c10::Device(c10::DeviceType::XPU, ptr_to_device_id); + auto options = at::TensorOptions().dtype(dtype).device(device); + + return at::for_blob(data_ptr, sizes) + .options(options) + .target_device(device) + .make_tensor(); +} + +void check_channel(int channel, int world_size) { + TORCH_CHECK( + channel >= 0, + "channel for barrier(), put_signal() and wait_signal() ", + "must be greater than 0 (got ", + channel, + ")"); + const size_t num_channels = signal_pad_size / sizeof(uint32_t) * world_size; + TORCH_CHECK( + static_cast(channel) < num_channels, + "The maximum supported channel for barrier(), put_signal() and wait_signal() is ", + num_channels - 1, + " (got ", + channel, + ")"); +} + +void XPUSymmetricMemory::barrier(int channel, size_t timeout_ms) { + check_channel(channel, world_size_); + + // Currently, we leverage oneCCL for barrier. Later, we may move to SYCL + // implementation. + auto group = c10d::resolve_process_group(group_name_); + if (group == nullptr) { + TORCH_WARN( + "Process group '", + group_name_, + "' not found, please init process group first before calling SymmetricMemory"); + throw std::runtime_error("Process group not found"); + } + auto* xcclPg = dynamic_cast( + group->getBackend(c10::DeviceType::XPU).get()); + + c10::Device local_device(c10::DeviceType::XPU, local_device_idx_); + c10::DeviceGuard guard(local_device); + + static thread_local at::Tensor barrier_tensor; + if (!barrier_tensor.defined() || barrier_tensor.device() != local_device) { + barrier_tensor = at::zeros( + {1}, at::TensorOptions().device(local_device).dtype(at::kFloat)); + } else { + barrier_tensor.zero_(); + } + + c10d::AllreduceOptions arOpts; + arOpts.asyncOp = false; + auto work = + xcclPg->allreduce_impl(barrier_tensor, "xccl:symm_mem_barrier", arOpts); + + if (work) { + bool success = work->wait(std::chrono::milliseconds(timeout_ms)); + TORCH_CHECK( + success, + "Barrier timeout after ", + timeout_ms, + " ms for group '", + group_name_, + "'"); + } +} + +void XPUSymmetricMemory::put_signal( + int dst_rank, + int channel, + size_t timeout_ms) { + LOG(ERROR) << "XPUSymmetricMemory::put_signal not supported"; +} + +void XPUSymmetricMemory::wait_signal( + int src_rank, + int channel, + size_t timeout_ms) { + LOG(ERROR) << "XPUSymmetricMemory::wait_signal not supported"; +} + +int XPUSymmetricMemory::get_rank() { + return rank_; +} + +int XPUSymmetricMemory::get_world_size() { + return world_size_; +} + +c10::Device XPUSymmetricMemory::get_device() { + return c10::Device(c10::DeviceType::XPU, local_device_idx_); +} + +Block::Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name) + : alloc_ref(std::move(alloc_ref)), + device_idx(device_idx), + block_size(block_size), + buffer_size(buffer_size), + signal_pad_offset(signal_pad_offset), + default_group_name(std::move(group_name)) {} + +void* XPUSymmetricMemoryAllocator::alloc( + size_t size, + int device_idx, + const std::optional& group_name) { + size_t signal_pad_offset = at::round_up(size, 16UL); + size_t block_size = signal_pad_offset + signal_pad_size; + + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + void* ptr = sycl::malloc_device(block_size, current_queue); + current_queue.memset(ptr, 0, block_size); + auto alloc_ref = + c10::make_intrusive(ptr, ptr, block_size, device_idx, true); + auto block = c10::make_intrusive( + std::move(alloc_ref), + device_idx, + block_size, + size, + signal_pad_offset, + group_name); + + { + std::unique_lock lock(mutex_); + ptr_to_block_.emplace(ptr, std::move(block)); + } + return ptr; +} + +void XPUSymmetricMemoryAllocator::free(void* ptr) { + std::unique_lock lock(mutex_); + ptr_to_block_.erase(ptr); +} + +size_t XPUSymmetricMemoryAllocator::get_alloc_size(void* ptr) { + auto block = find_block(ptr); + TORCH_CHECK( + block != nullptr, + "XPUSymmetricMemoryAllocator::get_alloc_size: input must be allocated ", + "via XPUSymmetricMemoryAllocator::alloc"); + return block->buffer_size; +} + +struct RendezvousRequest { + int device_idx; + int pid; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + bool has_multicast_support; +}; + +void validate_rendezvous_requests( + const std::vector& reqs, + int world_size) { + TORCH_CHECK(reqs.size() == (size_t)world_size); + + std::unordered_set device_indices; + device_indices.reserve(world_size); + for (auto req : reqs) { + device_indices.insert(req.device_idx); + } + + for (int r = 1; r < world_size; ++r) { + TORCH_CHECK(reqs[r].block_size == reqs[0].block_size); + TORCH_CHECK(reqs[r].buffer_size == reqs[0].buffer_size); + TORCH_CHECK(reqs[r].signal_pad_offset == reqs[0].signal_pad_offset); + } +} + +c10::intrusive_ptr XPUSymmetricMemoryAllocator::rendezvous( + void* ptr, + const std::optional& group_name) { + auto block = find_block(ptr); + if (block == nullptr) { + return nullptr; + } + + // The group_name passed to rendezvous() takes precedence over + // the default group_name specified during allocation. + std::string group_name_; + // Treat empty string and std::nullopt the same as empty string seems to be + // implicitly used that way + if (group_name.has_value() && group_name != "") { + group_name_ = *group_name; + } else { + if (!block->default_group_name.has_value()) { + TORCH_CHECK( + false, + "XPUSymmetricMemory::rendezvous: `group_name` is neither " + "specified during allocation nor passed to rendezvous()."); + } + group_name_ = *block->default_group_name; + } + + auto it = block->symm_mems.find(group_name_); + if (it != block->symm_mems.end()) { + return it->second; + } + + c10::Device local_device(c10::DeviceType::XPU, block->device_idx); + c10::DeviceGuard guard(local_device); + + // IpcChannel is used to do inter-process communication + IpcChannel ipc_channel; + auto group_info = get_group_info(group_name_); + auto store = group_info.store; + int rank = group_info.rank; + int world_size = group_info.world_size; + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + + auto local_req = RendezvousRequest{ + .device_idx = block->device_idx, + .pid = getpid(), + .block_size = block->block_size, + .buffer_size = block->buffer_size, + .signal_pad_offset = block->signal_pad_offset, + .has_multicast_support = false}; + auto reqs = storeExchange.all_gather(store, rank, world_size, local_req); + validate_rendezvous_requests(reqs, world_size); + + std::vector pids(world_size); + for (int r = 0; r < world_size; ++r) { + pids[r] = reqs[r].pid; + } + + // do IPC exchange for all peer ranks + ipc_channel.exchange_peer_ipc_mem(current_queue, ptr, rank, world_size); + + // no physical memory handle, so handles and buffers are both for virtual + // address + std::vector handles(world_size); + std::vector buffers(world_size, nullptr); + std::vector signal_pads(world_size, nullptr); + + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + handles[r] = block->alloc_ref->handle; + buffers[r] = ptr; + signal_pads[r] = (void*)((uintptr_t)ptr + block->signal_pad_offset); + continue; + } else { + buffers[r] = ipc_channel.buffers[r]; + handles[r] = ipc_channel.buffers[r]; + signal_pads[r] = + (void*)((uintptr_t)buffers[r] + block->signal_pad_offset); + } + } + storeExchange.barrier(store, rank, world_size); + + HandleType mc_handle{}; + void* mc_addr = nullptr; + + std::vector> alloc_refs; + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + alloc_refs.emplace_back(block->alloc_ref); + continue; + } + alloc_refs.push_back(c10::make_intrusive( + buffers[r], handles[r], block->block_size, block->device_idx, false)); + } + + auto symm_mem = c10::make_intrusive( + std::move(alloc_refs), + std::move(buffers), + std::move(signal_pads), + mc_handle, + mc_addr, + block->buffer_size, + block->device_idx, + group_info.rank, + group_info.world_size); + symm_mem->set_group_name(group_name_); + block->symm_mems[group_name_] = symm_mem; + return symm_mem; +} + +bool XPUSymmetricMemoryAllocator::has_multicast_support(int device_idx) { + return false; +} + +c10::DeviceType XPUSymmetricMemoryAllocator::supported_device_type() { + return c10::DeviceType::XPU; +} + +std::string XPUSymmetricMemoryAllocator::name() { + return "XPU"; +} + +c10::intrusive_ptr XPUSymmetricMemoryAllocator::find_block(void* ptr) { + std::shared_lock lock(mutex_); + auto it = ptr_to_block_.find(ptr); + if (it == ptr_to_block_.end()) { + return nullptr; + } + return it->second; +} + +struct RegisterXPUSymmetricMemoryAllocator { + RegisterXPUSymmetricMemoryAllocator() { + auto allocator = c10::make_intrusive(); + // Query backend used for XPU + if (getSymmMemBackendXPU() == "XPU") { + // Direct set (static registration) + register_allocator(c10::DeviceType::XPU, allocator); + } else { + // Register availability in case `set_backend` is called dynamically + register_availability("XPU", allocator); + } + } +}; +static RegisterXPUSymmetricMemoryAllocator register_allocator_; + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/XPUSymmetricMemory.hpp b/src/xccl/XPUSymmetricMemory.hpp new file mode 100644 index 000000000..2daac1114 --- /dev/null +++ b/src/xccl/XPUSymmetricMemory.hpp @@ -0,0 +1,130 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +// Resource wrapper that owns a (vaddr, allocation handle) pair. Upon +// destruction, it unmaps the vaddr and releases the allocation handle. +struct AllocationRef : public c10::intrusive_ptr_target { + void* ptr; + HandleType handle; + size_t block_size; + int device_idx; + bool local_allocation; + + AllocationRef( + void* ptr, + HandleType handle, + size_t block_size, + int device_idx, + bool local_allocation); + + ~AllocationRef(); +}; + +class XPUSymmetricMemory : public SymmetricMemory { + public: + XPUSymmetricMemory( + std::vector> alloc_refs, + std::vector buffers, + std::vector signal_pads, + HandleType mc_handle, + void* mc_addr, + size_t buffer_size, + int local_device_idx, + int rank, + int world_size); + + ~XPUSymmetricMemory() override{}; + + std::vector get_buffer_ptrs() override; + std::vector get_signal_pad_ptrs() override; + void** get_buffer_ptrs_dev() override; + void** get_signal_pad_ptrs_dev() override; + size_t get_buffer_size() override; + size_t get_signal_pad_size() override; + + bool has_multicast_support() override; + void* get_multicast_ptr() override; + + at::Tensor get_buffer( + int rank, + c10::IntArrayRef sizes, + c10::ScalarType dtype, + int64_t storage_offset); + + void barrier(int channel, size_t timeout_ms) override; + void put_signal(int dst_rank, int channel, size_t timeout_ms) override; + void wait_signal(int src_rank, int channel, size_t timeout_ms) override; + + int get_rank() override; + int get_world_size() override; + c10::Device get_device() override; + + void set_group_name(const std::string& group_name) { + group_name_ = group_name; + } + + private: + std::vector> alloc_refs_; + std::vector buffers_; + std::vector signal_pads_; + HandleType mc_handle_; + void* mc_addr_; + size_t buffer_size_; + int local_device_idx_; + int rank_; + int world_size_; + void** buffers_dev_; + void** signal_pads_dev_; + std::string group_name_; +}; + +struct Block : public c10::intrusive_ptr_target { + c10::intrusive_ptr alloc_ref; + int device_idx; + size_t block_size; + size_t buffer_size; + size_t signal_pad_offset; + std::optional default_group_name; + std::map> symm_mems; + + Block( + c10::intrusive_ptr alloc_ref, + int device_idx, + size_t block_size, + size_t buffer_size, + size_t signal_pad_offset, + const std::optional& group_name); +}; + +class XPUSymmetricMemoryAllocator : public SymmetricMemoryAllocator { + public: + void* alloc( + size_t size, + int device_idx, + const std::optional& group_name) override; + + void free(void* ptr) override; + size_t get_alloc_size(void* ptr) override; + c10::intrusive_ptr rendezvous( + void* ptr, + const std::optional& group_name) override; + bool has_multicast_support(int device_idx) override; + // void exchange_peer_ipc_mem(sycl::queue& queue, void* ptr); + c10::DeviceType supported_device_type() override; + std::string name() override; + + private: + c10::intrusive_ptr find_block(void* ptr); + + std::shared_mutex mutex_; + std::unordered_map> ptr_to_block_; +}; + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryTypes.hpp b/src/xccl/XPUSymmetricMemoryTypes.hpp new file mode 100644 index 000000000..4cab3b81f --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryTypes.hpp @@ -0,0 +1,8 @@ +#pragma once + +namespace c10d::symmetric_memory { + +constexpr size_t signal_pad_size = 2048; +using HandleType = void*; + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.cpp b/src/xccl/XPUSymmetricMemoryUtils.cpp new file mode 100644 index 000000000..7130fe7b6 --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace c10d::symmetric_memory { + +std::string getSymmMemBackendXPU() { + static auto val = c10::utils::get_env("TORCH_SYMMMEM"); + if (val.has_value()) { + TORCH_CHECK( + val.value() == "XPU", + "TORCH_SYMMMEM environment variable must be 'XPU'."); + return val.value(); + } + return "XPU"; +} + +bool device_has_multicast_support(int device_idx) { + return false; +} + +bool allow_overlapping_devices() { + return false; +} + +void map_block( + void** ptr, + ze_physical_mem_handle_t handle, + size_t size, + int device_idx) { + sycl::queue current_queue = at::xpu::getCurrentXPUStream().queue(); + sycl::context sycl_ctx = current_queue.get_context(); + ze_context_handle_t ze_context = + sycl::get_native(sycl_ctx); + // 1. Reserve virtual address space + void* virtual_ptr = nullptr; + ze_result_t status = zeVirtualMemReserve( + ze_context, // context + nullptr, // let L0 pick virtual address + size, // size + &virtual_ptr // out: reserved address + ); + TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemReserve failed"); + + // 2. Map physical memory to virtual address + status = zeVirtualMemMap( + ze_context, + virtual_ptr, // virtual memory to map to + size, + handle, // physical memory handle + 0, // flags + ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE // ze_memory_access_attribute_t + ); + TORCH_CHECK(status == ZE_RESULT_SUCCESS, "zeVirtualMemMap failed"); + + // 3. Set access attributes + ze_memory_access_attribute_t access = ZE_MEMORY_ACCESS_ATTRIBUTE_READWRITE; + status = + zeVirtualMemSetAccessAttribute(ze_context, virtual_ptr, size, access); + TORCH_CHECK( + status == ZE_RESULT_SUCCESS, "zeVirtualMemSetAccessAttribute failed"); + + // 4. Return pointer + *ptr = virtual_ptr; +} + +} // namespace c10d::symmetric_memory diff --git a/src/xccl/XPUSymmetricMemoryUtils.hpp b/src/xccl/XPUSymmetricMemoryUtils.hpp new file mode 100644 index 000000000..69189f45c --- /dev/null +++ b/src/xccl/XPUSymmetricMemoryUtils.hpp @@ -0,0 +1,89 @@ +#pragma once +#include +#include +#include + +namespace c10d { +namespace symmetric_memory { + +std::string getSymmMemBackendXPU(); + +bool device_has_multicast_support(int device_idx); + +bool allow_overlapping_devices(); + +// A set of store-based exchange methods with a preset prefix typically type of +// the SymmetricMemory. Most used as static instances at respective +// SymmetricMemory implementation files. +class StoreExchange { + public: + StoreExchange(const std::string& store_prefix) + : store_prefix_(store_prefix) {} + + // Put template function in header file so that compiler can easily access it. + template + std::vector all_gather( + const c10::intrusive_ptr& store, + int rank, + int world_size, + T val) { + static_assert(std::is_trivially_copyable_v); + + std::vector peer_keys; + peer_keys.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + std::ostringstream oss; + oss << store_prefix_ << "/" << seq_id_ << "/" << r; + peer_keys.push_back(oss.str()); + } + ++seq_id_; + + { + std::vector payload( + reinterpret_cast(&val), + reinterpret_cast(&val) + sizeof(T)); + store->set(peer_keys[rank], payload); + } + + std::vector peer_vals; + peer_vals.reserve(world_size); + for (int r = 0; r < world_size; ++r) { + if (r == rank) { + peer_vals.push_back(val); + continue; + } + store->wait({peer_keys[r]}); + auto payload = store->get(peer_keys[r]); + TORCH_CHECK(payload.size() == sizeof(T)); + T peer_val{}; + std::memcpy(&peer_val, payload.data(), sizeof(T)); + peer_vals.push_back(peer_val); + } + return peer_vals; + } + + void barrier( + const c10::intrusive_ptr& store, + int rank, + int world_size) { + // TODO: implement an efficient one? + all_gather(store, rank, world_size, 0); + } + + private: + const std::string store_prefix_; + size_t seq_id_ = 0; +}; + +// Returns a pointer of virtual address that is mapped to the physical memory +// held by the handle. +// todo: will follow such physical memory handle map with virtual address, +// when L0 provides physical handle exchange API and we have multicast support. +void map_block( + void** ptr, + ze_physical_mem_handle_t handle, + size_t size, + int device_idx); + +} // namespace symmetric_memory +} // namespace c10d diff --git a/src/xccl/ze_symbol.hpp b/src/xccl/ze_symbol.hpp new file mode 100644 index 000000000..20af66681 --- /dev/null +++ b/src/xccl/ze_symbol.hpp @@ -0,0 +1,254 @@ +#pragma once + +#include +#include +#include +#include + +#define zeVirtualMemMap zeVirtualMemMap_original +#define zeVirtualMemReserve zeVirtualMemReserve_original +#define zeVirtualMemSetAccessAttribute zeVirtualMemSetAccessAttribute_original + +#include + +#undef zeVirtualMemMap +#undef zeVirtualMemReserve +#undef zeVirtualMemSetAccessAttribute + +typedef ze_result_t (*zeInit_t)(ze_init_flags_t flags); +typedef ze_result_t (*zeMemGetAddressRange_t)( + ze_context_handle_t hContext, + const void* ptr, + void** pBase, + size_t* pSize); +typedef ze_result_t (*zeMemGetIpcHandle_t)( + ze_context_handle_t hContext, + const void* ptr, + ze_ipc_mem_handle_t* pIpcHandle); +typedef ze_result_t (*zeMemOpenIpcHandle_t)( + ze_context_handle_t hContext, + ze_device_handle_t hDevice, + ze_ipc_mem_handle_t handle, + ze_ipc_memory_flags_t flags, + void** pptr); +typedef ze_result_t ( + *zeMemCloseIpcHandle_t)(ze_context_handle_t hContext, const void* ptr); +typedef ze_result_t (*zeVirtualMemMap_t)( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_physical_mem_handle_t hPhysicalMemory, + size_t offset, + ze_memory_access_attribute_t access); +typedef ze_result_t (*zeVirtualMemReserve_t)( + ze_context_handle_t hContext, + const void* pStart, + size_t size, + void** pptr); +typedef ze_result_t (*zeVirtualMemSetAccessAttribute_t)( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_memory_access_attribute_t access); + +bool load_level_zero_library(); +void unload_level_zero_library(); + +#define zeCheck_dynamic(x) \ + do { \ + if (!load_level_zero_library()) { \ + throw std::runtime_error("Level Zero library not available"); \ + } \ + ze_result_t result = (x); \ + if (result != ZE_RESULT_SUCCESS) { \ + auto e = zeException(result); \ + std::cout << "Throw " << e.what() << std::endl; \ + throw e; \ + } \ + } while (0) + +#define zeInit_dynamic(flags) zeInit_ptr(flags) +#define zeMemGetAddressRange_dynamic(ctx, ptr, base, size) \ + zeMemGetAddressRange_ptr(ctx, ptr, base, size) +#define zeMemGetIpcHandle_dynamic(ctx, ptr, handle) \ + zeMemGetIpcHandle_ptr(ctx, ptr, handle) +#define zeMemOpenIpcHandle_dynamic(ctx, dev, handle, flags, ptr) \ + zeMemOpenIpcHandle_ptr(ctx, dev, handle, flags, ptr) +#define zeMemCloseIpcHandle_dynamic(ctx, ptr) zeMemCloseIpcHandle_ptr(ctx, ptr) +#define zeVirtualMemMap_dynamic(ctx, ptr, size, phys_mem, offset, access) \ + zeVirtualMemMap_ptr(ctx, ptr, size, phys_mem, offset, access) +#define zeVirtualMemReserve_dynamic(ctx, start, size, ptr) \ + zeVirtualMemReserve_ptr(ctx, start, size, ptr) +#define zeVirtualMemSetAccessAttribute_dynamic(ctx, ptr, size, access) \ + zeVirtualMemSetAccessAttribute_ptr(ctx, ptr, size, access) + +// Exception handling class +class zeException : std::exception { + const char* zeResultToString(ze_result_t status) const { + static const std::unordered_map zeResultToStringMap{ + {ZE_RESULT_SUCCESS, "[Core] success"}, + {ZE_RESULT_NOT_READY, "[Core] synchronization primitive not signaled"}, + {ZE_RESULT_ERROR_UNINITIALIZED, + "[Validation] driver is not initialized"}, + {ZE_RESULT_ERROR_INVALID_NULL_POINTER, + "[Validation] pointer argument may not be nullptr"}, + {ZE_RESULT_ERROR_INVALID_NULL_HANDLE, + "[Validation] handle argument is not valid"}, + {ZE_RESULT_ERROR_INVALID_ENUMERATION, + "[Validation] enumerator argument is not valid"}, + {ZE_RESULT_ERROR_INVALID_SIZE, "[Validation] size argument is invalid"}, + {ZE_RESULT_ERROR_UNSUPPORTED_SIZE, + "[Validation] size argument is not supported by the device"}, + {ZE_RESULT_ERROR_UNSUPPORTED_ALIGNMENT, + "[Validation] alignment argument is not supported by the device"}, + {ZE_RESULT_ERROR_UNSUPPORTED_FEATURE, + "[Validation] generic error code for unsupported features"}, + {ZE_RESULT_ERROR_INVALID_NATIVE_BINARY, + "[Validation] native binary is not supported by the device"}, + {ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY, + "[Core] insufficient host memory to satisfy call"}, + {ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY, + "[Core] insufficient device memory to satisfy call"}, + {ZE_RESULT_ERROR_DEVICE_LOST, + "[Core] device hung, reset, was removed, or driver update occurred"}, + {ZE_RESULT_ERROR_MODULE_BUILD_FAILURE, + "[Core] error occurred when building module, see build log for details"}, + {ZE_RESULT_ERROR_HANDLE_OBJECT_IN_USE, + "[Validation] object pointed to by handle still in-use by device"}, + }; + auto it = zeResultToStringMap.find(status); + if (it != zeResultToStringMap.end()) + return it->second; + else + return "Unknown Reason"; + } + + public: + zeException(ze_result_t ret) : result_(ret) {} + + ze_result_t result_; + + const char* what() const noexcept override { + return zeResultToString(result_); + } +}; + +#define zeCheck(x) \ + if (x != ZE_RESULT_SUCCESS) { \ + auto e = zeException(x); \ + std::cout << "Throw " << e.what() << std::endl; \ + throw e; \ + } + +static zeInit_t zeInit_ptr = nullptr; +static zeMemGetAddressRange_t zeMemGetAddressRange_ptr = nullptr; +static zeMemGetIpcHandle_t zeMemGetIpcHandle_ptr = nullptr; +static zeMemOpenIpcHandle_t zeMemOpenIpcHandle_ptr = nullptr; +static zeMemCloseIpcHandle_t zeMemCloseIpcHandle_ptr = nullptr; +static zeVirtualMemMap_t zeVirtualMemMap_ptr = nullptr; +static zeVirtualMemReserve_t zeVirtualMemReserve_ptr = nullptr; +static zeVirtualMemSetAccessAttribute_t zeVirtualMemSetAccessAttribute_ptr = + nullptr; + +static void* ze_handle = nullptr; + +inline bool load_level_zero_library() { + if (ze_handle != nullptr) { + return true; + } + const char* lib_names[] = {"libze_loader.so"}; + + for (const char* lib_name : lib_names) { + ze_handle = dlopen(lib_name, RTLD_LAZY); + if (ze_handle != nullptr) { + break; + } + } + + if (ze_handle == nullptr) { + std::cerr << "Failed to load Level Zero library: " << dlerror() + << std::endl; + return false; + } + + zeInit_ptr = (zeInit_t)dlsym(ze_handle, "zeInit"); + zeMemGetAddressRange_ptr = + (zeMemGetAddressRange_t)dlsym(ze_handle, "zeMemGetAddressRange"); + zeMemGetIpcHandle_ptr = + (zeMemGetIpcHandle_t)dlsym(ze_handle, "zeMemGetIpcHandle"); + zeMemOpenIpcHandle_ptr = + (zeMemOpenIpcHandle_t)dlsym(ze_handle, "zeMemOpenIpcHandle"); + zeMemCloseIpcHandle_ptr = + (zeMemCloseIpcHandle_t)dlsym(ze_handle, "zeMemCloseIpcHandle"); + zeVirtualMemMap_ptr = (zeVirtualMemMap_t)dlsym(ze_handle, "zeVirtualMemMap"); + zeVirtualMemReserve_ptr = + (zeVirtualMemReserve_t)dlsym(ze_handle, "zeVirtualMemReserve"); + zeVirtualMemSetAccessAttribute_ptr = (zeVirtualMemSetAccessAttribute_t)dlsym( + ze_handle, "zeVirtualMemSetAccessAttribute"); + + if (!zeInit_ptr || !zeMemGetAddressRange_ptr || !zeMemGetIpcHandle_ptr || + !zeMemOpenIpcHandle_ptr || !zeMemCloseIpcHandle_ptr || + !zeVirtualMemMap_ptr || !zeVirtualMemReserve_ptr || + !zeVirtualMemSetAccessAttribute_ptr) { + std::cerr << "Failed to load Level Zero API functions" << std::endl; + dlclose(ze_handle); + ze_handle = nullptr; + return false; + } + + return true; +} + +inline void unload_level_zero_library() { + if (ze_handle != nullptr) { + dlclose(ze_handle); + ze_handle = nullptr; + zeInit_ptr = nullptr; + zeMemGetAddressRange_ptr = nullptr; + zeMemGetIpcHandle_ptr = nullptr; + zeMemOpenIpcHandle_ptr = nullptr; + zeMemCloseIpcHandle_ptr = nullptr; + zeVirtualMemMap_ptr = nullptr; + zeVirtualMemReserve_ptr = nullptr; + zeVirtualMemSetAccessAttribute_ptr = nullptr; + } +} + +extern "C" { + +__attribute__((weak)) ze_result_t zeVirtualMemMap( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_physical_mem_handle_t hPhysicalMemory, + size_t offset, + ze_memory_access_attribute_t access) { + if (!load_level_zero_library() || !zeVirtualMemMap_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemMap_ptr( + hContext, ptr, size, hPhysicalMemory, offset, access); +} + +__attribute__((weak)) ze_result_t zeVirtualMemReserve( + ze_context_handle_t hContext, + const void* pStart, + size_t size, + void** pptr) { + if (!load_level_zero_library() || !zeVirtualMemReserve_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemReserve_ptr(hContext, pStart, size, pptr); +} + +__attribute__((weak)) ze_result_t zeVirtualMemSetAccessAttribute( + ze_context_handle_t hContext, + const void* ptr, + size_t size, + ze_memory_access_attribute_t access) { + if (!load_level_zero_library() || !zeVirtualMemSetAccessAttribute_ptr) { + return ZE_RESULT_ERROR_UNINITIALIZED; + } + return zeVirtualMemSetAccessAttribute_ptr(hContext, ptr, size, access); +} +} diff --git a/test/xpu/distributed/test_symmetric_memory_xccl.py b/test/xpu/distributed/test_symmetric_memory_xccl.py new file mode 100644 index 000000000..37f5d3e6d --- /dev/null +++ b/test/xpu/distributed/test_symmetric_memory_xccl.py @@ -0,0 +1,85 @@ +import torch +import torch.distributed as dist +from test_c10d_xccl import init_multigpu_helper, requires_xccl +from torch.distributed._symmetric_memory import ( + _fused_all_gather_matmul_fallback, + _fused_matmul_reduce_scatter_fallback, +) + +from torch.testing._internal.common_distributed import MultiProcContinuousTest +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests +) + +@instantiate_parametrized_tests +class AsyncTPTest(MultiProcContinuousTest): + @property + def device(self) -> torch.device: + return torch.device("xpu", self.rank) + + def _init_process(self): + torch.xpu.set_device(self.device) + torch.manual_seed(42 + self.rank) + torch.use_deterministic_algorithms(True) + torch.set_deterministic_debug_mode("warn") + torch.utils.deterministic.fill_uninitialized_memory = True + + @requires_xccl() + @parametrize("gather_dim", [0, 1]) + def test_fused_all_gather_matmul(self, gather_dim: int) -> None: + self._init_process() + BATCH = 8 + M = 64 + N = 16 + K = 32 + group = dist.group.WORLD + rank = self.rank + + torch.manual_seed(42 + rank) + A_shard = torch.rand(BATCH, M // self.world_size, K, device="xpu") + Bs = [torch.rand(K, N, device="xpu") for _ in range(3)] + + ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul( + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name + ) + + self.assertEqual(ag_output_0, ag_output_1) + self.assertEqual(ag_output_0.stride(), ag_output_1.stride()) + for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): + self.assertEqual(mm_output_0, mm_output_1) + self.assertEqual(mm_output_0.stride(), mm_output_1.stride()) + + @requires_xccl() + @parametrize("scatter_dim", [0, 1]) + def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: + self._init_process() + + BATCH = 8 + M = 64 + N = 16 + K = 32 + group = dist.group.WORLD + rank = self.rank + + torch.manual_seed(42 + rank) + A = torch.rand(BATCH, M, K, device="xpu") + B = torch.rand(K, N, device="xpu") + + output_0 = _fused_matmul_reduce_scatter_fallback( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter( + A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name + ) + + self.assertEqual(output_0, output_1) + self.assertEqual(output_0.stride(), output_1.stride()) + + +if __name__ == "__main__": + run_tests() diff --git a/test/xpu/run_distributed.py b/test/xpu/run_distributed.py index ddde5f8c8..496540616 100644 --- a/test/xpu/run_distributed.py +++ b/test/xpu/run_distributed.py @@ -26,6 +26,10 @@ def run(test_command): test_command = ["python", "distributed/test_c10d_ops_xccl.py"] res += run(test_command) +test_command = ["python", "distributed/test_c10d_xccl.py"] +res += run(test_command) +test_command = ["python", "distributed/test_symmetric_memory_xccl.py"] +res += run(test_command) # run pytest with skiplist for key in skip_dict: