diff --git a/CMakeLists.txt b/CMakeLists.txt index 3521f01a..c9a0be3f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,7 @@ cmake_minimum_required(VERSION 3.19) set(CMAKE_CXX_FLAGS_DEBUG "-g -O0") set(CMAKE_CXX_FLAGS_RELEASE "-O3") +set(CMAKE_CXX_FLAGS_O3DEBUG "-ggdb -O3") set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) diff --git a/examples/dist_rdma_ops/dist_write.cpp b/examples/dist_rdma_ops/dist_write.cpp index 101dc197..dd5adc86 100644 --- a/examples/dist_rdma_ops/dist_write.cpp +++ b/examples/dist_rdma_ops/dist_write.cpp @@ -59,11 +59,10 @@ void VerifyBuffer(void* buffer, size_t maxSize, char expected) { template inline __device__ void QuiteSerial(RdmaEndpoint* endpoint) { - if (threadIdx.x != 0 || blockIdx.x != 0) return; - + if (GetActiveLaneNum() != 0) return; CompletionQueueHandle& cq = endpoint->cqHandle; WorkQueueHandle& wq = endpoint->wqHandle; - // AcquireLock(&cq.pollCqLock); + if (!AcquireLockOnce(&cq.pollCqLock)) return; while (true) { bool done{false}; uint32_t quiet_amount{0}; @@ -74,6 +73,7 @@ inline __device__ void QuiteSerial(RdmaEndpoint* endpoint) { uint32_t doneIdx = __hip_atomic_load(&wq.doneIdx, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT); // printf("dbTouchIdx: %u, doneIdx: %u\n", dbTouchIdx, doneIdx); if (dbTouchIdx == doneIdx) { + ReleaseLock(&cq.pollCqLock); return; } @@ -105,7 +105,7 @@ inline __device__ void QuiteSerial(RdmaEndpoint* endpoint) { __atomic_signal_fence(__ATOMIC_SEQ_CST); __hip_atomic_fetch_max(&wq.doneIdx, wqe_id, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } - // ReleaseLock(&cq.pollCqLock); + ReleaseLock(&cq.pollCqLock); } template @@ -122,12 +122,6 @@ __device__ void Quite(RdmaEndpoint* endpoint) { bool is_last{my_logical_lane_id == num_active_lanes - 1}; const uint64_t leader_phys_lane_id = GetFirstActiveLaneID(activemask); CompletionQueueHandle* cqHandle = &endpoint->cqHandle; - uint8_t num_slot_per_wqe; - if constexpr (P == ProviderType::MLX5) { - num_slot_per_wqe = 1; - } else if constexpr (P == ProviderType::BNXT) { - num_slot_per_wqe = 3; - } while (true) { bool done{false}; @@ -144,9 +138,8 @@ __device__ void Quite(RdmaEndpoint* endpoint) { if (!(posted - completed)) { return; } - uint32_t quiet_val = posted - active; - - if (!quiet_val) { + int32_t quiet_val = posted - active; + if (quiet_val <= 0) { continue; } quiet_amount = min(num_active_lanes, quiet_val); @@ -155,7 +148,8 @@ __device__ void Quite(RdmaEndpoint* endpoint) { active + quiet_amount, __ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); if (done) { - warp_cq_consumer = active; + warp_cq_consumer = __hip_atomic_fetch_add(&cqHandle->cq_consumer, quiet_amount, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } } done = __shfl(done, leader_phys_lane_id); @@ -167,14 +161,11 @@ __device__ void Quite(RdmaEndpoint* endpoint) { if (my_logical_lane_id < quiet_amount) { uint16_t wqe_counter; PollCq

(cqHandle->cqAddr, cqHandle->cqeNum, &my_cq_consumer, &wqe_counter); - if constexpr (P == ProviderType::BNXT) { - wqe_counter = (num_slot_per_wqe * (wqe_counter + endpoint->wqHandle.sqWqeNum - 1) % - endpoint->wqHandle.sqWqeNum); - } __threadfence_system(); wqe_id = endpoint->wqHandle.outstandingWqe[wqe_counter]; __hip_atomic_fetch_max(&wqe_broadcast[warp_id], wqe_id, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); + __atomic_signal_fence(__ATOMIC_SEQ_CST); } if (is_leader) { uint64_t completed{0}; @@ -194,118 +185,139 @@ __device__ void Quite(RdmaEndpoint* endpoint) { } } -inline __device__ void atomic_add_packed_msn_and_psn(uint64_t* msnPack, uint32_t incSlot, - uint32_t incPsn, uint32_t* oldSlot, - uint32_t* oldPsn) { - uint64_t expected = __hip_atomic_load(msnPack, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - while (true) { - uint32_t curSlot = static_cast(expected & 0xFFFFFFFF); - uint32_t curPsn = static_cast((expected >> 32) & 0xFFFFFFFF); - - uint32_t newSlot = curSlot + incSlot; - uint32_t newPsn = curPsn + incPsn; - - uint64_t desired = (static_cast(newPsn) << 32) | static_cast(newSlot); - - if (__hip_atomic_compare_exchange_strong(msnPack, &expected, desired, __ATOMIC_RELAXED, - __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT)) { - if (oldSlot) *oldSlot = curSlot; - if (oldPsn) *oldPsn = curPsn; - break; - } - } -} - template -__global__ void Write(RdmaEndpoint* endpoint, RdmaMemoryRegion localMr, RdmaMemoryRegion remoteMr, - size_t msg_size, int iters, uint32_t* blockSync) { - for (int i = 0; i < iters; i++) { - - uint64_t activemask = GetActiveLaneMask(); - uint8_t num_active_lanes = GetActiveLaneCount(activemask); - uint8_t my_logical_lane_id = GetActiveLaneNum(activemask); - bool is_leader{my_logical_lane_id == num_active_lanes - 1}; - const uint64_t leader_phys_lane_id = GetLastActiveLaneID(activemask); +__device__ void Write(RdmaEndpoint* endpoint, RdmaMemoryRegion localMr, RdmaMemoryRegion remoteMr, + size_t msg_size) { + if (msg_size == 0) return; + uint64_t activemask = GetActiveLaneMask(); + uint8_t num_active_lanes = GetActiveLaneCount(activemask); + uint8_t my_logical_lane_id = GetActiveLaneNum(activemask); + bool is_leader{my_logical_lane_id == num_active_lanes - 1}; + const uint64_t leader_phys_lane_id = GetLastActiveLaneID(activemask); - uint8_t num_wqes = num_active_lanes; + uint8_t num_wqes{num_active_lanes}; + uint32_t warp_sq_counter{0}; + uint32_t warp_msntbl_counter{0}, warp_psn_counter{0}; + uint32_t my_sq_counter{0}, my_msntbl_counter{0}, my_psn_counter{0}; - uint64_t warp_sq_counter{0}; - uint32_t warp_msntbl_counter{0}, warp_psn_counter{0}; - WorkQueueHandle* wqHandle = &endpoint->wqHandle; - CompletionQueueHandle* cqHandle = &endpoint->cqHandle; - uint32_t psnCnt = (msg_size + wqHandle->mtuSize - 1) / wqHandle->mtuSize; - if (is_leader) { + WorkQueueHandle* wqHandle = &endpoint->wqHandle; + CompletionQueueHandle* cqHandle = &endpoint->cqHandle; + uint32_t psnCnt; + if constexpr (P == core::ProviderType::BNXT) { + psnCnt = (msg_size + wqHandle->mtuSize - 1) / wqHandle->mtuSize; + } + if (is_leader) { + if constexpr (P == core::ProviderType::MLX5) { + warp_sq_counter = __hip_atomic_fetch_add(&wqHandle->postIdx, num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } else if constexpr (P == core::ProviderType::BNXT) { core::atomic_add_packed_msn_and_psn(&wqHandle->msnPack, num_wqes, psnCnt * num_wqes, &warp_msntbl_counter, &warp_psn_counter); - // TODO: if warp_msntbl_counter overflow 32bit, sq_slot's caculation will be wrong warp_sq_counter = warp_msntbl_counter; __hip_atomic_fetch_max(&wqHandle->postIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } else { + assert(false); } - warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); + } + warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); + if constexpr (P == core::ProviderType::MLX5) { + my_sq_counter = warp_sq_counter + my_logical_lane_id; + } else if constexpr (P == core::ProviderType::BNXT) { warp_msntbl_counter = __shfl(warp_msntbl_counter, leader_phys_lane_id); warp_psn_counter = __shfl(warp_psn_counter, leader_phys_lane_id); - uint64_t my_sq_counter = warp_sq_counter + my_logical_lane_id; - uint64_t my_msntbl_counter = warp_msntbl_counter + my_logical_lane_id; - uint64_t my_psn_counter = warp_psn_counter + my_logical_lane_id * psnCnt; - while (true) { - uint64_t db_touched = - __hip_atomic_load(&wqHandle->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint64_t db_done = - __hip_atomic_load(&wqHandle->doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint64_t num_active_sq_entries = db_touched - db_done; - uint64_t num_free_entries = wqHandle->sqWqeNum - num_active_sq_entries; - uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_active_lanes - db_touched; - if (num_free_entries > num_entries_until_warp_last_entry) { - break; - } - if constexpr (P == ProviderType::MLX5) { - Quite

(endpoint); - } else if constexpr (P == ProviderType::BNXT) { - QuiteSerial

(endpoint); - } + my_sq_counter = warp_sq_counter + my_logical_lane_id; + my_msntbl_counter = warp_msntbl_counter + my_logical_lane_id; + my_psn_counter = warp_psn_counter + psnCnt * my_logical_lane_id; + } else { + assert(false); + } + + while (true) { + uint64_t db_touched = + __hip_atomic_load(&wqHandle->dbTouchIdx, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); + uint64_t db_done = + __hip_atomic_load(&wqHandle->doneIdx, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); + uint64_t num_active_sq_entries = db_touched - db_done; + uint64_t num_free_entries = wqHandle->sqWqeNum - num_active_sq_entries; + uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_active_lanes - db_touched; + if (num_free_entries > num_entries_until_warp_last_entry) { + break; } if constexpr (P == ProviderType::MLX5) { - wqHandle->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; + Quite

(endpoint); } else if constexpr (P == ProviderType::BNXT) { - wqHandle->outstandingWqe[my_sq_counter % wqHandle->sqWqeNum] = my_sq_counter; + QuiteSerial

(endpoint); } - uintptr_t srcAddr = localMr.addr + FlatThreadId() * msg_size; - uintptr_t dstAddr = remoteMr.addr + FlatThreadId() * msg_size; - uint64_t dbr_val = + } + uintptr_t srcAddr = localMr.addr + FlatThreadId() * msg_size; + uintptr_t dstAddr = remoteMr.addr + FlatThreadId() * msg_size; + uint64_t dbr_val; + if constexpr (P == ProviderType::MLX5) { + wqHandle->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; + dbr_val = + PostWrite

(*wqHandle, my_sq_counter, my_sq_counter, my_sq_counter, is_leader, + endpoint->handle.qpn, srcAddr, localMr.lkey, dstAddr, remoteMr.rkey, msg_size); + } else if constexpr (P == ProviderType::BNXT) { + wqHandle->outstandingWqe[my_sq_counter % wqHandle->sqWqeNum] = my_sq_counter; + dbr_val = PostWrite

(*wqHandle, my_sq_counter, my_msntbl_counter, my_psn_counter, is_leader, endpoint->handle.qpn, srcAddr, localMr.lkey, dstAddr, remoteMr.rkey, msg_size); + } else { + assert(false); + } - if (is_leader) { - uint64_t db_touched{0}; - do { - db_touched = - __hip_atomic_load(&wqHandle->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - } while (db_touched != warp_sq_counter); - - // uint8_t* base_ptr = reinterpret_cast(wqHandle->sqAddr); - // uint64_t* ctrl_wqe_8B_for_db = reinterpret_cast( - // &base_ptr[64 * ((warp_sq_counter + num_wqes - 1) % wqHandle->sqWqeNum)]); - UpdateSendDbrRecord

(wqHandle->dbrRecAddr, warp_sq_counter + num_wqes); - // __threadfence_system(); - RingDoorbell

(wqHandle->dbrAddr, dbr_val); - - __hip_atomic_fetch_add(&cqHandle->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&wqHandle->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, - __HIP_MEMORY_SCOPE_AGENT); + if (is_leader) { + uint64_t db_touched{0}; + do { + db_touched = + __hip_atomic_load(&wqHandle->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } while (db_touched != warp_sq_counter); + + UpdateSendDbrRecord

(wqHandle->dbrRecAddr, warp_sq_counter + num_wqes); + RingDoorbell

(wqHandle->dbrAddr, dbr_val); + + __hip_atomic_fetch_add(&cqHandle->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + __hip_atomic_store(&wqHandle->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } +} + +template +__global__ void MultiQpWrite(RdmaEndpoint* endpoints, RdmaMemoryRegion localMr, + RdmaMemoryRegion remoteMr, size_t msg_size, int iters, + uint32_t* blockSync, int num_qp) { + int thdId = threadIdx.x; + int thdNum = blockDim.x; + + int laneId = threadIdx.x & (warpSize - 1); + int warpId = thdId / warpSize; + int warpNum = (blockDim.x + warpSize - 1) / warpSize; + + int globalThdId = blockIdx.x * blockDim.x + threadIdx.x; + int globalThdNum = gridDim.x * blockDim.x; + int globalWarpId = blockIdx.x * warpNum + warpId; + int globalWarpNum = gridDim.x * warpNum; + int qp_id = globalWarpId % num_qp; + for (int i = 0; i < iters; i++) { + for (int qp_offset = qp_id; qp_offset < num_qp; qp_offset += globalWarpNum) { + Write

(endpoints + qp_id, localMr, remoteMr, msg_size); } if constexpr (P == ProviderType::MLX5) { - Quite

(endpoint); + Quite

(endpoints + qp_id); } else if constexpr (P == ProviderType::BNXT) { - QuiteSerial

(endpoint); + for (int t = globalWarpId; t < num_qp; t += globalWarpNum) { + // printf("qp_offset:%d\n",qp_offset); + QuiteSerial

(endpoints + t); + } } + __syncthreads(); if (threadIdx.x == 0) { atomicAdd(blockSync + i, 1); } while (atomicAdd(blockSync + i, 0) < gridDim.x) { ; } - __syncthreads(); } } @@ -333,6 +345,7 @@ void distRdmaOps(int argc, char* argv[]) { hipEvent_t start, end; HIP_RUNTIME_CHECK(hipEventCreate(&start)); HIP_RUNTIME_CHECK(hipEventCreate(&end)); + int num_qp = args.getNumQp(); // RDMA initialization // 1 Create device @@ -348,29 +361,42 @@ void distRdmaOps(int argc, char* argv[]) { RdmaEndpointConfig config; config.portId = activeDevicePortList[local_rank % activeDevicePortList.size()].second; config.gidIdx = 3; - config.maxMsgsNum = 1536 - 1; + config.maxMsgsNum = 8092; +#ifdef ENABLE_BNXT + config.maxCqeNum = 1; +#else config.maxCqeNum = 4096; +#endif config.alignment = 4096; config.onGpu = on_gpu; - RdmaEndpoint endpoint = device_context->CreateRdmaEndpoint(config); + std::vector endpoints; + for (int i = 0; i < num_qp; ++i) { + endpoints.push_back(device_context->CreateRdmaEndpoint(config)); + } // 3 Allgather global endpoint and connect - std::vector global_rdma_ep_handles(world_size); - bootNet.Allgather(&endpoint.handle, global_rdma_ep_handles.data(), sizeof(RdmaEndpointHandle)); + std::vector global_rdma_ep_handles(world_size * num_qp); + for (int i = 0; i < num_qp; ++i) { + bootNet.Allgather(&endpoints[i].handle, global_rdma_ep_handles.data() + i * world_size, + sizeof(RdmaEndpointHandle)); + } - std::cout << "Local rank " << local_rank << " " << endpoint.handle << std::endl; + std::cout << "Local rank " << local_rank << " " << endpoints[0].handle << std::endl; for (int i = 0; i < world_size; i++) { if (i == local_rank) continue; - device_context->ConnectEndpoint(endpoint.handle, global_rdma_ep_handles[i]); - std::cout << "Local rank " << local_rank << " received " << global_rdma_ep_handles[i] - << std::endl; + for (int qp = 0; qp < num_qp; ++qp) { + device_context->ConnectEndpoint(endpoints[qp].handle, + global_rdma_ep_handles[i + qp * world_size], qp); + std::cout << "Local rank " << local_rank << " connected to rank " << i << " qp " << qp + << " with handle " << global_rdma_ep_handles[i + qp * world_size] << std::endl; + } } // 4 Register buffer and block sync memory void* buffer; size_t totalSize = maxSize * blocks * threads; - assert(totalSize <= 0x800000000ULL && "Error: totalSize cannot exceed 32GB!"); + assert(totalSize <= 0x1000000000ULL && "Error: totalSize cannot exceed 64GB!"); HIP_RUNTIME_CHECK(hipMalloc(&buffer, totalSize)); HIP_RUNTIME_CHECK(hipMemset(buffer, local_rank, totalSize)); uint32_t* blockSync; @@ -384,9 +410,10 @@ void distRdmaOps(int argc, char* argv[]) { std::vector global_mr_handles(world_size); bootNet.Allgather(&mr_handle, global_mr_handles.data(), sizeof(mr_handle)); global_mr_handles[local_rank] = mr_handle; - RdmaEndpoint* devEndpoint; - HIP_RUNTIME_CHECK(hipMalloc(&devEndpoint, sizeof(RdmaEndpoint))); - HIP_RUNTIME_CHECK(hipMemcpy(devEndpoint, &endpoint, sizeof(RdmaEndpoint), hipMemcpyHostToDevice)); + RdmaEndpoint* devEndpoints; + HIP_RUNTIME_CHECK(hipMalloc(&devEndpoints, num_qp * sizeof(RdmaEndpoint))); + HIP_RUNTIME_CHECK(hipMemcpy(devEndpoints, endpoints.data(), num_qp * sizeof(RdmaEndpoint), + hipMemcpyHostToDevice)); double* bwTable; uint64_t* sizeTable; @@ -402,15 +429,15 @@ void distRdmaOps(int argc, char* argv[]) { for (size_t size = minSize; size <= maxSize; size *= stepFactor) { if (local_rank == 0) { - switch (endpoint.GetProviderType()) { + switch (endpoints[0].GetProviderType()) { case ProviderType::MLX5: - Write<<>>(devEndpoint, global_mr_handles[0], - global_mr_handles[1], size, 1, blockSync); + MultiQpWrite<<>>( + devEndpoints, global_mr_handles[0], global_mr_handles[1], size, 1, blockSync, num_qp); break; #ifdef ENABLE_BNXT case ProviderType::BNXT: - Write<<>>(devEndpoint, global_mr_handles[0], - global_mr_handles[1], size, 1, blockSync); + MultiQpWrite<<>>( + devEndpoints, global_mr_handles[0], global_mr_handles[1], size, 1, blockSync, num_qp); break; #endif default: @@ -430,17 +457,17 @@ void distRdmaOps(int argc, char* argv[]) { if (local_rank == 0) { for (size_t size = minSize; size <= maxSize; size *= stepFactor) { // warmup - switch (endpoint.GetProviderType()) { + switch (endpoints[0].GetProviderType()) { case ProviderType::MLX5: - Write<<>>(devEndpoint, global_mr_handles[0], - global_mr_handles[1], size, warmupIters, - blockSync + 1); + MultiQpWrite<<>>(devEndpoints, global_mr_handles[0], + global_mr_handles[1], size, + warmupIters, blockSync + 1, num_qp); break; #ifdef ENABLE_BNXT case ProviderType::BNXT: - Write<<>>(devEndpoint, global_mr_handles[0], - global_mr_handles[1], size, warmupIters, - blockSync + 1); + MultiQpWrite<<>>(devEndpoints, global_mr_handles[0], + global_mr_handles[1], size, + warmupIters, blockSync + 1, num_qp); break; #endif default: @@ -450,17 +477,17 @@ void distRdmaOps(int argc, char* argv[]) { // test and record HIP_RUNTIME_CHECK(hipEventRecord(start)); - switch (endpoint.GetProviderType()) { + switch (endpoints[0].GetProviderType()) { case ProviderType::MLX5: - Write<<>>(devEndpoint, global_mr_handles[0], - global_mr_handles[1], size, iters, - blockSync + 1 + warmupIters); + MultiQpWrite + <<>>(devEndpoints, global_mr_handles[0], global_mr_handles[1], size, + iters, blockSync + 1 + warmupIters, num_qp); break; #ifdef ENABLE_BNXT case ProviderType::BNXT: - Write<<>>(devEndpoint, global_mr_handles[0], - global_mr_handles[1], size, iters, - blockSync + 1 + warmupIters); + MultiQpWrite + <<>>(devEndpoints, global_mr_handles[0], global_mr_handles[1], size, + iters, blockSync + 1 + warmupIters, num_qp); break; #endif default: @@ -483,19 +510,20 @@ void distRdmaOps(int argc, char* argv[]) { if (local_rank == 0) { printf("\nIBGDA White benchmark:\n"); - printf("Blocks: %zu, Threads: %zu, Iterations: %zu\n", blocks, threads, iters); - printf("%-8s %-12s %-12s %-12s %-12s\n", "Index", "Size(B)", "bw(GB)", "Time(ms)", "Rate(pps)"); + printf("Blocks: %zu, Threads: %zu, Iterations: %zu, QPs:%d \n", blocks, threads, iters, num_qp); + printf("%-8s %-12s %-12s %-12s %-12s\n", "Index", "Size(B)", "bw(GB)", "Time(ms)", + "Rate(Mpps)"); for (size_t i = 0; i < validSizeLog; ++i) { - double rate_pps = (blocks * threads * iters) / (times[i] * MS_TO_S); + double rate_mpps = (blocks * threads * iters) / (times[i] / MS_TO_S) / 1000000.0; printf("%-8zu %-12lu %-12.4f %-12.4f %-12.4f\n", i + 1, sizeTable[i], bwTable[i], times[i], - rate_pps); + rate_mpps); } } bootNet.Finalize(); HIP_RUNTIME_CHECK(hipFree(buffer)); - HIP_RUNTIME_CHECK(hipFree(devEndpoint)); + HIP_RUNTIME_CHECK(hipFree(devEndpoints)); HIP_RUNTIME_CHECK(hipHostFree(bwTable)); HIP_RUNTIME_CHECK(hipHostFree(sizeTable)); HIP_RUNTIME_CHECK(hipHostFree(times)); diff --git a/examples/ops/dispatch_combine/test_dispatch_combine_internode.py b/examples/ops/dispatch_combine/test_dispatch_combine_internode.py index 55d4ef29..202849ef 100644 --- a/examples/ops/dispatch_combine/test_dispatch_combine_internode.py +++ b/examples/ops/dispatch_combine/test_dispatch_combine_internode.py @@ -254,8 +254,8 @@ def run_test_once(self, op, test_data, error_round, round): ) # TODO: test output scales - if self.config.rank == 0: - print("Dispatch Pass") + if self.rank % self.gpu_per_node == 0: + print(f"Node {self.rank // self.gpu_per_node} Dispatch Pass") dist.barrier() @@ -311,8 +311,8 @@ def run_test_once(self, op, test_data, error_round, round): ) assert weight_match, f"Weight assertion failed for token {i}" - if self.config.rank == 0: - print("Combine Pass") + if self.rank % self.gpu_per_node == 0: + print(f"Node {self.rank // self.gpu_per_node} Combine Pass") def test_dispatch_combine(self): op = mori.ops.EpDispatchCombineOp(self.config) diff --git a/examples/shmem/atomic_fetch_thread.cpp b/examples/shmem/atomic_fetch_thread.cpp index 7f8f5550..c931ddbf 100644 --- a/examples/shmem/atomic_fetch_thread.cpp +++ b/examples/shmem/atomic_fetch_thread.cpp @@ -53,16 +53,16 @@ __global__ void AtomicFetchThreadKernel(int myPe, const SymmMemObjPtr memObj) { int threadOffset = globalTid * sizeof(T); if (myPe == sendPe) { - RdmaMemoryRegion source = memObj->GetRdmaMemoryRegion(sendPe); - - ShmemAtomicTypeFetchThread(memObj, threadOffset, source, threadOffset, sendPe, recvPe, - recvPe, AMO_FETCH_AND); + T ret = ShmemAtomicTypeFetchThread(memObj, 2 * sizeof(T), 1, 0, AMO_FETCH_ADD, recvPe); __threadfence_system(); + if (ret == gridDim.x * blockDim.x) { + printf("globalTid: %d ret = %lu atomic fetch is ok!~\n", globalTid, (uint64_t)ret); + } - ShmemQuietThread(); // __syncthreads(); } else { - while (atomicAdd(reinterpret_cast(memObj->localPtr) + globalTid, 0) != sendPe) { + while (AtomicLoadRelaxed(reinterpret_cast(memObj->localPtr) + 2) != + gridDim.x * blockDim.x + 1) { } if (globalTid == 0) { printf("atomic fetch is ok!~\n"); @@ -97,24 +97,67 @@ void testAtomicFetchThread() { SymmMemObjPtr buffObj = ShmemQueryMemObjPtr(buff); assert(buffObj.IsValid()); - // Run uint64 atomic nonfetch - AtomicFetchThreadKernel<<>>(myPe, buffObj); - HIP_RUNTIME_CHECK(hipDeviceSynchronize()); - MPI_Barrier(MPI_COMM_WORLD); - printf("after rank[%d] %lu %lu\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); + for (int iteration = 0; iteration < 10; iteration++) { + if (myPe == 0) { + printf("========== Iteration %d ==========\n", iteration + 1); + } - buffSize = numEle * sizeof(uint32_t); - HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(buff), myPe, numEle)); - HIP_RUNTIME_CHECK(hipDeviceSynchronize()); - printf("before rank[%d] %u %u\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); - // Run uint32 atomic nonfetch - AtomicFetchThreadKernel<<>>(myPe, buffObj); - HIP_RUNTIME_CHECK(hipDeviceSynchronize()); - MPI_Barrier(MPI_COMM_WORLD); - printf("after rank[%d] %u %u\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); + // Run uint64 atomic nonfetch + myHipMemsetD64(buff, myPe, numEle); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + AtomicFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + // Test int64_t atomic nonfetch + buffSize = numEle * sizeof(int64_t); + myHipMemsetD64(buff, myPe, numEle); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + // Run int64 atomic nonfetch + AtomicFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + // Test uint32_t atomic nonfetch + buffSize = numEle * sizeof(uint32_t); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(buff), myPe, numEle)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + // Run uint32 atomic nonfetch + AtomicFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + // Test int32_t atomic nonfetch + buffSize = numEle * sizeof(int32_t); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(buff), myPe, numEle)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + // Run int32 atomic nonfetch + AtomicFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + MPI_Barrier(MPI_COMM_WORLD); // Ensure all processes complete this iteration before next + if (myPe == 0) { + printf("Iteration %d completed\n", iteration + 1); + } + sleep(1); + } // Finalize ShmemFree(buff); diff --git a/examples/shmem/atomic_nonfetch_thread.cpp b/examples/shmem/atomic_nonfetch_thread.cpp index e3e09c26..8d20c41e 100644 --- a/examples/shmem/atomic_nonfetch_thread.cpp +++ b/examples/shmem/atomic_nonfetch_thread.cpp @@ -50,20 +50,25 @@ __global__ void AtomicNonFetchThreadKernel(int myPe, const SymmMemObjPtr memObj) constexpr int recvPe = 1; int globalTid = blockIdx.x * blockDim.x + threadIdx.x; + int globalWarpId = globalTid / warpSize; int threadOffset = globalTid * sizeof(T); if (myPe == sendPe) { RdmaMemoryRegion source = memObj->GetRdmaMemoryRegion(sendPe); - // ShmemAtomicUint64NonFetchThread(memObj, threadOffset, sendPe, recvPe, AMO_SET); - ShmemAtomicTypeNonFetchThread(memObj, threadOffset, source, threadOffset, sendPe, recvPe, - AMO_SET); + // ShmemAtomicUint64NonFetchThread(memObj, threadOffset, sendPe, AMO_SET, recvPe); + if (globalWarpId % 2 == 0) { + ShmemAtomicTypeNonFetchThread(memObj, 2 * sizeof(T), 1, AMO_ADD, recvPe); + } else { + ShmemAtomicTypeNonFetchThread(memObj, 2 * sizeof(T), 1, AMO_ADD, recvPe, 1); + } __threadfence_system(); ShmemQuietThread(); // __syncthreads(); } else { - while (atomicAdd(reinterpret_cast(memObj->localPtr) + globalTid, 0) != sendPe) { + while (AtomicLoadRelaxed(reinterpret_cast(memObj->localPtr) + 2) != + gridDim.x * blockDim.x + 1) { } if (globalTid == 0) { printf("atomic nonfetch is ok!~\n"); @@ -90,32 +95,75 @@ void testAtomicNonFetchThread() { int numEle = threadNum * blockNum; int buffSize = numEle * sizeof(uint64_t); - void* buff = ShmemMalloc(buffSize); + void* buff = ShmemExtMallocWithFlags(buffSize, hipDeviceMallocUncached); myHipMemsetD64(buff, myPe, numEle); HIP_RUNTIME_CHECK(hipDeviceSynchronize()); printf("before rank[%d] %lu %lu\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); + *(reinterpret_cast(buff))); SymmMemObjPtr buffObj = ShmemQueryMemObjPtr(buff); assert(buffObj.IsValid()); - // Run uint64 atomic nonfetch - AtomicNonFetchThreadKernel<<>>(myPe, buffObj); - HIP_RUNTIME_CHECK(hipDeviceSynchronize()); - MPI_Barrier(MPI_COMM_WORLD); - printf("after rank[%d] %lu %lu\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); + // Run atomic operations 10 times + for (int iteration = 0; iteration < 10; iteration++) { + if (myPe == 0) { + printf("========== Iteration %d ==========\n", iteration + 1); + } - buffSize = numEle * sizeof(uint32_t); - HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(buff), myPe, numEle)); - HIP_RUNTIME_CHECK(hipDeviceSynchronize()); - printf("before rank[%d] %u %u\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); - // Run uint32 atomic nonfetch - AtomicNonFetchThreadKernel<<>>(myPe, buffObj); - HIP_RUNTIME_CHECK(hipDeviceSynchronize()); - MPI_Barrier(MPI_COMM_WORLD); - printf("after rank[%d] %u %u\n", myPe, *(reinterpret_cast(buff)), - *(reinterpret_cast(buff) + numEle - 1)); + // Run uint64 atomic nonfetch + myHipMemsetD64(buff, myPe, numEle); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + AtomicNonFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + // Test int64_t atomic nonfetch + buffSize = numEle * sizeof(int64_t); + myHipMemsetD64(buff, myPe, numEle); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + // Run int64 atomic nonfetch + AtomicNonFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + // Test uint32_t atomic nonfetch + buffSize = numEle * sizeof(uint32_t); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(buff), myPe, numEle)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + // Run uint32 atomic nonfetch + AtomicNonFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + + // Test int32_t atomic nonfetch + buffSize = numEle * sizeof(int32_t); + HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast(buff), myPe, numEle)); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + printf("before rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff))); + // Run int32 atomic nonfetch + AtomicNonFetchThreadKernel<<>>(myPe, buffObj); + HIP_RUNTIME_CHECK(hipDeviceSynchronize()); + MPI_Barrier(MPI_COMM_WORLD); + printf("after rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast(buff)), + *(reinterpret_cast(buff) + 2)); + MPI_Barrier(MPI_COMM_WORLD); // Ensure all processes complete this iteration before next + if (myPe == 0) { + printf("Iteration %d completed\n", iteration + 1); + } + sleep(1); + } // Finalize ShmemFree(buff); diff --git a/examples/shmem/concurrent_put_thread.cpp b/examples/shmem/concurrent_put_thread.cpp index 4e3d7038..2a7cf67b 100644 --- a/examples/shmem/concurrent_put_thread.cpp +++ b/examples/shmem/concurrent_put_thread.cpp @@ -40,7 +40,7 @@ __global__ void ConcurrentPutThreadKernel(int myPe, const SymmMemObjPtr memObj) if (myPe == sendPe) { RdmaMemoryRegion source = memObj->GetRdmaMemoryRegion(myPe); - ShmemPutMemNbiThread(memObj, threadOffset, source, threadOffset, sizeof(uint32_t), recvPe); + ShmemPutMemNbiThread(memObj, threadOffset, source, threadOffset, sizeof(uint32_t), recvPe, 1); __threadfence_system(); if (blockIdx.x == 0) diff --git a/examples/utils/args_parser.cpp b/examples/utils/args_parser.cpp index d8927f98..5857ba7a 100644 --- a/examples/utils/args_parser.cpp +++ b/examples/utils/args_parser.cpp @@ -161,10 +161,11 @@ void BenchmarkConfig::readArgs(int argc, char** argv) { {"scope", required_argument, 0, 's'}, {"atomic_op", required_argument, 0, 'a'}, {"stride", required_argument, 0, 'i'}, + {"num_qp", required_argument, 0, 'q'}, {0, 0, 0, 0}}; int opt_idx = 0; - while ((c = getopt_long(argc, argv, "hb:e:f:n:w:c:t:d:o:s:a:i:", long_opts, &opt_idx)) != -1) { + while ((c = getopt_long(argc, argv, "hb:e:f:n:w:c:t:d:o:s:a:i:q:", long_opts, &opt_idx)) != -1) { switch (c) { case 'h': printf( @@ -244,6 +245,9 @@ void BenchmarkConfig::readArgs(int argc, char** argv) { case 'a': atomicOpParse(optarg); break; + case 'q': + atolScaled(optarg, &num_qp); + break; case '?': if (optopt == 'c') { fprintf(stderr, "Option -%c requires an argument.\n", optopt); @@ -269,13 +273,13 @@ void BenchmarkConfig::readArgs(int argc, char** argv) { printf("Runtime options after parsing command line arguments \n"); printf( - "min_size: %zu, max_size: %zu, step_factor: %zu, iterations: %zu, " + "min_size: %zu, max_size: %zu, num_qp: %zu, step_factor: %zu, iterations: %zu, " "warmup iterations: %zu, number of ctas: %zu, threads per cta: %zu " "stride: %zu, datatype: %s, reduce_op: %s, threadgroup_scope: %s, " "atomic_op: %s, dir: %s, report_msgrate: %d, bidirectional: %d, " "putget_issue: %s\n", - min_size, max_size, step_factor, iters, warmup_iters, num_blocks, threads_per_block, stride, - datatype.name.c_str(), reduce_op.name.c_str(), threadgroup_scope.name.c_str(), + min_size, max_size, num_qp, step_factor, iters, warmup_iters, num_blocks, threads_per_block, + stride, datatype.name.c_str(), reduce_op.name.c_str(), threadgroup_scope.name.c_str(), test_amo.name.c_str(), dir.name.c_str(), report_msgrate, bidirectional, putget_issue.name.c_str()); printf( diff --git a/examples/utils/args_parser.hpp b/examples/utils/args_parser.hpp index f70f1230..d67fb6d8 100644 --- a/examples/utils/args_parser.hpp +++ b/examples/utils/args_parser.hpp @@ -109,6 +109,7 @@ class BenchmarkConfig { size_t getStepFactor() const { return step_factor; } size_t getMaxSizeLog() const { return max_size_log; } size_t getStride() const { return stride; } + size_t getNumQp() const { return num_qp; } bool isBidirectional() const { return bidirectional; } bool isReportMsgrate() const { return report_msgrate; } @@ -130,6 +131,7 @@ class BenchmarkConfig { size_t step_factor = 2; size_t max_size_log = 1; size_t stride = 1; + size_t num_qp = 1; bool bidirectional = false; bool report_msgrate = false; diff --git a/include/mori/application/context/context.hpp b/include/mori/application/context/context.hpp index 5cfe8720..793556d0 100644 --- a/include/mori/application/context/context.hpp +++ b/include/mori/application/context/context.hpp @@ -43,6 +43,7 @@ class Context { TransportType GetTransportType(int destRank) const { return transportTypes[destRank]; } std::vector GetTransportTypes() const { return transportTypes; } + int GetNumQpPerPe() const { return numQpPerPe; } RdmaContext* GetRdmaContext() const { return rdmaContext.get(); } RdmaDeviceContext* GetRdmaDeviceContext() const { return rdmaDeviceContext.get(); } @@ -57,6 +58,7 @@ class Context { private: BootstrapNetwork& bootNet; int rankInNode{-1}; + int numQpPerPe{4}; std::vector hostnames; std::vector transportTypes; diff --git a/include/mori/application/transport/rdma/providers/bnxt/bnxt.hpp b/include/mori/application/transport/rdma/providers/bnxt/bnxt.hpp index ae60fa6e..b49c9d46 100644 --- a/include/mori/application/transport/rdma/providers/bnxt/bnxt.hpp +++ b/include/mori/application/transport/rdma/providers/bnxt/bnxt.hpp @@ -39,6 +39,9 @@ namespace mori { namespace application { #ifdef ENABLE_BNXT +// BNXT UDP sport configuration constants +static constexpr uint32_t BNXT_UDP_SPORT_ARRAY_SIZE = 4; + /* ---------------------------------------------------------------------------------------------- */ /* Device Attributes */ /* ---------------------------------------------------------------------------------------------- */ @@ -73,26 +76,31 @@ class BnxtCqContainer { ibv_cq* cq{nullptr}; }; +class BnxtDeviceContext; // Forward declaration + class BnxtQpContainer { public: - BnxtQpContainer(ibv_context* context, const RdmaEndpointConfig& config, ibv_cq* cq, ibv_pd* pd); + BnxtQpContainer(ibv_context* context, const RdmaEndpointConfig& config, ibv_cq* cq, ibv_pd* pd, BnxtDeviceContext* device_context); ~BnxtQpContainer(); void ModifyRst2Init(); void ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, const ibv_port_attr& portAttr, - const ibv_device_attr_ex& deviceAttr); + const ibv_device_attr_ex& deviceAttr, uint32_t qpId = 0); void ModifyRtr2Rts(const RdmaEndpointHandle& local_handle, const RdmaEndpointHandle& remote_handle); void* GetSqAddress(); void* GetMsntblAddress(); void* GetRqAddress(); + + BnxtDeviceContext* GetDeviceContext() { return device_context; } private: void DestroyQueuePair(); public: ibv_context* context; + BnxtDeviceContext* device_context; public: RdmaEndpointConfig config; @@ -110,6 +118,11 @@ class BnxtQpContainer { void* qpUar{nullptr}; void* qpUarPtr{nullptr}; ibv_qp* qp{nullptr}; + + // Atomic internal buffer fields + void* atomicIbufAddr{nullptr}; + size_t atomicIbufSize{0}; + ibv_mr* atomicIbufMr{nullptr}; }; /* ---------------------------------------------------------------------------------------------- */ @@ -122,7 +135,7 @@ class BnxtDeviceContext : public RdmaDeviceContext { virtual RdmaEndpoint CreateRdmaEndpoint(const RdmaEndpointConfig&) override; virtual void ConnectEndpoint(const RdmaEndpointHandle& local, - const RdmaEndpointHandle& remote) override; + const RdmaEndpointHandle& remote, uint32_t qpId = 0) override; private: uint32_t pdn; diff --git a/include/mori/application/transport/rdma/providers/ibverbs/ibverbs.hpp b/include/mori/application/transport/rdma/providers/ibverbs/ibverbs.hpp index d2ca0ec1..a985a4b7 100644 --- a/include/mori/application/transport/rdma/providers/ibverbs/ibverbs.hpp +++ b/include/mori/application/transport/rdma/providers/ibverbs/ibverbs.hpp @@ -34,7 +34,7 @@ class IBVerbsDeviceContext : public RdmaDeviceContext { virtual RdmaEndpoint CreateRdmaEndpoint(const RdmaEndpointConfig&) override; virtual void ConnectEndpoint(const RdmaEndpointHandle& local, - const RdmaEndpointHandle& remote) override; + const RdmaEndpointHandle& remote, uint32_t qpId = 0) override; private: std::unordered_map cqPool; diff --git a/include/mori/application/transport/rdma/providers/mlx5/mlx5.hpp b/include/mori/application/transport/rdma/providers/mlx5/mlx5.hpp index 248e8b0e..5527f4c0 100644 --- a/include/mori/application/transport/rdma/providers/mlx5/mlx5.hpp +++ b/include/mori/application/transport/rdma/providers/mlx5/mlx5.hpp @@ -74,18 +74,22 @@ class Mlx5CqContainer { mlx5dv_devx_obj* cq{nullptr}; }; +class Mlx5DeviceContext; // Forward declaration + class Mlx5QpContainer { public: Mlx5QpContainer(ibv_context* context, const RdmaEndpointConfig& config, uint32_t cqn, - uint32_t pdn); + uint32_t pdn, Mlx5DeviceContext* device_context); ~Mlx5QpContainer(); void ModifyRst2Init(); - void ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, const ibv_port_attr& portAttr); + void ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, const ibv_port_attr& portAttr, uint32_t qpId = 0); void ModifyRtr2Rts(const RdmaEndpointHandle& local_handle); void* GetSqAddress(); void* GetRqAddress(); + + Mlx5DeviceContext* GetDeviceContext() { return device_context; } private: void ComputeQueueAttrs(const RdmaEndpointConfig& config); @@ -94,6 +98,7 @@ class Mlx5QpContainer { public: ibv_context* context; + Mlx5DeviceContext* device_context; public: RdmaEndpointConfig config; @@ -110,6 +115,11 @@ class Mlx5QpContainer { mlx5dv_devx_uar* qpUar{nullptr}; void* qpUarPtr{nullptr}; mlx5dv_devx_obj* qp{nullptr}; + + // Atomic internal buffer fields + void* atomicIbufAddr{nullptr}; + size_t atomicIbufSize{0}; + ibv_mr* atomicIbufMr{nullptr}; }; /* ---------------------------------------------------------------------------------------------- */ @@ -122,7 +132,7 @@ class Mlx5DeviceContext : public RdmaDeviceContext { virtual RdmaEndpoint CreateRdmaEndpoint(const RdmaEndpointConfig&) override; virtual void ConnectEndpoint(const RdmaEndpointHandle& local, - const RdmaEndpointHandle& remote) override; + const RdmaEndpointHandle& remote, uint32_t qpId = 0) override; private: uint32_t pdn; diff --git a/include/mori/application/transport/rdma/rdma.hpp b/include/mori/application/transport/rdma/rdma.hpp index c5a443fa..694bcf66 100644 --- a/include/mori/application/transport/rdma/rdma.hpp +++ b/include/mori/application/transport/rdma/rdma.hpp @@ -67,6 +67,12 @@ RdmaDeviceVendorId ToRdmaDeviceVendorId(T v) { #define PAGESIZE uint32_t(sysconf(_SC_PAGE_SIZE)) // #define OUTSTANDING_TABLE_SIZE (65536) +// UDP sport configuration constants for multi-provider support +static constexpr uint32_t RDMA_UDP_SPORT_ARRAY_SIZE = 4; + +// Atomic internal buffer configuration +static constexpr size_t ATOMIC_IBUF_SLOT_SIZE = 8; // Each atomic ibuf slot is 8 bytes + /* -------------------------------------------------------------------------- */ /* Rdma Data Structure */ /* -------------------------------------------------------------------------- */ @@ -80,6 +86,7 @@ struct RdmaEndpointConfig { bool onGpu{false}; bool withCompChannel{false}; bool enableSrq{false}; + uint32_t atomicIbufSlots{512}; // Number of atomic internal buffer slots, each slot is 8B }; struct InfiniBandEndpointHandle { @@ -125,6 +132,13 @@ struct WorkQueueAttrs { uint32_t offset{0}; }; +struct RdmaMemoryRegion { + uintptr_t addr{0}; + uint32_t lkey{0}; + uint32_t rkey{0}; + size_t length{0}; +}; + struct RdmaEndpoint { RdmaDeviceVendorId vendorId{RdmaDeviceVendorId::Unknown}; RdmaEndpointHandle handle; @@ -135,6 +149,9 @@ struct RdmaEndpoint { core::CompletionQueueHandle cqHandle; core::IBVerbsHandle ibvHandle; + // Atomic internal buffer (ibuf) - independent MR for atomic operations + core::IbufHandle atomicIbuf; + __device__ __host__ core::ProviderType GetProviderType() { if (vendorId == RdmaDeviceVendorId::Mellanox) { return core::ProviderType::MLX5; @@ -150,13 +167,6 @@ struct RdmaEndpoint { class RdmaDevice; -struct RdmaMemoryRegion { - uintptr_t addr{0}; - uint32_t lkey{0}; - uint32_t rkey{0}; - size_t length{0}; -}; - /* -------------------------------------------------------------------------- */ /* RdmaDeviceContext */ /* -------------------------------------------------------------------------- */ @@ -173,10 +183,11 @@ class RdmaDeviceContext { virtual RdmaEndpoint CreateRdmaEndpoint(const RdmaEndpointConfig&) { assert(false && "not implemented"); } - void ConnectEndpoint(const RdmaEndpoint& local, const RdmaEndpoint& remote) { - ConnectEndpoint(local.handle, remote.handle); + void ConnectEndpoint(const RdmaEndpoint& local, const RdmaEndpoint& remote, uint32_t qpId = 0) { + ConnectEndpoint(local.handle, remote.handle, qpId); } - virtual void ConnectEndpoint(const RdmaEndpointHandle& local, const RdmaEndpointHandle& remote) { + virtual void ConnectEndpoint(const RdmaEndpointHandle& local, const RdmaEndpointHandle& remote, + uint32_t qpId = 0) { assert(false && "not implemented"); } @@ -184,12 +195,21 @@ class RdmaDeviceContext { RdmaDevice* GetRdmaDevice(); ibv_context* GetIbvContext(); + ibv_pd* GetIbvPd() { return pd; } ibv_srq* GetIbvSrq() { return srq; } + uint16_t GetUdpSport(uint32_t qpId) const; + protected: ibv_pd* pd{nullptr}; ibv_srq* srq{nullptr}; + // Shared UDP sport configuration for all RDMA providers + uint16_t udp_sport_setting[RDMA_UDP_SPORT_ARRAY_SIZE]; + + // Initialize UDP sport configuration from environment variables + void InitializeUdpSportConfiguration(); + private: RdmaDevice* device; std::unordered_map mrPool; diff --git a/include/mori/core/transport/rdma/primitives.hpp b/include/mori/core/transport/rdma/primitives.hpp index 722cf2eb..0a3ceb91 100644 --- a/include/mori/core/transport/rdma/primitives.hpp +++ b/include/mori/core/transport/rdma/primitives.hpp @@ -23,6 +23,7 @@ #include #include + #include "infiniband/verbs.h" namespace mori { @@ -72,12 +73,12 @@ struct WorkQueueHandle { union { struct { uint32_t msntblSlotIdx; - uint32_t psnIdx; // for bnxt msn psn index calculate + uint32_t psnIdx; // for bnxt msn psn index calculate }; uint64_t msnPack{0}; }; void* sqAddr{nullptr}; - void* msntblAddr{nullptr}; // for bnxt + void* msntblAddr{nullptr}; // for bnxt void* rqAddr{nullptr}; void* dbrRecAddr{nullptr}; void* dbrAddr{nullptr}; @@ -93,9 +94,9 @@ struct CompletionQueueHandle { void* cqAddr{nullptr}; void* dbrRecAddr{nullptr}; void* dbrAddr{nullptr}; - uint32_t consIdx{0}; // numbers of cqe that have been completed - uint32_t needConsIdx{0}; // numbers of cqe that should be consumed - uint32_t activeIdx{0}; // numbers of cqe that under processing but not completed + uint32_t consIdx{0}; // numbers of cqe that have been completed + uint32_t needConsIdx{0}; // numbers of cqe that should be consumed + uint32_t activeIdx{0}; // numbers of cqe that under processing but not completed uint32_t cq_consumer{0}; uint32_t cqeNum{0}; uint32_t cqeSize{0}; @@ -109,6 +110,15 @@ struct IBVerbsHandle { ibv_comp_channel* compCh{nullptr}; }; +struct IbufHandle { + uintptr_t addr{0}; + uint32_t lkey{0}; + uint32_t rkey{0}; + uint32_t nslots{0}; + uint32_t head{0}; + uint32_t tail{0}; +}; + /* ---------------------------------------------------------------------------------------------- */ /* Utility Functions */ /* ---------------------------------------------------------------------------------------------- */ diff --git a/include/mori/core/transport/rdma/providers/bnxt/bnxt_device_primitives.hpp b/include/mori/core/transport/rdma/providers/bnxt/bnxt_device_primitives.hpp index 399966c6..71f3bd3b 100644 --- a/include/mori/core/transport/rdma/providers/bnxt/bnxt_device_primitives.hpp +++ b/include/mori/core/transport/rdma/providers/bnxt/bnxt_device_primitives.hpp @@ -66,7 +66,7 @@ inline __device__ uint64_t bnxt_re_init_db_hdr(int32_t indx, uint32_t toggle, ui inline __device__ void atomic_add_packed_msn_and_psn(uint64_t* msnPack, uint32_t incSlot, uint32_t incPsn, uint32_t* oldSlot, uint32_t* oldPsn) { - uint64_t expected = __hip_atomic_load(msnPack, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + uint64_t expected = __hip_atomic_load(msnPack, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); while (true) { uint32_t curSlot = static_cast(expected & 0xFFFFFFFF); uint32_t curPsn = static_cast((expected >> 32) & 0xFFFFFFFF); @@ -493,6 +493,7 @@ inline __device__ uint64_t BnxtPrepareAtomicWqe(WorkQueueHandle& wq, uint32_t cu uint32_t opcode = BNXT_RE_WR_OPCD_ATOMIC_FA; uint64_t data = val_1 ? *static_cast(val_1) : 0; uint64_t cmp = val_2 ? *static_cast(val_2) : 0; + // printf("BNXT atomic values: data=0x%lx, cmp=0x%lx\n", data, cmp); switch (amo_op) { case AMO_FETCH_INC: @@ -501,15 +502,6 @@ inline __device__ uint64_t BnxtPrepareAtomicWqe(WorkQueueHandle& wq, uint32_t cu data = 1; break; } - // TODO: dont have opmod, is set will work? - case AMO_SIGNAL: - case AMO_SIGNAL_SET: - case AMO_SWAP: - case AMO_SET: { - opcode = BNXT_RE_WR_OPCD_ATOMIC_CS; - cmp = 0; - break; - } case AMO_FETCH_ADD: case AMO_SIGNAL_ADD: case AMO_ADD: { @@ -544,7 +536,7 @@ inline __device__ uint64_t BnxtPrepareAtomicWqe(WorkQueueHandle& wq, uint32_t cu sge.pa = (uint64_t)laddr; sge.lkey = lkey & 0xffffffff; - sge.length = bytes; + sge.length = 8; char* base = reinterpret_cast(queueBuffAddr) + slotIdx * BNXT_RE_SLOT_SIZE; ThreadCopy(base + 0 * BNXT_RE_SLOT_SIZE, reinterpret_cast(&hdr), sizeof(hdr)); @@ -662,38 +654,68 @@ inline __device__ void UpdateDbrAndRingDbRecv(void* dbrRecAd /* ---------------------------------------------------------------------------------------------- */ /* Completion Queue */ /* ---------------------------------------------------------------------------------------------- */ +inline __device__ int PollSingleCqe(volatile char* cqe, uint32_t consIdx, uint32_t* wqeIdx) { + // Extract completion index using HIP atomic load + const uint32_t con_indx = __hip_atomic_load( + reinterpret_cast(const_cast(cqe) + offsetof(bnxt_re_req_cqe, con_indx)), + __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM); + + if (wqeIdx) { + *wqeIdx = con_indx & 0xFFFF; + } + + // Check completion status using HIP atomic load + volatile char* flgSrc = cqe + sizeof(struct bnxt_re_req_cqe); + const uint32_t flg_val = __hip_atomic_load(reinterpret_cast(const_cast(flgSrc)), + __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM); + const uint8_t status = (flg_val >> BNXT_RE_BCQE_STATUS_SHIFT) & BNXT_RE_BCQE_STATUS_MASK; + + if (status == BNXT_RE_REQ_ST_OK) { + return BNXT_RE_REQ_ST_OK; + } + + return status; +} + template <> inline __device__ int PollCqOnce(void* cqeAddr, uint32_t cqeNum, uint32_t consIdx, uint32_t* wqeIdx) { - uint32_t cqeIdx = consIdx % cqeNum; + // Fast path for single CQE (most common case) - eliminate all branching + if (cqeNum == 1) { + return PollSingleCqe(static_cast(cqeAddr), consIdx, wqeIdx); + } + // Slower path for multiple CQEs + const uint32_t cqeIdx = consIdx % cqeNum; volatile char* cqe = static_cast(cqeAddr) + 2 * BNXT_RE_SLOT_SIZE * cqeIdx; volatile char* flgSrc = cqe + sizeof(struct bnxt_re_req_cqe); - uint32_t phase = BNXT_RE_QUEUE_START_PHASE ^ ((consIdx / cqeNum) & 0x1); - uint32_t flg_val = *reinterpret_cast(flgSrc); - uint32_t con_indx = + const uint32_t flg_val = *reinterpret_cast(flgSrc); + const uint32_t expected_phase = BNXT_RE_QUEUE_START_PHASE ^ ((consIdx / cqeNum) & 0x1); + + if ((flg_val & BNXT_RE_BCQE_PH_MASK) != expected_phase) { + return -1; // CQE not ready yet + } + + // Extract completion index and check status + const uint32_t con_indx = *reinterpret_cast(cqe + offsetof(bnxt_re_req_cqe, con_indx)); - // printf("GPU flg_val = 0x%08X (%u), phase = 0x%08X (%u) consIdx %u, cqeNum %u\n", - // flg_val & BNXT_RE_BCQE_PH_MASK, flg_val & BNXT_RE_BCQE_PH_MASK, phase, phase, consIdx, - // cqeNum); - if (((flg_val)&BNXT_RE_BCQE_PH_MASK) == (phase)) { - uint8_t status = (flg_val >> BNXT_RE_BCQE_STATUS_SHIFT) & BNXT_RE_BCQE_STATUS_MASK; - - if (status != BNXT_RE_REQ_ST_OK) { - printf("CQ Error (%u)\n", status); - return status; - } - if (wqeIdx) { - *wqeIdx = con_indx & 0xFFFF; - } - return 0; + + if (wqeIdx) { + *wqeIdx = con_indx & 0xFFFF; } - return -1; + + const uint8_t status = (flg_val >> BNXT_RE_BCQE_STATUS_SHIFT) & BNXT_RE_BCQE_STATUS_MASK; + + if (__builtin_expect(status == BNXT_RE_REQ_ST_OK, 1)) { + return BNXT_RE_REQ_ST_OK; + } + + return status; } template <> inline __device__ int PollCq(void* cqAddr, uint32_t cqeNum, uint32_t* consIdx) { - uint32_t curConsIdx = atomicAdd(consIdx, 1); + const uint32_t curConsIdx = atomicAdd(consIdx, 1); int opcode = -1; do { opcode = PollCqOnce(cqAddr, cqeNum, curConsIdx, nullptr); @@ -701,32 +723,36 @@ inline __device__ int PollCq(void* cqAddr, uint32_t cqeNum, asm volatile("" ::: "memory"); } while (opcode < 0); + // Handle error cases if (opcode != BNXT_RE_REQ_ST_OK) { auto error = BnxtHandleErrorCqe(opcode); - printf("(%s:%d) CQE error: %s\n", __FILE__, __LINE__, IbvWcStatusString(error)); + printf("[BNXT PollCq] CQE error: %s (opcode: %d) at %s:%d\n", IbvWcStatusString(error), opcode, + __FILE__, __LINE__); return opcode; } - return opcode; + + return BNXT_RE_REQ_ST_OK; } template <> inline __device__ int PollCq(void* cqAddr, uint32_t cqeNum, uint32_t* consIdx, uint16_t* wqeCounter) { - uint32_t curConsIdx = *consIdx; + const uint32_t curConsIdx = *consIdx; int opcode = -1; uint32_t wqeIdx; do { opcode = PollCqOnce(cqAddr, cqeNum, curConsIdx, &wqeIdx); asm volatile("" ::: "memory"); } while (opcode < 0); - + *wqeCounter = (uint16_t)(wqeIdx & 0xFFFF); if (opcode != BNXT_RE_REQ_ST_OK) { auto error = BnxtHandleErrorCqe(opcode); - printf("(%s:%d) CQE error: %s\n", __FILE__, __LINE__, IbvWcStatusString(error)); + printf("[BNXT PollCq] CQE error: %s (opcode: %d), wqeCounter: %u at %s:%d\n", + IbvWcStatusString(error), opcode, *wqeCounter, __FILE__, __LINE__); return opcode; } - *wqeCounter = (uint16_t)(wqeIdx & 0xFFFF); - return opcode; + + return BNXT_RE_REQ_ST_OK; } template <> diff --git a/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp b/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp index 070ca0d3..2e79a1b8 100644 --- a/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp +++ b/include/mori/core/transport/rdma/providers/mlx5/mlx5_device_primitives.hpp @@ -258,14 +258,14 @@ inline __device__ uint64_t PostWriteInline(WorkQueueHandle& /* Atomic APIs */ /* ---------------------------------------------------------------------------------------------- */ -inline __device__ uint64_t mlx5PrepareAtomicWqe(WorkQueueHandle& wq, uint32_t curPostIdx, - bool cqeSignal, uint32_t qpn, uintptr_t laddr, - uint64_t lkey, uintptr_t raddr, uint64_t rkey, - void* val_1, void* val_2, uint32_t bytes, - atomicType amo_op) { +inline __device__ uint64_t mlx5PrepareAtomicWqe_v1(WorkQueueHandle& wq, uint32_t curPostIdx, + bool cqeSignal, uint32_t qpn, uintptr_t laddr, + uint64_t lkey, uintptr_t raddr, uint64_t rkey, + void* val_1, void* val_2, uint32_t bytes, + atomicType amo_op) { uint8_t signalFlag = cqeSignal ? MLX5_WQE_CTRL_CQ_UPDATE : 0x00; - void* queueBuffAddr = wq.rqAddr; - uint32_t wqeNum = wq.rqWqeNum; + void* queueBuffAddr = wq.sqAddr; + uint32_t wqeNum = wq.sqWqeNum; uint32_t numWqesPerCmd = get_num_wqes_in_atomic(amo_op, bytes); uint32_t wqeIdx = curPostIdx & (wqeNum - 1); @@ -519,6 +519,80 @@ inline __device__ uint64_t mlx5PrepareAtomicWqe(WorkQueueHandle& wq, uint32_t cu return reinterpret_cast(wqeCtrlSeg)[0]; } +inline __device__ uint64_t mlx5PrepareAtomicWqe(WorkQueueHandle& wq, uint32_t curPostIdx, + bool cqeSignal, uint32_t qpn, uintptr_t laddr, + uint64_t lkey, uintptr_t raddr, uint64_t rkey, + void* val_1, void* val_2, uint32_t bytes, + atomicType amo_op) { + uint8_t signalFlag = cqeSignal ? MLX5_WQE_CTRL_CQ_UPDATE : 0x00; + void* queueBuffAddr = wq.sqAddr; + uint32_t wqeNum = wq.sqWqeNum; + + uint32_t wqeIdx = curPostIdx & (wqeNum - 1); + void* wqeAddr = reinterpret_cast(queueBuffAddr) + (wqeIdx << MLX5_SEND_WQE_SHIFT); + + constexpr int atomicWqeSize = + sizeof(mlx5_wqe_ctrl_seg) + sizeof(mlx5_wqe_raddr_seg) + 2 * sizeof(mlx5_wqe_atomic_seg); + constexpr int numOctoWords = CeilDiv(atomicWqeSize, 16); + assert(numOctoWords == 4); + + struct mlx5_wqe_ctrl_seg* wqeCtrlSeg = reinterpret_cast(wqeAddr); + struct mlx5_wqe_raddr_seg* wqeRaddrSeg = reinterpret_cast( + reinterpret_cast(wqeAddr) + sizeof(mlx5_wqe_ctrl_seg)); + struct mlx5_wqe_atomic_seg* wqeAtomicSeg = reinterpret_cast( + reinterpret_cast(wqeAddr) + sizeof(mlx5_wqe_ctrl_seg) + sizeof(mlx5_wqe_raddr_seg)); + struct mlx5_wqe_data_seg* wqeDataSeg = reinterpret_cast( + reinterpret_cast(wqeAddr) + sizeof(mlx5_wqe_ctrl_seg) + sizeof(mlx5_wqe_raddr_seg) + + sizeof(mlx5_wqe_atomic_seg)); + + uint64_t data = val_1 ? *static_cast(val_1) : 0; + uint64_t cmp = val_2 ? *static_cast(val_2) : 0; + + uint32_t opcode = MLX5_OPCODE_ATOMIC_FA; + switch (amo_op) { + case AMO_FETCH_INC: + case AMO_INC: { + opcode = MLX5_OPCODE_ATOMIC_FA; + data = 1; + break; + } + case AMO_FETCH_ADD: + case AMO_SIGNAL_ADD: + case AMO_ADD: { + opcode = MLX5_OPCODE_ATOMIC_FA; + break; + } + case AMO_FETCH: { + opcode = MLX5_OPCODE_ATOMIC_FA; + data = 0; + break; + } + case AMO_COMPARE_SWAP: { + opcode = MLX5_OPCODE_ATOMIC_CS; + break; + } + default: { + printf("Error: unsupported atomic type (%d)\n", amo_op); + assert(0); + } + } + wqeCtrlSeg->opmod_idx_opcode = HTOBE32(((curPostIdx & 0xffff) << 8) | opcode); + wqeCtrlSeg->qpn_ds = HTOBE32((qpn << 8) | numOctoWords); + wqeCtrlSeg->fm_ce_se = signalFlag; + + wqeRaddrSeg->raddr = HTOBE64(raddr); + wqeRaddrSeg->rkey = HTOBE32(rkey); + wqeRaddrSeg->reserved = 0; + + wqeAtomicSeg->swap_add = HTOBE64(data); + wqeAtomicSeg->compare = HTOBE64(cmp); + + wqeDataSeg->byte_count = HTOBE32(8); + wqeDataSeg->addr = HTOBE64(laddr); + wqeDataSeg->lkey = HTOBE32(lkey); + return reinterpret_cast(wqeCtrlSeg)[0]; +} + template <> inline __device__ uint64_t PostAtomic( WorkQueueHandle& wq, uint32_t curPostIdx, uint32_t curMsntblSlotIdx, uint32_t curPsnIdx, @@ -688,7 +762,8 @@ inline __device__ int PollCq(void* cqAddr, uint32_t cqeNum, } template <> -inline __device__ void UpdateCqDbrRecord(void* dbrRecAddr, uint32_t cons_idx, uint32_t cqeNum) { +inline __device__ void UpdateCqDbrRecord(void* dbrRecAddr, uint32_t cons_idx, + uint32_t cqeNum) { reinterpret_cast(dbrRecAddr)[MLX5_CQ_SET_CI] = HTOBE32(cons_idx & 0xffffff); } diff --git a/include/mori/core/transport/rdma/providers/utils.h b/include/mori/core/transport/rdma/providers/utils.h index ebcb013d..8b9b8dcc 100644 --- a/include/mori/core/transport/rdma/providers/utils.h +++ b/include/mori/core/transport/rdma/providers/utils.h @@ -154,22 +154,22 @@ static __device__ __host__ void DumpMlx5Wqe(void* wqeBaseAddr, uint32_t idx) { } static __device__ __host__ uint32_t get_num_wqes_in_atomic(atomicType amo_op, uint32_t bytes) { - if (bytes == 8) { - // RC - switch (amo_op) { - case AMO_SIGNAL: - case AMO_SIGNAL_SET: - case AMO_SWAP: - case AMO_SET: - case AMO_FETCH_AND: - case AMO_AND: - case AMO_FETCH_OR: - case AMO_OR: - return 2; - default: - break; - } - } + // if (bytes == 8) { + // // RC + // switch (amo_op) { + // case AMO_SIGNAL: + // case AMO_SIGNAL_SET: + // case AMO_SWAP: + // case AMO_SET: + // case AMO_FETCH_AND: + // case AMO_AND: + // case AMO_FETCH_OR: + // case AMO_OR: + // return 2; + // default: + // break; + // } + // } return 1; } diff --git a/include/mori/core/utils.hpp b/include/mori/core/utils.hpp index f8b59e8f..dbf20771 100644 --- a/include/mori/core/utils.hpp +++ b/include/mori/core/utils.hpp @@ -187,6 +187,11 @@ __device__ inline void AcquireLock(uint32_t* lockVar) { } } +__device__ inline bool AcquireLockOnce(uint32_t* lockVar) { + return atomicCAS(lockVar, 0, 1) == 0; +} + + __device__ inline void ReleaseLock(uint32_t* lockVar) { atomicExch(lockVar, 0); } } // namespace core diff --git a/include/mori/ops/dispatch_combine/dispatch_combine.hpp b/include/mori/ops/dispatch_combine/dispatch_combine.hpp index 3cdd773b..60b0502c 100644 --- a/include/mori/ops/dispatch_combine/dispatch_combine.hpp +++ b/include/mori/ops/dispatch_combine/dispatch_combine.hpp @@ -184,6 +184,7 @@ class EpDispatchCombineHandle { // Record number of tokens that will be received from other PE mori::application::SymmMemObjPtr recvTokenNumMemObj; mori::application::SymmMemObjPtr sendTokenNumMemObj; + mori::application::SymmMemObjPtr sendAtomicSignalMemObj; // Barrier for intra-grid synchronization uint32_t* dispatchGridBarrier{nullptr}; @@ -238,6 +239,7 @@ struct EpDispatchCombineArgs { mori::application::SymmMemObjPtr shmemOutIndicesMemObj; mori::application::SymmMemObjPtr recvTokenNumMemObj; mori::application::SymmMemObjPtr sendTokenNumMemObj; + mori::application::SymmMemObjPtr sendAtomicSignalMemObj; uint32_t* dispatchGridBarrier{nullptr}; uint32_t* combineGridBarrier{nullptr}; index_t* destPeTokenCounter{nullptr}; @@ -281,6 +283,7 @@ EpDispatchCombineArgs GetEpDispatchCombineArgs(const EpDispatchCombineHandle& args.shmemOutIndicesMemObj = handle.shmemOutIndicesMemObj; args.recvTokenNumMemObj = handle.recvTokenNumMemObj; args.sendTokenNumMemObj = handle.sendTokenNumMemObj; + args.sendAtomicSignalMemObj = handle.sendAtomicSignalMemObj; args.dispatchGridBarrier = handle.dispatchGridBarrier; args.combineGridBarrier = handle.combineGridBarrier; args.dispReceiverIdxMap = handle.dispReceiverIdxMap; diff --git a/include/mori/shmem/shmem_device_api.hpp b/include/mori/shmem/shmem_device_api.hpp index 4e9112db..e685f547 100644 --- a/include/mori/shmem/shmem_device_api.hpp +++ b/include/mori/shmem/shmem_device_api.hpp @@ -45,6 +45,20 @@ namespace shmem { assert(false); \ } +#define DISPATCH_TRANSPORT_DATA_TYPE_WITH_RETURN(func, pe, type, ...) \ + [&]() { \ + GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); \ + application::TransportType transportType = globalGpuStates->transportTypes[pe]; \ + if (transportType == application::TransportType::RDMA) { \ + return func(__VA_ARGS__); \ + } else if (transportType == application::TransportType::P2P) { \ + return func(__VA_ARGS__); \ + } else { \ + assert(false); \ + return type{}; \ + } \ + }() + /* ---------------------------------------------------------------------------------------------- */ /* Synchronization */ /* ---------------------------------------------------------------------------------------------- */ @@ -59,55 +73,61 @@ inline __device__ void ShmemQuietThread(int pe) { /* ---------------------------------------------------------------------------------------------- */ /* Point-to-Point */ /* ---------------------------------------------------------------------------------------------- */ -#define DEFINE_SHMEM_PUT_MEM_NBI_API_TEMPLATE(Scope) \ - inline __device__ void ShmemPutMemNbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe) { \ - DISPATCH_TRANSPORT_TYPE(ShmemPutMemNbi##Scope##Kernel, pe, dest, destOffset, source, \ - sourceOffset, bytes, pe); \ - } \ - inline __device__ void ShmemPutMemNbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::SymmMemObjPtr source, size_t sourceOffset, size_t bytes, int pe) { \ - int rank = GetGlobalGpuStatesPtr()->rank; \ - ShmemPutMemNbi##Scope(dest, destOffset, source->GetRdmaMemoryRegion(rank), sourceOffset, \ - bytes, pe); \ +#define DEFINE_SHMEM_PUT_MEM_NBI_API_TEMPLATE(Scope) \ + inline __device__ void ShmemPutMemNbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, \ + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe, \ + int qpId = 0) { \ + DISPATCH_TRANSPORT_TYPE(ShmemPutMemNbi##Scope##Kernel, pe, dest, destOffset, source, \ + sourceOffset, bytes, pe, qpId); \ + } \ + inline __device__ void ShmemPutMemNbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, \ + const application::SymmMemObjPtr source, size_t sourceOffset, size_t bytes, int pe, \ + int qpId = 0) { \ + int rank = GetGlobalGpuStatesPtr()->rank; \ + ShmemPutMemNbi##Scope(dest, destOffset, source->GetRdmaMemoryRegion(rank), sourceOffset, \ + bytes, pe, qpId); \ } DEFINE_SHMEM_PUT_MEM_NBI_API_TEMPLATE(Thread) DEFINE_SHMEM_PUT_MEM_NBI_API_TEMPLATE(Warp) -#define DEFINE_SHMEM_PUT_TYPE_NBI_API_TEMPLATE(Scope) \ - template \ - inline __device__ void ShmemPutTypeNbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destElmOffset, \ - const application::RdmaMemoryRegion& source, size_t srcElmOffset, size_t nelems, int pe) { \ - constexpr size_t typeSize = sizeof(T); \ - ShmemPutMemNbi##Scope(dest, destElmOffset * typeSize, source, srcElmOffset * typeSize, \ - nelems * typeSize, pe); \ - } \ - template \ - inline __device__ void ShmemPutTypeNbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destElmOffset, \ - const application::SymmMemObjPtr source, size_t srcElmOffset, size_t nelems, int pe) { \ - int rank = GetGlobalGpuStatesPtr()->rank; \ - ShmemPutTypeNbi##Scope(dest, destElmOffset, source->GetRdmaMemoryRegion(rank), \ - srcElmOffset, nelems, pe); \ +#define DEFINE_SHMEM_PUT_TYPE_NBI_API_TEMPLATE(Scope) \ + template \ + inline __device__ void ShmemPutTypeNbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::RdmaMemoryRegion& source, size_t srcElmOffset, size_t nelems, int pe, \ + int qpId = 0) { \ + constexpr size_t typeSize = sizeof(T); \ + ShmemPutMemNbi##Scope(dest, destElmOffset * typeSize, source, srcElmOffset * typeSize, \ + nelems * typeSize, pe, qpId); \ + } \ + template \ + inline __device__ void ShmemPutTypeNbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::SymmMemObjPtr source, size_t srcElmOffset, size_t nelems, int pe, \ + int qpId = 0) { \ + int rank = GetGlobalGpuStatesPtr()->rank; \ + ShmemPutTypeNbi##Scope(dest, destElmOffset, source->GetRdmaMemoryRegion(rank), \ + srcElmOffset, nelems, pe, qpId); \ } DEFINE_SHMEM_PUT_TYPE_NBI_API_TEMPLATE(Thread) DEFINE_SHMEM_PUT_TYPE_NBI_API_TEMPLATE(Warp) -#define DEFINE_SHMEM_PUT_TYPE_NBI_API(TypeName, T, Scope) \ - inline __device__ void ShmemPut##TypeName##Nbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destElmOffset, \ - const application::RdmaMemoryRegion& source, size_t srcElmOffset, size_t nelems, int pe) { \ - ShmemPutTypeNbi##Scope(dest, destElmOffset, source, srcElmOffset, nelems, pe); \ - } \ - inline __device__ void ShmemPut##TypeName##Nbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destElmOffset, \ - const application::SymmMemObjPtr source, size_t srcElmOffset, size_t nelems, int pe) { \ - ShmemPutTypeNbi##Scope(dest, destElmOffset, source, srcElmOffset, nelems, pe); \ +#define DEFINE_SHMEM_PUT_TYPE_NBI_API(TypeName, T, Scope) \ + inline __device__ void ShmemPut##TypeName##Nbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::RdmaMemoryRegion& source, size_t srcElmOffset, size_t nelems, int pe, \ + int qpId = 0) { \ + ShmemPutTypeNbi##Scope(dest, destElmOffset, source, srcElmOffset, nelems, pe, qpId); \ + } \ + inline __device__ void ShmemPut##TypeName##Nbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destElmOffset, \ + const application::SymmMemObjPtr source, size_t srcElmOffset, size_t nelems, int pe, \ + int qpId = 0) { \ + ShmemPutTypeNbi##Scope(dest, destElmOffset, source, srcElmOffset, nelems, pe, qpId); \ } DEFINE_SHMEM_PUT_TYPE_NBI_API(Uint8, uint8_t, Thread) @@ -133,31 +153,33 @@ DEFINE_SHMEM_PUT_TYPE_NBI_API(Float, float, Warp) DEFINE_SHMEM_PUT_TYPE_NBI_API(Double, double, Warp) // TODO: deal with bytes count limit -#define SHMEM_PUT_SIZE_IMM_NBI_API(Scope) \ - inline __device__ void ShmemPutSizeImmNbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe) { \ - DISPATCH_TRANSPORT_TYPE(ShmemPutSizeImmNbi##Scope##Kernel, pe, dest, destOffset, val, bytes, \ - pe); \ +#define SHMEM_PUT_SIZE_IMM_NBI_API(Scope) \ + inline __device__ void ShmemPutSizeImmNbi##Scope(const application::SymmMemObjPtr dest, \ + size_t destOffset, void* val, size_t bytes, \ + int pe, int qpId = 0) { \ + DISPATCH_TRANSPORT_TYPE(ShmemPutSizeImmNbi##Scope##Kernel, pe, dest, destOffset, val, bytes, \ + pe, qpId); \ } SHMEM_PUT_SIZE_IMM_NBI_API(Thread) SHMEM_PUT_SIZE_IMM_NBI_API(Warp) -#define SHMEM_PUT_TYPE_IMM_NBI_API_TEMPLATE(Scope) \ - template \ - inline __device__ void ShmemPutTypeImmNbi##Scope(const application::SymmMemObjPtr dest, \ - size_t destOffset, T val, int pe) { \ - static_assert(sizeof(T) <= core::MaxInlineDataSizePerWqe); \ - ShmemPutSizeImmNbi##Scope(dest, destOffset, &val, sizeof(T), pe); \ +#define SHMEM_PUT_TYPE_IMM_NBI_API_TEMPLATE(Scope) \ + template \ + inline __device__ void ShmemPutTypeImmNbi##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, T val, int pe, int qpId = 0) { \ + static_assert(sizeof(T) <= core::MaxInlineDataSizePerWqe); \ + ShmemPutSizeImmNbi##Scope(dest, destOffset, &val, sizeof(T), pe, qpId); \ } SHMEM_PUT_TYPE_IMM_NBI_API_TEMPLATE(Thread) SHMEM_PUT_TYPE_IMM_NBI_API_TEMPLATE(Warp) -#define DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(TypeName, T, Scope) \ - inline __device__ void ShmemPut##TypeName##ImmNbi##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, uint32_t val, int pe) { \ - ShmemPutTypeImmNbi##Scope(dest, destOffset, val, pe); \ +#define DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(TypeName, T, Scope) \ + inline __device__ void ShmemPut##TypeName##ImmNbi##Scope(const application::SymmMemObjPtr dest, \ + size_t destOffset, uint32_t val, \ + int pe, int qpId = 0) { \ + ShmemPutTypeImmNbi##Scope(dest, destOffset, val, pe, qpId); \ } DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Uint8, uint8_t, Thread) @@ -178,26 +200,23 @@ DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Int32, int32_t, Warp) DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Uint64, uint64_t, Warp) DEFINE_SHMEM_PUT_TYPE_IMM_NBI_API(Int64, int64_t, Warp) -#define SHMEM_ATOMIC_SIZE_NONFETCH_API_TEMPLATE(Scope) \ - inline __device__ void ShmemAtomicSizeNonFetch##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, \ - int pe, core::atomicType amoType) { \ - DISPATCH_TRANSPORT_TYPE(ShmemAtomicSizeNonFetch##Scope##Kernel, pe, dest, destOffset, source, \ - sourceOffset, val, bytes, pe, amoType); \ +#define SHMEM_ATOMIC_SIZE_NONFETCH_API_TEMPLATE(Scope) \ + inline __device__ void ShmemAtomicSizeNonFetch##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, \ + core::atomicType amoType, int pe, int qpId = 0) { \ + DISPATCH_TRANSPORT_TYPE(ShmemAtomicSizeNonFetch##Scope##Kernel, pe, dest, destOffset, val, \ + bytes, amoType, pe, qpId); \ } SHMEM_ATOMIC_SIZE_NONFETCH_API_TEMPLATE(Thread) SHMEM_ATOMIC_SIZE_NONFETCH_API_TEMPLATE(Warp) -#define SHMEM_ATOMIC_TYPE_NONFETCH_API_TEMPLATE(Scope) \ - template \ - inline __device__ void ShmemAtomicTypeNonFetch##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, T val, int pe, \ - core::atomicType amoType) { \ - ShmemAtomicSizeNonFetch##Scope(dest, destOffset, source, sourceOffset, &val, sizeof(T), pe, \ - amoType); \ +#define SHMEM_ATOMIC_TYPE_NONFETCH_API_TEMPLATE(Scope) \ + template \ + inline __device__ void ShmemAtomicTypeNonFetch##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, T val, core::atomicType amoType, \ + int pe, int qpId = 0) { \ + ShmemAtomicSizeNonFetch##Scope(dest, destOffset, &val, sizeof(T), amoType, pe, qpId); \ } SHMEM_ATOMIC_TYPE_NONFETCH_API_TEMPLATE(Thread) @@ -205,10 +224,9 @@ SHMEM_ATOMIC_TYPE_NONFETCH_API_TEMPLATE(Warp) #define DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(TypeName, T, Scope) \ inline __device__ void ShmemAtomic##TypeName##NonFetch##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, T val, int pe, \ - core::atomicType amoType) { \ - ShmemAtomicTypeNonFetch##Scope(dest, destOffset, source, sourceOffset, val, pe, amoType); \ + const application::SymmMemObjPtr dest, size_t destOffset, T val, core::atomicType amoType, \ + int pe, int qpId = 0) { \ + ShmemAtomicTypeNonFetch##Scope(dest, destOffset, val, amoType, pe, qpId); \ } DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Uint32, uint32_t, Thread) @@ -221,40 +239,25 @@ DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Uint64, uint64_t, Warp) DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Int32, int32_t, Warp) DEFINE_SHMEM_ATOMIC_TYPE_NONFETCH_API(Int64, int64_t, Warp) -#define SHMEM_ATOMIC_SIZE_FETCH_API_TEMPLATE(Scope) \ - inline __device__ void ShmemAtomicSizeFetch##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, \ - size_t bytes, int pe, core::atomicType amoType) { \ - DISPATCH_TRANSPORT_TYPE(ShmemAtomicSizeFetch##Scope##Kernel, pe, dest, destOffset, source, \ - sourceOffset, val, compare, bytes, pe, amoType); \ - } - -SHMEM_ATOMIC_SIZE_FETCH_API_TEMPLATE(Thread) -SHMEM_ATOMIC_SIZE_FETCH_API_TEMPLATE(Warp) - -#define SHMEM_ATOMIC_TYPE_FETCH_API_TEMPLATE(Scope) \ - template \ - inline __device__ T ShmemAtomicTypeFetch##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, T val, T compare, int pe, \ - core::atomicType amoType) { \ - ShmemAtomicSizeFetch##Scope(dest, destOffset, source, sourceOffset, &val, &compare, sizeof(T), \ - pe, amoType); \ - uintptr_t fetchResultPtr = source.addr + sourceOffset; \ - return core::AtomicLoadRelaxedSystem(reinterpret_cast(fetchResultPtr)); \ +#define SHMEM_ATOMIC_TYPE_FETCH_API_TEMPLATE(Scope) \ + template \ + inline __device__ T ShmemAtomicTypeFetch##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, T val, T compare, \ + core::atomicType amoType, int pe, int qpId = 0) { \ + T result = DISPATCH_TRANSPORT_DATA_TYPE_WITH_RETURN(ShmemAtomicTypeFetch##Scope##Kernel, pe, \ + T, dest, destOffset, &val, &compare, \ + sizeof(T), amoType, pe, qpId); \ + return result; \ } SHMEM_ATOMIC_TYPE_FETCH_API_TEMPLATE(Thread) SHMEM_ATOMIC_TYPE_FETCH_API_TEMPLATE(Warp) -#define DEFINE_SHMEM_ATOMIC_TYPE_FETCH_API(TypeName, T, Scope) \ - inline __device__ T ShmemAtomic##TypeName##Fetch##Scope( \ - const application::SymmMemObjPtr dest, size_t destOffset, \ - const application::RdmaMemoryRegion& source, size_t sourceOffset, T val, T compare, int pe, \ - core::atomicType amoType) { \ - return ShmemAtomicTypeFetch##Scope(dest, destOffset, source, sourceOffset, val, compare, \ - pe, amoType); \ +#define DEFINE_SHMEM_ATOMIC_TYPE_FETCH_API(TypeName, T, Scope) \ + inline __device__ T ShmemAtomic##TypeName##Fetch##Scope( \ + const application::SymmMemObjPtr dest, size_t destOffset, T val, T compare, \ + core::atomicType amoType, int pe, int qpId = 0) { \ + return ShmemAtomicTypeFetch##Scope(dest, destOffset, val, compare, amoType, pe, qpId); \ } DEFINE_SHMEM_ATOMIC_TYPE_FETCH_API(Uint32, uint32_t, Thread) diff --git a/include/mori/shmem/shmem_device_kernels.hpp b/include/mori/shmem/shmem_device_kernels.hpp index 12957560..f2611545 100644 --- a/include/mori/shmem/shmem_device_kernels.hpp +++ b/include/mori/shmem/shmem_device_kernels.hpp @@ -33,51 +33,49 @@ template inline __device__ void ShmemPutMemNbiThreadKernel(const application::SymmMemObjPtr dest, size_t destOffset, const application::RdmaMemoryRegion& source, - size_t sourceOffset, size_t bytes, int pe); + size_t sourceOffset, size_t bytes, int pe, + int qpId = 0); template inline __device__ void ShmemPutMemNbiWarpKernel(const application::SymmMemObjPtr dest, size_t destOffset, const application::RdmaMemoryRegion& source, - size_t sourceOffset, size_t bytes, int pe); + size_t sourceOffset, size_t bytes, int pe, + int qpId = 0); template inline __device__ void ShmemPutSizeImmNbiThreadKernel(const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, - int pe); + int pe, int qpId = 0); template inline __device__ void ShmemPutSizeImmNbiWarpKernel(const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, - int pe); + int pe, int qpId = 0); template -inline __device__ void ShmemAtomicSizeNonFetchThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType); +inline __device__ void ShmemAtomicSizeNonFetchThreadKernel(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, + size_t bytes, core::atomicType amoType, + int pe, int qpId = 0); template -inline __device__ void ShmemAtomicSizeNonFetchWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType); +inline __device__ void ShmemAtomicSizeNonFetchWarpKernel(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, size_t bytes, + core::atomicType amoType, int pe, + int qpId = 0); -template -inline __device__ void ShmemAtomicSizeFetchThreadKernel(const application::SymmMemObjPtr dest, - size_t destOffset, - const application::RdmaMemoryRegion& source, - size_t sourceOffset, void* val, - void* compare, size_t bytes, int pe, - core::atomicType amoType); +template +inline __device__ T ShmemAtomicTypeFetchThreadKernel(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, void* compare, + size_t bytes, core::atomicType amoType, int pe, + int qpId = 0); -template -inline __device__ void ShmemAtomicSizeFetchWarpKernel(const application::SymmMemObjPtr dest, - size_t destOffset, - const application::RdmaMemoryRegion& source, - size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, - core::atomicType amoType); +template +inline __device__ T ShmemAtomicTypeFetchWarpKernel(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, void* compare, + size_t bytes, core::atomicType amoType, int pe, + int qpId = 0); /* ---------------------------------------------------------------------------------------------- */ /* Synchronization */ @@ -86,7 +84,7 @@ template inline __device__ void ShmemQuietThreadKernel(); template -inline __device__ void ShmemQuietThreadKernel(int pe); +inline __device__ void ShmemQuietThreadKernel(int pe, int qpId = 0); } // namespace shmem } // namespace mori diff --git a/include/mori/shmem/shmem_ibgda_kernels.hpp b/include/mori/shmem/shmem_ibgda_kernels.hpp index 4d183d40..0ec9b24d 100644 --- a/include/mori/shmem/shmem_ibgda_kernels.hpp +++ b/include/mori/shmem/shmem_ibgda_kernels.hpp @@ -62,17 +62,37 @@ namespace shmem { assert(false && "Unsupported or disabled provider type"); \ } +#define DISPATCH_PROVIDER_TYPE_COMPILE_TIME(func, ...) \ + do { \ + if constexpr (DISPATCH_BNXT == 1) { \ + func(__VA_ARGS__); \ + } else { \ + func(__VA_ARGS__); \ + } \ + } while (0) + +#define DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_RETURN(func, type, ...) \ + [&]() { \ + if constexpr (DISPATCH_BNXT == 1) { \ + return func(__VA_ARGS__); \ + } else { \ + return func(__VA_ARGS__); \ + } \ + }() + /* ---------------------------------------------------------------------------------------------- */ /* Synchronization */ /* ---------------------------------------------------------------------------------------------- */ template -inline __device__ void ShmemQuietThreadKernelSerialImpl(int pe) { - if (threadIdx.x != 0) return; +inline __device__ void ShmemQuietThreadKernelSerialImpl(int pe, int qpId) { + if (core::GetActiveLaneNum() != 0) return; GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; - core::CompletionQueueHandle& cq = ep[pe].cqHandle; - core::WorkQueueHandle& wq = ep[pe].wqHandle; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle& wq = ep[epIndex].wqHandle; + core::CompletionQueueHandle& cq = ep[epIndex].cqHandle; + if (!core::AcquireLockOnce(&cq.pollCqLock)) return; while (true) { bool done{false}; uint32_t quiet_amount{0}; @@ -83,7 +103,8 @@ inline __device__ void ShmemQuietThreadKernelSerialImpl(int pe) { uint32_t doneIdx = __hip_atomic_load(&wq.doneIdx, __ATOMIC_SEQ_CST, __HIP_MEMORY_SCOPE_AGENT); // printf("dbTouchIdx: %u, doneIdx: %u\n", dbTouchIdx, doneIdx); if (dbTouchIdx == doneIdx) { - return; + // core::ReleaseLock(&cq.pollCqLock); + break; } my_cq_consumer = @@ -105,7 +126,6 @@ inline __device__ void ShmemQuietThreadKernelSerialImpl(int pe) { if (opcode != BNXT_RE_REQ_ST_OK) { int rank = globalGpuStates->rank; uint32_t my_cq_index = my_cq_consumer % cq.cqeNum; - printf("rank %d dest pe %d consIdx %d opcode %d\n", rank, pe, my_cq_index, opcode); assert(false); } wqe_counter = (wqe_counter + wq.sqWqeNum - 1) % wq.sqWqeNum; @@ -117,100 +137,95 @@ inline __device__ void ShmemQuietThreadKernelSerialImpl(int pe) { __atomic_signal_fence(__ATOMIC_SEQ_CST); __hip_atomic_fetch_max(&wq.doneIdx, wqe_id, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } + core::ReleaseLock(&cq.pollCqLock); } template -inline __device__ void ShmemQuietThreadKernelImpl(int pe) { +inline __device__ void ShmemQuietThreadKernelImpl(int pe, int qpId) { if constexpr (PrvdType == core::ProviderType::BNXT) { - ShmemQuietThreadKernelSerialImpl(pe); + ShmemQuietThreadKernelSerialImpl(pe, qpId); return; - } - GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); - application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; - core::CompletionQueueHandle& cq = ep[pe].cqHandle; - core::WorkQueueHandle& wq = ep[pe].wqHandle; - - constexpr size_t BROADCAST_SIZE = 1024 / warpSize; - __shared__ uint64_t wqe_broadcast[BROADCAST_SIZE]; - uint8_t warp_id = core::FlatBlockThreadId() / warpSize; - wqe_broadcast[warp_id] = 0; - - uint64_t activemask = core::GetActiveLaneMask(); - uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); - uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); - bool is_leader{my_logical_lane_id == 0}; - const uint64_t leader_phys_lane_id = core::GetFirstActiveLaneID(activemask); - - while (true) { - bool done{false}; - uint32_t quiet_amount{0}; - uint32_t warp_cq_consumer{0}; - while (!done) { - uint32_t active = - __hip_atomic_load(&cq.activeIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint32_t posted = - __hip_atomic_load(&cq.needConsIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint32_t completed = - __hip_atomic_load(&cq.consIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - if (!(posted - completed)) { - return; - } - int32_t quiet_val = posted - active; - if (quiet_val <= 0) { - continue; - } - quiet_amount = min(num_active_lanes, quiet_val); - if (is_leader) { - done = __hip_atomic_compare_exchange_strong(&cq.activeIdx, &active, active + quiet_amount, - __ATOMIC_RELAXED, __ATOMIC_RELAXED, - __HIP_MEMORY_SCOPE_AGENT); - if (done) { - warp_cq_consumer = __hip_atomic_fetch_add(&cq.cq_consumer, quiet_amount, __ATOMIC_RELAXED, - __HIP_MEMORY_SCOPE_AGENT); + } else { + GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); + application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle& wq = ep[epIndex].wqHandle; + core::CompletionQueueHandle& cq = ep[epIndex].cqHandle; + + constexpr size_t BROADCAST_SIZE = 1024 / warpSize; + __shared__ uint64_t wqe_broadcast[BROADCAST_SIZE]; + uint8_t warp_id = core::FlatBlockThreadId() / warpSize; + wqe_broadcast[warp_id] = 0; + + uint64_t activemask = core::GetActiveLaneMask(); + uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); + uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); + bool is_leader{my_logical_lane_id == 0}; + const uint64_t leader_phys_lane_id = core::GetFirstActiveLaneID(activemask); + + while (true) { + bool done{false}; + uint32_t quiet_amount{0}; + uint32_t warp_cq_consumer{0}; + while (!done) { + uint32_t active = + __hip_atomic_load(&cq.activeIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + uint32_t posted = + __hip_atomic_load(&cq.needConsIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + uint32_t completed = + __hip_atomic_load(&cq.consIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + if (!(posted - completed)) { + return; + } + int32_t quiet_val = posted - active; + if (quiet_val <= 0) { + continue; } + quiet_amount = min(num_active_lanes, quiet_val); + if (is_leader) { + done = __hip_atomic_compare_exchange_strong(&cq.activeIdx, &active, active + quiet_amount, + __ATOMIC_RELAXED, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + if (done) { + warp_cq_consumer = __hip_atomic_fetch_add(&cq.cq_consumer, quiet_amount, + __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } + } + done = __shfl(done, leader_phys_lane_id); } - done = __shfl(done, leader_phys_lane_id); - } - warp_cq_consumer = __shfl(warp_cq_consumer, leader_phys_lane_id); - uint32_t my_cq_consumer = warp_cq_consumer + my_logical_lane_id; - uint32_t my_cq_index = my_cq_consumer % cq.cqeNum; - - if (my_logical_lane_id < quiet_amount) { - uint16_t wqe_counter; - int opcode = core::PollCq(cq.cqAddr, cq.cqeNum, &my_cq_consumer, &wqe_counter); - if constexpr (PrvdType == core::ProviderType::MLX5) { + warp_cq_consumer = __shfl(warp_cq_consumer, leader_phys_lane_id); + uint32_t my_cq_consumer = warp_cq_consumer + my_logical_lane_id; + uint32_t my_cq_index = my_cq_consumer % cq.cqeNum; + + if (my_logical_lane_id < quiet_amount) { + uint16_t wqe_counter; + int opcode = core::PollCq(cq.cqAddr, cq.cqeNum, &my_cq_consumer, &wqe_counter); if (opcode == MLX5_CQE_RESP_ERR || opcode == MLX5_CQE_REQ_ERR) { int rank = globalGpuStates->rank; printf("rank %d dest pe %d consIdx %d opcode %d\n", rank, pe, my_cq_index, opcode); core::DumpMlx5Wqe(wq.sqAddr, my_cq_index); assert(false); } - } else if constexpr (PrvdType == core::ProviderType::BNXT) { - if (opcode != BNXT_RE_REQ_ST_OK) { - int rank = globalGpuStates->rank; - printf("rank %d dest pe %d consIdx %d opcode %d\n", rank, pe, my_cq_index, opcode); - assert(false); - } - wqe_counter = (BNXT_RE_NUM_SLOT_PER_WQE * (wqe_counter + wq.sqWqeNum - 1) % wq.sqWqeNum); + uint64_t wqe_id = wq.outstandingWqe[wqe_counter]; + __hip_atomic_fetch_max(&wqe_broadcast[warp_id], wqe_id, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_WORKGROUP); + __atomic_signal_fence(__ATOMIC_SEQ_CST); + } + if (is_leader) { + uint64_t completed{0}; + do { + completed = __hip_atomic_load(&cq.consIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + } while (completed != warp_cq_consumer); + + core::UpdateCqDbrRecord(cq.dbrRecAddr, + (uint32_t)(warp_cq_consumer + quiet_amount), cq.cqeNum); + + __atomic_signal_fence(__ATOMIC_SEQ_CST); + uint64_t doneIdx = wqe_broadcast[warp_id]; + __hip_atomic_fetch_max(&wq.doneIdx, doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + __hip_atomic_fetch_add(&cq.consIdx, quiet_amount, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); } - uint64_t wqe_id = wq.outstandingWqe[wqe_counter]; - __hip_atomic_fetch_max(&wqe_broadcast[warp_id], wqe_id, __ATOMIC_RELAXED, - __HIP_MEMORY_SCOPE_WORKGROUP); - __atomic_signal_fence(__ATOMIC_SEQ_CST); - } - if (is_leader) { - uint64_t completed{0}; - do { - completed = __hip_atomic_load(&cq.consIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - } while (completed != warp_cq_consumer); - - core::UpdateCqDbrRecord(cq.dbrRecAddr, (uint32_t)(warp_cq_consumer + quiet_amount), - cq.cqeNum); - - __atomic_signal_fence(__ATOMIC_SEQ_CST); - uint64_t doneIdx = wqe_broadcast[warp_id]; - __hip_atomic_fetch_max(&wq.doneIdx, doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_fetch_add(&cq.consIdx, quiet_amount, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } } } @@ -218,24 +233,25 @@ inline __device__ void ShmemQuietThreadKernelImpl(int pe) { template <> inline __device__ void ShmemQuietThreadKernel() { GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); - application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; int rank = globalGpuStates->rank; int worldSize = globalGpuStates->worldSize; - for (int pe = blockIdx.x; pe < worldSize; pe += gridDim.x) { - if (pe != rank && globalGpuStates->transportTypes[pe] == application::TransportType::RDMA) { - DISPATCH_PROVIDER_TYPE_EP(ep, ShmemQuietThreadKernelImpl, pe); + for (int peId = 0; peId < worldSize; peId++) { + if (peId != rank && globalGpuStates->transportTypes[peId] == application::TransportType::RDMA) { + for (int qpId = 0; qpId < globalGpuStates->numQpPerPe; qpId++) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemQuietThreadKernelImpl, peId, qpId); + } } } } template <> -inline __device__ void ShmemQuietThreadKernel(int pe) { +inline __device__ void ShmemQuietThreadKernel(int pe, int qpId) { GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; int rank = globalGpuStates->rank; if (pe == rank) return; if (globalGpuStates->transportTypes[pe] != application::TransportType::RDMA) return; - DISPATCH_PROVIDER_TYPE_EP(ep, ShmemQuietThreadKernelImpl, pe); + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemQuietThreadKernelImpl, pe, qpId); } /* ---------------------------------------------------------------------------------------------- */ @@ -245,7 +261,8 @@ template inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMemObjPtr dest, size_t destOffset, const application::RdmaMemoryRegion& source, - size_t sourceOffset, size_t bytes, int pe) { + size_t sourceOffset, size_t bytes, int pe, + int qpId) { if (bytes == 0) return; uintptr_t laddr = source.addr + sourceOffset; uintptr_t raddr = dest->peerPtrs[pe] + destOffset; @@ -253,15 +270,16 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; - core::WorkQueueHandle* wq = &ep[pe].wqHandle; - core::CompletionQueueHandle* cq = &ep[pe].cqHandle; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle* wq = &ep[epIndex].wqHandle; + core::CompletionQueueHandle* cq = &ep[epIndex].cqHandle; + uint32_t qpn = ep[epIndex].handle.qpn; uint64_t activemask = core::GetActiveLaneMask(); uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); bool is_leader{my_logical_lane_id == num_active_lanes - 1}; const uint64_t leader_phys_lane_id = core::GetLastActiveLaneID(activemask); - uint8_t num_wqes{num_active_lanes}; uint32_t warp_sq_counter{0}; uint32_t warp_msntbl_counter{0}, warp_psn_counter{0}; uint32_t my_sq_counter{0}, my_msntbl_counter{0}, my_psn_counter{0}; @@ -272,13 +290,13 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem } if (is_leader) { if constexpr (PrvdType == core::ProviderType::MLX5) { - warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED, + warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } else if constexpr (PrvdType == core::ProviderType::BNXT) { - core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_wqes, psnCnt * num_wqes, + core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_active_lanes, psnCnt * num_active_lanes, &warp_msntbl_counter, &warp_psn_counter); warp_sq_counter = warp_msntbl_counter; - __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } else { assert(false); @@ -299,26 +317,25 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem while (true) { uint64_t db_touched = - __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); + __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); + uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); uint64_t num_active_sq_entries = db_touched - db_done; uint64_t num_free_entries = wq->sqWqeNum - num_active_sq_entries; uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_active_lanes - db_touched; if (num_free_entries > num_entries_until_warp_last_entry) { break; } - ShmemQuietThreadKernelImpl(pe); + ShmemQuietThreadKernelImpl(pe, qpId); } uint64_t dbr_val; if constexpr (PrvdType == core::ProviderType::MLX5) { wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; dbr_val = core::PostWrite(*wq, my_sq_counter, my_sq_counter, my_sq_counter, is_leader, - ep[pe].handle.qpn, laddr, source.lkey, raddr, rkey, bytes); + qpn, laddr, source.lkey, raddr, rkey, bytes); } else if constexpr (PrvdType == core::ProviderType::BNXT) { wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter; - dbr_val = - core::PostWrite(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, is_leader, - ep[pe].handle.qpn, laddr, source.lkey, raddr, rkey, bytes); + dbr_val = core::PostWrite(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, + is_leader, qpn, laddr, source.lkey, raddr, rkey, bytes); } else { assert(false); } @@ -329,13 +346,13 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } while (db_touched != warp_sq_counter); - core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_wqes); + core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_active_lanes); // __threadfence_system(); core::RingDoorbell(wq->dbrAddr, dbr_val); - __threadfence_system(); + // __threadfence_system(); __hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } // __threadfence_system(); @@ -344,28 +361,42 @@ inline __device__ void ShmemPutMemNbiThreadKernelImpl(const application::SymmMem template <> inline __device__ void ShmemPutMemNbiThreadKernel( const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe) { - DISPATCH_PROVIDER_TYPE(ShmemPutMemNbiThreadKernelImpl, dest, destOffset, source, sourceOffset, - bytes, pe); + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe, + int qpId) { + bool need_turn{true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemPutMemNbiThreadKernelImpl, dest, destOffset, source, + sourceOffset, bytes, pe, qpId); + need_turn = false; + } + turns = __ballot(need_turn); + } } template inline __device__ void ShmemPutMemNbiWarpKernelImpl(const application::SymmMemObjPtr dest, size_t destOffset, const application::RdmaMemoryRegion& source, - size_t sourceOffset, size_t bytes, int pe) { + size_t sourceOffset, size_t bytes, int pe, + int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) { - ShmemPutMemNbiThreadKernelImpl(dest, destOffset, source, sourceOffset, bytes, pe); + ShmemPutMemNbiThreadKernelImpl(dest, destOffset, source, sourceOffset, bytes, pe, + qpId); } } template <> inline __device__ void ShmemPutMemNbiWarpKernel( const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe) { - DISPATCH_PROVIDER_TYPE(ShmemPutMemNbiWarpKernelImpl, dest, destOffset, source, sourceOffset, - bytes, pe); + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe, + int qpId) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemPutMemNbiWarpKernelImpl, dest, destOffset, source, + sourceOffset, bytes, pe, qpId); } // TODO: deal with bytes count limit @@ -373,7 +404,7 @@ inline __device__ void ShmemPutMemNbiWarpKernel inline __device__ void ShmemPutSizeImmNbiThreadKernelImpl(const application::SymmMemObjPtr dest, size_t destOffset, void* val, - size_t bytes, int pe) { + size_t bytes, int pe, int qpId) { if (bytes == 0) return; uintptr_t raddr = dest->peerPtrs[pe] + destOffset; @@ -381,32 +412,33 @@ inline __device__ void ShmemPutSizeImmNbiThreadKernelImpl(const application::Sym GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; - core::WorkQueueHandle* wq = &ep[pe].wqHandle; - core::CompletionQueueHandle* cq = &ep[pe].cqHandle; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle* wq = &ep[epIndex].wqHandle; + core::CompletionQueueHandle* cq = &ep[epIndex].cqHandle; + uint32_t qpn = ep[epIndex].handle.qpn; uint64_t activemask = core::GetActiveLaneMask(); uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); uint8_t my_logical_lane_id = core::GetActiveLaneNum(activemask); bool is_leader{my_logical_lane_id == num_active_lanes - 1}; const uint64_t leader_phys_lane_id = core::GetLastActiveLaneID(activemask); - uint8_t num_wqes{num_active_lanes}; uint32_t warp_sq_counter{0}; uint32_t warp_msntbl_counter{0}, warp_psn_counter{0}; uint32_t my_sq_counter{0}, my_msntbl_counter{0}, my_psn_counter{0}; if constexpr (PrvdType == core::ProviderType::MLX5) { if (is_leader) { - warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED, + warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); my_sq_counter = warp_sq_counter + my_logical_lane_id; } else if constexpr (PrvdType == core::ProviderType::BNXT) { if (is_leader) { - core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_wqes, num_wqes, &warp_msntbl_counter, - &warp_psn_counter); + core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_active_lanes, num_active_lanes, + &warp_msntbl_counter, &warp_psn_counter); warp_sq_counter = warp_msntbl_counter; - __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); @@ -429,20 +461,18 @@ inline __device__ void ShmemPutSizeImmNbiThreadKernelImpl(const application::Sym if (num_free_entries > num_entries_until_warp_last_entry) { break; } - ShmemQuietThreadKernelImpl(pe); + ShmemQuietThreadKernelImpl(pe, qpId); } uint64_t dbr_val; if constexpr (PrvdType == core::ProviderType::MLX5) { wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; - dbr_val = - core::PostWriteInline(*wq, my_sq_counter, my_sq_counter, my_sq_counter, is_leader, - ep[pe].handle.qpn, val, raddr, rkey, bytes); + dbr_val = core::PostWriteInline(*wq, my_sq_counter, my_sq_counter, my_sq_counter, + is_leader, qpn, val, raddr, rkey, bytes); } else if constexpr (PrvdType == core::ProviderType::BNXT) { wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter; - dbr_val = - core::PostWriteInline(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, - is_leader, ep[pe].handle.qpn, val, raddr, rkey, bytes); + dbr_val = core::PostWriteInline(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, + is_leader, qpn, val, raddr, rkey, bytes); } else { assert(false); } @@ -453,13 +483,13 @@ inline __device__ void ShmemPutSizeImmNbiThreadKernelImpl(const application::Sym db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } while (db_touched != warp_sq_counter); - core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_wqes); + core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_active_lanes); // __threadfence_system(); core::RingDoorbell(wq->dbrAddr, dbr_val); - __threadfence_system(); + // __threadfence_system(); __hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } // __threadfence_system(); @@ -467,41 +497,58 @@ inline __device__ void ShmemPutSizeImmNbiThreadKernelImpl(const application::Sym template <> inline __device__ void ShmemPutSizeImmNbiThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe) { - DISPATCH_PROVIDER_TYPE(ShmemPutSizeImmNbiThreadKernelImpl, dest, destOffset, val, bytes, pe); + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe, + int qpId) { + bool need_turn{true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemPutSizeImmNbiThreadKernelImpl, dest, destOffset, val, + bytes, pe, qpId); + need_turn = false; + } + turns = __ballot(need_turn); + } } template inline __device__ void ShmemPutSizeImmNbiWarpKernelImpl(const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, - int pe) { + int pe, int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) { - ShmemPutSizeImmNbiThreadKernelImpl(dest, destOffset, val, bytes, pe); + ShmemPutSizeImmNbiThreadKernelImpl(dest, destOffset, val, bytes, pe, qpId); } } template <> inline __device__ void ShmemPutSizeImmNbiWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe) { - DISPATCH_PROVIDER_TYPE(ShmemPutSizeImmNbiWarpKernelImpl, dest, destOffset, val, bytes, pe); + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe, + int qpId) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemPutSizeImmNbiWarpKernelImpl, dest, destOffset, val, + bytes, pe, qpId); } template inline __device__ void ShmemAtomicSizeNonFetchThreadKernelImpl( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType) { + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, + core::atomicType amoType, int pe, int qpId) { if (bytes == 0) return; - uintptr_t raddr = dest->peerPtrs[pe] + destOffset; - uintptr_t rkey = dest->peerRkeys[pe]; - uintptr_t laddr = source.addr + sourceOffset; - uintptr_t lkey = source.lkey; GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; - core::WorkQueueHandle* wq = &ep[pe].wqHandle; - core::CompletionQueueHandle* cq = &ep[pe].cqHandle; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle* wq = &ep[epIndex].wqHandle; + core::CompletionQueueHandle* cq = &ep[epIndex].cqHandle; + uint32_t qpn = ep[epIndex].handle.qpn; + core::IbufHandle* ibuf = &ep[epIndex].atomicIbuf; + + uintptr_t raddr = dest->peerPtrs[pe] + destOffset; + uintptr_t rkey = dest->peerRkeys[pe]; + uintptr_t laddr = ibuf->addr; + uintptr_t lkey = ibuf->lkey; uint64_t activemask = core::GetActiveLaneMask(); uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); @@ -512,24 +559,20 @@ inline __device__ void ShmemAtomicSizeNonFetchThreadKernelImpl( uint32_t warp_sq_counter = 0; uint32_t warp_msntbl_counter = 0, warp_psn_counter = 0; uint32_t my_sq_counter = 0, my_msntbl_counter = 0, my_psn_counter = 0; - uint8_t num_wqes; if constexpr (PrvdType == core::ProviderType::MLX5) { - uint32_t numWqesPerCmd = core::get_num_wqes_in_atomic(amoType, bytes); - num_wqes = num_active_lanes * numWqesPerCmd; if (is_leader) { - warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED, + warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); - my_sq_counter = warp_sq_counter + my_logical_lane_id * numWqesPerCmd; + my_sq_counter = warp_sq_counter + my_logical_lane_id; } else if constexpr (PrvdType == core::ProviderType::BNXT) { - num_wqes = num_active_lanes; if (is_leader) { - core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_wqes, num_wqes, &warp_msntbl_counter, - &warp_psn_counter); + core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_active_lanes, num_active_lanes, + &warp_msntbl_counter, &warp_psn_counter); warp_sq_counter = warp_msntbl_counter; - __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); @@ -548,27 +591,26 @@ inline __device__ void ShmemAtomicSizeNonFetchThreadKernelImpl( uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); uint64_t num_active_sq_entries = db_touched - db_done; uint64_t num_free_entries = wq->sqWqeNum - num_active_sq_entries; - uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_wqes - db_touched; + uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_active_lanes - db_touched; if (num_free_entries > num_entries_until_warp_last_entry) break; - ShmemQuietThreadKernelImpl(pe); + ShmemQuietThreadKernelImpl(pe, qpId); } if constexpr (PrvdType == core::ProviderType::MLX5) { - wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = - my_sq_counter + core::get_num_wqes_in_atomic(amoType, bytes) - 1; + wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; } else if constexpr (PrvdType == core::ProviderType::BNXT) { wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter; } uint64_t dbr_val; if constexpr (PrvdType == core::ProviderType::MLX5) { - dbr_val = core::PostAtomic(*wq, my_sq_counter, my_sq_counter, my_sq_counter, - is_leader, ep[pe].handle.qpn, laddr, lkey, raddr, rkey, - val, val, bytes, amoType); + dbr_val = + core::PostAtomic(*wq, my_sq_counter, my_sq_counter, my_sq_counter, is_leader, qpn, + laddr, lkey, raddr, rkey, val, val, bytes, amoType); } else if constexpr (PrvdType == core::ProviderType::BNXT) { - dbr_val = core::PostAtomic(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, - is_leader, ep[pe].handle.qpn, laddr, lkey, raddr, rkey, - val, val, bytes, amoType); + dbr_val = + core::PostAtomic(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, is_leader, + qpn, laddr, lkey, raddr, rkey, val, val, bytes, amoType); } // __threadfence_system(); @@ -578,12 +620,12 @@ inline __device__ void ShmemAtomicSizeNonFetchThreadKernelImpl( db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } while (db_touched != warp_sq_counter); - core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_wqes); - __threadfence_system(); + core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_active_lanes); + // __threadfence_system(); core::RingDoorbell(wq->dbrAddr, dbr_val); __hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } @@ -592,50 +634,76 @@ inline __device__ void ShmemAtomicSizeNonFetchThreadKernelImpl( template <> inline __device__ void ShmemAtomicSizeNonFetchThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType) { - DISPATCH_PROVIDER_TYPE(ShmemAtomicSizeNonFetchThreadKernelImpl, dest, destOffset, source, - sourceOffset, val, bytes, pe, amoType); + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, + core::atomicType amoType, int pe, int qpId) { + bool need_turn{true}; + uint64_t turns = __ballot(need_turn); + while (turns) { + uint8_t lane = __ffsll((unsigned long long)turns) - 1; + int pe_turn = __shfl(pe, lane); + if (pe_turn == pe) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemAtomicSizeNonFetchThreadKernelImpl, dest, destOffset, + val, bytes, amoType, pe, qpId); + need_turn = false; + } + turns = __ballot(need_turn); + } } template -inline __device__ void ShmemAtomicSizeNonFetchWarpKernelImpl( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType) { +inline __device__ void ShmemAtomicSizeNonFetchWarpKernelImpl(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, + size_t bytes, core::atomicType amoType, + int pe, int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) { - ShmemAtomicSizeNonFetchThreadKernelImpl(dest, destOffset, source, sourceOffset, val, - bytes, pe, amoType); + ShmemAtomicSizeNonFetchThreadKernelImpl(dest, destOffset, val, bytes, amoType, pe, + qpId); } - // ShmemQuietThreadKernelImpl(pe); } template <> inline __device__ void ShmemAtomicSizeNonFetchWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType) { - DISPATCH_PROVIDER_TYPE(ShmemAtomicSizeNonFetchWarpKernelImpl, dest, destOffset, source, - sourceOffset, val, bytes, pe, amoType); + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, + core::atomicType amoType, int pe, int qpId) { + DISPATCH_PROVIDER_TYPE_COMPILE_TIME(ShmemAtomicSizeNonFetchWarpKernelImpl, dest, destOffset, val, + bytes, amoType, pe, qpId); } -template -inline __device__ void ShmemAtomicSizeFetchThreadKernelImpl( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, core::atomicType amoType) { - if (bytes == 0) return; - uintptr_t raddr = dest->peerPtrs[pe] + destOffset; - uintptr_t rkey = dest->peerRkeys[pe]; - uintptr_t laddr = source.addr + sourceOffset; - uintptr_t lkey = source.lkey; +inline __device__ uint32_t ShmemGetAtomicIbufSlot(core::IbufHandle* ibuf, uint32_t num_slots = 1) { + uint32_t base_slot = atomicAdd(&ibuf->head, num_slots); + uint32_t nslots = ibuf->nslots; + uint32_t last_slot = base_slot + num_slots; + while (last_slot - __hip_atomic_load(&ibuf->tail, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) > + nslots) { + ; + } + __threadfence_block(); + return base_slot; +} +inline __device__ void ShmemReleaseAtomicIbufSlot(core::IbufHandle* ibuf, uint32_t base_slots, + uint32_t num_slots) { + uint32_t last_slot = base_slots + num_slots; + while (atomicCAS(&ibuf->tail, base_slots, last_slot) != base_slots) { + ; + } + __threadfence_block(); +} + +template +inline __device__ T ShmemAtomicTypeFetchThreadKernelImpl(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, + void* compare, size_t bytes, + core::atomicType amoType, int pe, + int qpId) { GpuStates* globalGpuStates = GetGlobalGpuStatesPtr(); application::RdmaEndpoint* ep = globalGpuStates->rdmaEndpoints; - core::WorkQueueHandle* wq = &ep[pe].wqHandle; - core::CompletionQueueHandle* cq = &ep[pe].cqHandle; + int epIndex = pe * globalGpuStates->numQpPerPe + (qpId % globalGpuStates->numQpPerPe); + core::WorkQueueHandle* wq = &ep[epIndex].wqHandle; + core::CompletionQueueHandle* cq = &ep[epIndex].cqHandle; + uint32_t qpn = ep[epIndex].handle.qpn; + core::IbufHandle* ibuf = &ep[epIndex].atomicIbuf; uint64_t activemask = core::GetActiveLaneMask(); uint8_t num_active_lanes = core::GetActiveLaneCount(activemask); @@ -643,27 +711,34 @@ inline __device__ void ShmemAtomicSizeFetchThreadKernelImpl( bool is_leader = (my_logical_lane_id == num_active_lanes - 1); uint64_t leader_phys_lane_id = core::GetLastActiveLaneID(activemask); + uint32_t base_slot = 0; + if (is_leader) { + base_slot = ShmemGetAtomicIbufSlot(ibuf, num_active_lanes); + } + uint32_t my_slot = __shfl(base_slot, leader_phys_lane_id) + my_logical_lane_id; + uint32_t my_slot_index = my_slot & (ibuf->nslots - 1); + uintptr_t laddr = ibuf->addr + (my_slot_index + 1) * application::ATOMIC_IBUF_SLOT_SIZE; + uintptr_t lkey = ibuf->lkey; + uintptr_t raddr = dest->peerPtrs[pe] + destOffset; + uintptr_t rkey = dest->peerRkeys[pe]; + uint32_t warp_sq_counter = 0; uint32_t warp_msntbl_counter = 0, warp_psn_counter = 0; uint32_t my_sq_counter = 0, my_msntbl_counter = 0, my_psn_counter = 0; - uint8_t num_wqes; if constexpr (PrvdType == core::ProviderType::MLX5) { - uint32_t numWqesPerCmd = core::get_num_wqes_in_atomic(amoType, bytes); - num_wqes = num_active_lanes * numWqesPerCmd; if (is_leader) { - warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_wqes, __ATOMIC_RELAXED, + warp_sq_counter = __hip_atomic_fetch_add(&wq->postIdx, num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); - my_sq_counter = warp_sq_counter + my_logical_lane_id * numWqesPerCmd; + my_sq_counter = warp_sq_counter + my_logical_lane_id; } else if constexpr (PrvdType == core::ProviderType::BNXT) { - num_wqes = num_active_lanes; if (is_leader) { - core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_wqes, num_wqes, &warp_msntbl_counter, - &warp_psn_counter); + core::atomic_add_packed_msn_and_psn(&wq->msnPack, num_active_lanes, num_active_lanes, + &warp_msntbl_counter, &warp_psn_counter); warp_sq_counter = warp_msntbl_counter; - __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_fetch_max(&wq->postIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } warp_sq_counter = __shfl(warp_sq_counter, leader_phys_lane_id); @@ -682,27 +757,26 @@ inline __device__ void ShmemAtomicSizeFetchThreadKernelImpl( uint64_t db_done = __hip_atomic_load(&wq->doneIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); uint64_t num_active_sq_entries = db_touched - db_done; uint64_t num_free_entries = wq->sqWqeNum - num_active_sq_entries; - uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_wqes - db_touched; + uint64_t num_entries_until_warp_last_entry = warp_sq_counter + num_active_lanes - db_touched; if (num_free_entries > num_entries_until_warp_last_entry) break; - ShmemQuietThreadKernelImpl(pe); + ShmemQuietThreadKernelImpl(pe, qpId); } if constexpr (PrvdType == core::ProviderType::MLX5) { - wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = - my_sq_counter + core::get_num_wqes_in_atomic(amoType, bytes) - 1; + wq->outstandingWqe[my_sq_counter % OUTSTANDING_TABLE_SIZE] = my_sq_counter; } else if constexpr (PrvdType == core::ProviderType::BNXT) { wq->outstandingWqe[my_sq_counter % wq->sqWqeNum] = my_sq_counter; } uint64_t dbr_val; if constexpr (PrvdType == core::ProviderType::MLX5) { - dbr_val = core::PostAtomic(*wq, my_sq_counter, my_sq_counter, my_sq_counter, - is_leader, ep[pe].handle.qpn, laddr, lkey, raddr, rkey, - val, compare, bytes, amoType); + dbr_val = + core::PostAtomic(*wq, my_sq_counter, my_sq_counter, my_sq_counter, is_leader, qpn, + laddr, lkey, raddr, rkey, val, compare, bytes, amoType); } else if constexpr (PrvdType == core::ProviderType::BNXT) { - dbr_val = core::PostAtomic(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, - is_leader, ep[pe].handle.qpn, laddr, lkey, raddr, rkey, - val, compare, bytes, amoType); + dbr_val = + core::PostAtomic(*wq, my_sq_counter, my_msntbl_counter, my_psn_counter, is_leader, + qpn, laddr, lkey, raddr, rkey, val, compare, bytes, amoType); } // __threadfence_system(); @@ -712,48 +786,80 @@ inline __device__ void ShmemAtomicSizeFetchThreadKernelImpl( db_touched = __hip_atomic_load(&wq->dbTouchIdx, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } while (db_touched != warp_sq_counter); - core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_wqes); - __threadfence_system(); + core::UpdateSendDbrRecord(wq->dbrRecAddr, warp_sq_counter + num_active_lanes); + // __threadfence_system(); core::RingDoorbell(wq->dbrAddr, dbr_val); __hip_atomic_fetch_add(&cq->needConsIdx, 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); - __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_wqes, __ATOMIC_RELAXED, + __hip_atomic_store(&wq->dbTouchIdx, warp_sq_counter + num_active_lanes, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } - // __threadfence_system(); - // ShmemQuietThreadKernelImpl(pe); -} + ShmemQuietThreadKernelImpl(pe, qpId); + T ret = *reinterpret_cast(laddr); + if (sizeof(T) == 4) ret = BSWAP32((uint32_t)ret); -template <> -inline __device__ void ShmemAtomicSizeFetchThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, core::atomicType amoType) { - DISPATCH_PROVIDER_TYPE(ShmemAtomicSizeFetchThreadKernelImpl, dest, destOffset, source, - sourceOffset, val, compare, bytes, pe, amoType); + if (is_leader) { + ShmemReleaseAtomicIbufSlot(ibuf, base_slot, num_active_lanes); + } + + return ret; } -template -inline __device__ void ShmemAtomicSizeFetchWarpKernelImpl( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, core::atomicType amoType) { +#define DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL(TypeName, T) \ + template <> \ + inline __device__ T ShmemAtomicTypeFetchThreadKernel( \ + const application::SymmMemObjPtr dest, size_t destOffset, void* val, void* compare, \ + size_t bytes, core::atomicType amoType, int pe, int qpId) { \ + bool need_turn{true}; \ + uint64_t turns = __ballot(need_turn); \ + T result{}; \ + while (turns) { \ + uint8_t lane = __ffsll((unsigned long long)turns) - 1; \ + int pe_turn = __shfl(pe, lane); \ + if (pe_turn == pe) { \ + result = DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_RETURN( \ + ShmemAtomicTypeFetchThreadKernelImpl, T, dest, destOffset, val, compare, bytes, \ + amoType, pe, qpId); \ + need_turn = false; \ + } \ + turns = __ballot(need_turn); \ + } \ + return result; \ + } + +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL(Uint32, uint32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL(Uint64, uint64_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL(Int32, int32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL(Int64, int64_t) + +template +inline __device__ T ShmemAtomicTypeFetchWarpKernelImpl(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, void* compare, + size_t bytes, core::atomicType amoType, + int pe, int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) { - ShmemAtomicSizeFetchThreadKernelImpl(dest, destOffset, source, sourceOffset, val, - compare, bytes, pe, amoType); + return ShmemAtomicTypeFetchThreadKernelImpl(dest, destOffset, val, compare, bytes, + amoType, pe, qpId); } + return T{}; } -template <> -inline __device__ void ShmemAtomicSizeFetchWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, core::atomicType amoType) { - DISPATCH_PROVIDER_TYPE(ShmemAtomicSizeFetchWarpKernelImpl, dest, destOffset, source, sourceOffset, - val, compare, bytes, pe, amoType); -} +#define DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL(TypeName, T) \ + template <> \ + inline __device__ T ShmemAtomicTypeFetchWarpKernel( \ + const application::SymmMemObjPtr dest, size_t destOffset, void* val, void* compare, \ + size_t bytes, core::atomicType amoType, int pe, int qpId) { \ + return DISPATCH_PROVIDER_TYPE_COMPILE_TIME_WITH_RETURN(ShmemAtomicTypeFetchWarpKernelImpl, T, \ + dest, destOffset, val, compare, bytes, \ + amoType, pe, qpId); \ + } + +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL(Uint32, uint32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL(Uint64, uint64_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL(Int32, int32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL(Int64, int64_t) } // namespace shmem } // namespace mori diff --git a/include/mori/shmem/shmem_p2p_kernels.hpp b/include/mori/shmem/shmem_p2p_kernels.hpp index 56d7cdf9..d44d1200 100644 --- a/include/mori/shmem/shmem_p2p_kernels.hpp +++ b/include/mori/shmem/shmem_p2p_kernels.hpp @@ -24,6 +24,8 @@ #include #include +#include + #include "mori/application/application.hpp" #include "mori/core/core.hpp" #include "mori/shmem/shmem_api.hpp" @@ -38,7 +40,8 @@ namespace shmem { template <> inline __device__ void ShmemPutMemNbiThreadKernel( const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe) { + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe, + int qpId) { uint8_t* srcPtr = reinterpret_cast(source.addr + sourceOffset); uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); core::ThreadCopy(destPtr, srcPtr, bytes); @@ -47,7 +50,8 @@ inline __device__ void ShmemPutMemNbiThreadKernel inline __device__ void ShmemPutMemNbiWarpKernel( const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe) { + const application::RdmaMemoryRegion& source, size_t sourceOffset, size_t bytes, int pe, + int qpId) { uint8_t* srcPtr = reinterpret_cast(source.addr + sourceOffset); uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); core::WarpCopy(destPtr, srcPtr, bytes); @@ -55,7 +59,8 @@ inline __device__ void ShmemPutMemNbiWarpKernel template <> inline __device__ void ShmemPutSizeImmNbiThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe) { + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe, + int qpId) { uint8_t* srcPtr = reinterpret_cast(val); uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); switch (bytes) { @@ -88,7 +93,8 @@ inline __device__ void ShmemPutSizeImmNbiThreadKernel inline __device__ void ShmemPutSizeImmNbiWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe) { + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, int pe, + int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) ShmemPutSizeImmNbiThreadKernel(dest, destOffset, val, bytes, @@ -97,9 +103,8 @@ inline __device__ void ShmemPutSizeImmNbiWarpKernel inline __device__ void ShmemAtomicSizeNonFetchThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType) { + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, + core::atomicType amoType, int pe, int qpId) { uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); switch (bytes) { case 4: { @@ -217,176 +222,143 @@ inline __device__ void ShmemAtomicSizeNonFetchThreadKernel inline __device__ void ShmemAtomicSizeNonFetchWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, size_t bytes, - int pe, core::atomicType amoType) { + const application::SymmMemObjPtr dest, size_t destOffset, void* val, size_t bytes, + core::atomicType amoType, int pe, int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) { - ShmemAtomicSizeNonFetchThreadKernel( - dest, destOffset, source, sourceOffset, val, bytes, pe, amoType); + ShmemAtomicSizeNonFetchThreadKernel(dest, destOffset, val, + bytes, amoType, pe); } } -template <> -inline __device__ void ShmemAtomicSizeFetchThreadKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, core::atomicType amoType) { - uint8_t* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); - switch (bytes) { - case 4: { - int* fetchResPtr = reinterpret_cast(val); - int cmpVal = (compare != nullptr) ? *reinterpret_cast(compare) : 0; - int* remoteIntPtr = reinterpret_cast(destPtr); - auto casLoop = [=] __device__(int* addr, core::atomicType op, int operand, int cmpVal, - int* oldResult) { - int oldVal = core::AtomicLoadSeqCstSystem(addr); - while (true) { - int newVal = oldVal; - switch (op) { - case core::AMO_FETCH_INC: - newVal = oldVal + 1; - break; - case core::AMO_FETCH_ADD: - newVal = oldVal + operand; - break; - case core::AMO_FETCH_AND: - newVal = oldVal & operand; - break; - case core::AMO_FETCH_OR: - newVal = oldVal | operand; - break; - case core::AMO_FETCH_XOR: - newVal = oldVal ^ operand; - break; - case core::AMO_SWAP: - newVal = operand; - break; - case core::AMO_COMPARE_SWAP: - if (oldVal == cmpVal) { - newVal = operand; - } else { - newVal = oldVal; - } - break; - default: - break; - } +template +inline __device__ T ShmemAtomicTypeFetchThreadKernelImplP2P(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, + void* compare, size_t bytes, + core::atomicType amoType, int pe, + int qpId) { + T* destPtr = reinterpret_cast(dest->peerPtrs[pe] + destOffset); + T* fetchResPtr = reinterpret_cast(val); + T cmpVal = (compare != nullptr) ? *reinterpret_cast(compare) : T{}; - int expected = oldVal; - int prev = core::AtomicCompareExchangeSystem(addr, &expected, newVal); - if (prev == oldVal) { - *oldResult = oldVal; - break; - } - } - return oldVal; - }; - int* operandIntPtr = reinterpret_cast(source.addr + sourceOffset); - int operandInt = *operandIntPtr; - switch (amoType) { + auto casLoop = [=] __device__(T * addr, core::atomicType op, T operand, T cmpVal, T * oldResult) { + T oldVal = core::AtomicLoadSeqCstSystem(addr); + while (true) { + T newVal = oldVal; + switch (op) { case core::AMO_FETCH_INC: + newVal = oldVal + T{1}; + break; case core::AMO_FETCH_ADD: + newVal = oldVal + operand; + break; case core::AMO_FETCH_AND: + if constexpr (std::is_integral_v) { + newVal = oldVal & operand; + } + break; case core::AMO_FETCH_OR: + if constexpr (std::is_integral_v) { + newVal = oldVal | operand; + } + break; case core::AMO_FETCH_XOR: + if constexpr (std::is_integral_v) { + newVal = oldVal ^ operand; + } + break; case core::AMO_SWAP: - case core::AMO_COMPARE_SWAP: { - *operandIntPtr = casLoop(remoteIntPtr, amoType, operandInt, cmpVal, fetchResPtr); - } break; - + newVal = operand; + break; + case core::AMO_COMPARE_SWAP: + if (oldVal == cmpVal) { + newVal = operand; + } else { + newVal = oldVal; + } + break; default: - printf("Error: Unsupported 4-byte atomicType (%d) in FetchThreadKernel.\n", amoType); break; } - break; + T expected = oldVal; + T prev = core::AtomicCompareExchangeSystem(addr, &expected, newVal); + if (prev == oldVal) { + *oldResult = oldVal; + break; + } + oldVal = prev; } - case 8: { - long long* fetchResPtr = reinterpret_cast(val); - long long cmpValLL = (compare != nullptr) ? *reinterpret_cast(compare) : 0LL; - long long* remoteLLPtr = reinterpret_cast(destPtr); - auto casLoop64 = [=] __device__(long long* addr, core::atomicType op, long long operand, - long long cmpValLL, long long* oldResult) { - long long oldVal = core::AtomicLoadSeqCstSystem(addr); - while (true) { - long long newVal = oldVal; - switch (op) { - case core::AMO_FETCH_INC: - newVal = oldVal + 1; - break; - case core::AMO_FETCH_ADD: - newVal = oldVal + operand; - break; - case core::AMO_FETCH_AND: - newVal = oldVal & operand; - break; - case core::AMO_FETCH_OR: - newVal = oldVal | operand; - break; - case core::AMO_FETCH_XOR: - newVal = oldVal ^ operand; - break; - case core::AMO_SWAP: - newVal = operand; - break; - case core::AMO_COMPARE_SWAP: - if (oldVal == cmpValLL) { - newVal = operand; - } else { - newVal = oldVal; - } - break; - default: - break; - } + return oldVal; + }; - long long expected = oldVal; - long long prev = core::AtomicCompareExchangeSystem(addr, &expected, newVal); - if (prev == oldVal) { - *oldResult = oldVal; - break; - } - } - return oldVal; - }; - long long* operandLLPtr = reinterpret_cast(source.addr + sourceOffset); - long long operandLL = *operandLLPtr; - switch (amoType) { - case core::AMO_FETCH_INC: - case core::AMO_FETCH_ADD: - case core::AMO_FETCH_AND: - case core::AMO_FETCH_OR: - case core::AMO_FETCH_XOR: - case core::AMO_SWAP: - case core::AMO_COMPARE_SWAP: { - *operandLLPtr = casLoop64(remoteLLPtr, amoType, operandLL, cmpValLL, fetchResPtr); - } break; + T* valPtr = reinterpret_cast(val); + T operand = *valPtr; - default: - printf("Error: Unsupported 8-byte atomicType (%d) in FetchThreadKernel.\n", amoType); - break; - } - break; - } + switch (amoType) { + case core::AMO_FETCH_INC: + case core::AMO_FETCH_ADD: + case core::AMO_FETCH_AND: + case core::AMO_FETCH_OR: + case core::AMO_FETCH_XOR: + case core::AMO_SWAP: + case core::AMO_COMPARE_SWAP: { + T result = casLoop(destPtr, amoType, operand, cmpVal, fetchResPtr); + return result; + } break; default: - printf("Error: Unsupported data size (%zu bytes) in FetchThreadKernel.\n", bytes); + if constexpr (sizeof(T) == 4) { + printf("Error: Unsupported 4-byte atomicType (%d) in TypeFetchThreadKernel.\n", amoType); + } else if constexpr (sizeof(T) == 8) { + printf("Error: Unsupported 8-byte atomicType (%d) in TypeFetchThreadKernel.\n", amoType); + } break; } + return T{}; } -template <> -inline __device__ void ShmemAtomicSizeFetchWarpKernel( - const application::SymmMemObjPtr dest, size_t destOffset, - const application::RdmaMemoryRegion& source, size_t sourceOffset, void* val, void* compare, - size_t bytes, int pe, core::atomicType amoType) { +#define DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL_P2P(TypeName, T) \ + template <> \ + inline __device__ T ShmemAtomicTypeFetchThreadKernel( \ + const application::SymmMemObjPtr dest, size_t destOffset, void* val, void* compare, \ + size_t bytes, core::atomicType amoType, int pe, int qpId) { \ + return ShmemAtomicTypeFetchThreadKernelImplP2P(dest, destOffset, val, compare, bytes, \ + amoType, pe, qpId); \ + } + +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL_P2P(Uint32, uint32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL_P2P(Uint64, uint64_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL_P2P(Int32, int32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_THREAD_KERNEL_P2P(Int64, int64_t) + +template +inline __device__ T ShmemAtomicTypeFetchWarpKernelImpl_P2P(const application::SymmMemObjPtr dest, + size_t destOffset, void* val, + void* compare, size_t bytes, + core::atomicType amoType, int pe, + int qpId) { int laneId = threadIdx.x & (warpSize - 1); if (laneId == 0) { - ShmemAtomicSizeFetchThreadKernel( - dest, destOffset, source, sourceOffset, val, compare, bytes, pe, amoType); + return ShmemAtomicTypeFetchThreadKernelImplP2P(dest, destOffset, val, compare, bytes, + amoType, pe, qpId); } + return T{}; } +#define DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL_P2P(TypeName, T) \ + template <> \ + inline __device__ T ShmemAtomicTypeFetchWarpKernel( \ + const application::SymmMemObjPtr dest, size_t destOffset, void* val, void* compare, \ + size_t bytes, core::atomicType amoType, int pe, int qpId) { \ + return ShmemAtomicTypeFetchThreadKernelImplP2P(dest, destOffset, val, compare, bytes, \ + amoType, pe, qpId); \ + } + +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL_P2P(Uint32, uint32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL_P2P(Uint64, uint64_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL_P2P(Int32, int32_t) +DEFINE_SHMEM_ATOMIC_TYPE_FETCH_WARP_KERNEL_P2P(Int64, int64_t) + /* ---------------------------------------------------------------------------------------------- */ /* Synchronization */ /* ---------------------------------------------------------------------------------------------- */ @@ -394,7 +366,7 @@ template <> inline __device__ void ShmemQuietThreadKernel() {} template <> -inline __device__ void ShmemQuietThreadKernel(int pe) {} +inline __device__ void ShmemQuietThreadKernel(int pe, int qpId) {} } // namespace shmem } // namespace mori diff --git a/src/application/context/context.cpp b/src/application/context/context.cpp index 2b5f15a8..7d2e1b31 100644 --- a/src/application/context/context.cpp +++ b/src/application/context/context.cpp @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -89,12 +90,10 @@ void Context::CollectHostNames() { constexpr int IDENTIFIER_MAX = HOST_NAME_MAX + INET_ADDRSTRLEN; std::vector globalIdentifiers(IDENTIFIER_MAX * WorldSize()); - // Create a non-const buffer for Allgather char localBuffer[IDENTIFIER_MAX]; strncpy(localBuffer, hostIdentifier.c_str(), IDENTIFIER_MAX - 1); localBuffer[IDENTIFIER_MAX - 1] = '\0'; - bootNet.Allgather(localBuffer, globalIdentifiers.data(), IDENTIFIER_MAX); for (int i = 0; i < WorldSize(); i++) { @@ -102,9 +101,9 @@ void Context::CollectHostNames() { } if (LocalRank() == 0) { - MORI_APP_INFO("Collected hostnames:"); + MORI_APP_TRACE("Collected hostnames:"); for (int i = 0; i < hostnames.size(); i++) { - MORI_APP_INFO(" rank {}: {}", i, hostnames[i]); + MORI_APP_TRACE(" rank {}: {}", i, hostnames[i]); } } } @@ -178,6 +177,12 @@ void Context::InitializePossibleTransports() { << "[" << devicePortId << "] " << device->Name() << std::endl; } + int numQpPerPe = 4; + const char* envNumQp = std::getenv("MORI_NUM_QP_PER_PE"); + if (envNumQp != nullptr) { + numQpPerPe = std::max(1, std::atoi(envNumQp)); // ensure at least 1 QP + } + this->numQpPerPe = numQpPerPe; // Initialize transport int peerRankInNode = -1; for (int i = 0; i < WorldSize(); i++) { @@ -192,44 +197,66 @@ void Context::InitializePossibleTransports() { if ((i == LocalRank()) || canAccessPeer) { transportTypes.push_back(TransportType::P2P); - rdmaEps.push_back({}); + for (int qp = 0; qp < numQpPerPe; qp++) { + rdmaEps.push_back({}); + } continue; } } } else { if (i == LocalRank()) { transportTypes.push_back(TransportType::P2P); - rdmaEps.push_back({}); + for (int qp = 0; qp < numQpPerPe; qp++) { + rdmaEps.push_back({}); + } continue; } } if (rdmaDeviceContext.get() == nullptr) assert(false && "no rdma device found"); - + // Create multiple QPs for this peer application::RdmaEndpointConfig config; config.portId = portId; config.gidIdx = 3; config.maxMsgsNum = 4096; +#ifdef ENABLE_BNXT + config.maxCqeNum = 4096; +#else config.maxCqeNum = 4096; +#endif config.alignment = 4096; config.onGpu = true; - RdmaEndpoint ep = rdmaDeviceContext->CreateRdmaEndpoint(config); - rdmaEps.push_back(ep); + for (int qp = 0; qp < numQpPerPe; qp++) { + RdmaEndpoint ep = rdmaDeviceContext->CreateRdmaEndpoint(config); + rdmaEps.push_back(ep); + } transportTypes.push_back(TransportType::RDMA); } // All2All rdma eps - // Exchange endpoint handles - std::vector localToPeerEpHandles(WorldSize()); - std::vector peerToLocalEpHandles(WorldSize()); - for (int i = 0; i < WorldSize(); i++) localToPeerEpHandles[i] = rdmaEps[i].handle; + // Exchange endpoint handles (now with multiple QPs per peer) + int totalEps = WorldSize() * numQpPerPe; + std::vector localToPeerEpHandles(totalEps); + std::vector peerToLocalEpHandles(totalEps); + + // Fill local endpoint handles + for (int i = 0; i < rdmaEps.size(); i++) { + localToPeerEpHandles[i] = rdmaEps[i].handle; + } + bootNet.AllToAll(localToPeerEpHandles.data(), peerToLocalEpHandles.data(), - sizeof(RdmaEndpointHandle)); + sizeof(RdmaEndpointHandle) * numQpPerPe); // Connect RDMA endpoints - for (int i = 0; i < WorldSize(); i++) { - if (transportTypes[i] != TransportType::RDMA) continue; - rdmaDeviceContext->ConnectEndpoint(localToPeerEpHandles[i], peerToLocalEpHandles[i]); + for (int peer = 0; peer < WorldSize(); peer++) { + if (transportTypes[peer] != TransportType::RDMA) { + continue; + } + for (int qp = 0; qp < numQpPerPe; qp++) { + int epIndex = peer * numQpPerPe + qp; + rdmaDeviceContext->ConnectEndpoint(localToPeerEpHandles[epIndex], + peerToLocalEpHandles[epIndex], qp); + } } } diff --git a/src/application/transport/rdma/providers/bnxt/bnxt.cpp b/src/application/transport/rdma/providers/bnxt/bnxt.cpp index 05067da1..b8e31ba2 100644 --- a/src/application/transport/rdma/providers/bnxt/bnxt.cpp +++ b/src/application/transport/rdma/providers/bnxt/bnxt.cpp @@ -25,7 +25,11 @@ #include #include +#include +#include #include +#include +#include #include "mori/application/utils/check.hpp" #include "mori/application/utils/math.hpp" @@ -48,19 +52,16 @@ static std::ostream& operator<<(std::ostream& s, const bnxt_re_dv_qp_mem_info& m template <> struct fmt::formatter { - constexpr auto parse(format_parse_context& ctx) -> decltype(ctx.begin()) { - return ctx.end(); - } + constexpr auto parse(format_parse_context& ctx) -> decltype(ctx.begin()) { return ctx.end(); } template auto format(const bnxt_re_dv_qp_mem_info& m, FormatContext& ctx) -> decltype(ctx.out()) { return fmt::format_to(ctx.out(), - "qp_handle: 0x{:x} sq_va: 0x{:x} sq_len: {} sq_slots: {} " - "sq_wqe_sz: {} sq_psn_sz: {} sq_npsn: {} rq_va: 0x{:x} " - "rq_len: {} rq_slots: {} rq_wqe_sz: {} comp_mask: 0x{:x}", - m.qp_handle, m.sq_va, m.sq_len, m.sq_slots, m.sq_wqe_sz, - m.sq_psn_sz, m.sq_npsn, m.rq_va, m.rq_len, m.rq_slots, - m.rq_wqe_sz, m.comp_mask); + "qp_handle: 0x{:x} sq_va: 0x{:x} sq_len: {} sq_slots: {} " + "sq_wqe_sz: {} sq_psn_sz: {} sq_npsn: {} rq_va: 0x{:x} " + "rq_len: {} rq_slots: {} rq_wqe_sz: {} comp_mask: 0x{:x}", + m.qp_handle, m.sq_va, m.sq_len, m.sq_slots, m.sq_wqe_sz, m.sq_psn_sz, + m.sq_npsn, m.rq_va, m.rq_len, m.rq_slots, m.rq_wqe_sz, m.comp_mask); } }; #endif // ENABLE_BNXT @@ -68,6 +69,7 @@ struct fmt::formatter { namespace mori { namespace application { #ifdef ENABLE_BNXT + /* ---------------------------------------------------------------------------------------------- */ /* BnxtCqContainer */ /* ---------------------------------------------------------------------------------------------- */ @@ -111,6 +113,9 @@ BnxtCqContainer::BnxtCqContainer(ibv_context* context, const RdmaEndpointConfig& int status = bnxt_re_dv_init_obj(&dv_obj, BNXT_RE_DV_OBJ_CQ); assert(!status); cqn = dvcq.cqn; + + MORI_APP_TRACE("BNXT CQ created: cqn={}, cqeNum={}, cqSize={}, cqUmemAddr=0x{:x}", cqn, cqeNum, + cqSize, reinterpret_cast(cqUmemAddr)); } BnxtCqContainer::~BnxtCqContainer() { @@ -167,8 +172,8 @@ int bnxt_re_calc_dv_qp_mem_info(struct ibv_pd* ibvpd, struct ibv_qp_init_attr* a } BnxtQpContainer::BnxtQpContainer(ibv_context* context, const RdmaEndpointConfig& config, ibv_cq* cq, - ibv_pd* pd) - : context(context), config(config) { + ibv_pd* pd, BnxtDeviceContext* device_context) + : context(context), config(config), device_context(device_context) { struct ibv_qp_init_attr ib_qp_attr; struct bnxt_re_dv_umem_reg_attr umem_attr; struct bnxt_re_dv_qp_init_attr dv_qp_attr; @@ -257,19 +262,84 @@ BnxtQpContainer::BnxtQpContainer(ibv_context* context, const RdmaEndpointConfig& qp = bnxt_re_dv_create_qp(pd, &dv_qp_attr); assert(qp); qpn = qp->qp_num; - MORI_APP_INFO(qpMemInfo); + + // Allocate and register atomic internal buffer (ibuf) + atomicIbufSize = (RoundUpPowOfTwo(config.atomicIbufSlots) + 1) * ATOMIC_IBUF_SLOT_SIZE; + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipMalloc(&atomicIbufAddr, atomicIbufSize)); + HIP_RUNTIME_CHECK(hipMemset(atomicIbufAddr, 0, atomicIbufSize)); + } else { + int status = posix_memalign(&atomicIbufAddr, config.alignment, atomicIbufSize); + memset(atomicIbufAddr, 0, atomicIbufSize); + assert(!status); + } + if (config.onGpu) { + HIP_RUNTIME_CHECK( + hipExtMallocWithFlags(&atomicIbufAddr, atomicIbufSize, hipDeviceMallocUncached)); + HIP_RUNTIME_CHECK(hipMemset(atomicIbufAddr, 0, atomicIbufSize)); + } else { + err = posix_memalign(&atomicIbufAddr, config.alignment, atomicIbufSize); + memset(atomicIbufAddr, 0, atomicIbufSize); + assert(!err); + } + + // Register atomic ibuf as independent memory region + atomicIbufMr = ibv_reg_mr(pd, atomicIbufAddr, atomicIbufSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC); + assert(atomicIbufMr); + + MORI_APP_TRACE( + "BNXT Atomic ibuf allocated: addr=0x{:x}, slots={}, size={}, lkey=0x{:x}, rkey=0x{:x}", + reinterpret_cast(atomicIbufAddr), RoundUpPowOfTwo(config.atomicIbufSlots), + atomicIbufSize, atomicIbufMr->lkey, atomicIbufMr->rkey); + MORI_APP_TRACE(qpMemInfo); } BnxtQpContainer::~BnxtQpContainer() { DestroyQueuePair(); } void BnxtQpContainer::DestroyQueuePair() { - if (sqUmemAddr) HIP_RUNTIME_CHECK(hipFree(sqUmemAddr)); - if (rqUmemAddr) HIP_RUNTIME_CHECK(hipFree(rqUmemAddr)); - if (qpDbrUmemAddr) HIP_RUNTIME_CHECK(hipFree(qpDbrUmemAddr)); + // Clean up atomic internal buffer + if (atomicIbufMr) { + ibv_dereg_mr(atomicIbufMr); + atomicIbufMr = nullptr; + } + if (atomicIbufAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(atomicIbufAddr)); + } else { + free(atomicIbufAddr); + } + atomicIbufAddr = nullptr; + } + if (sqUmem) bnxt_re_dv_umem_dereg(sqUmem); if (rqUmem) bnxt_re_dv_umem_dereg(rqUmem); + if (sqUmemAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(sqUmemAddr)); + } else { + free(sqUmemAddr); + } + } + if (rqUmemAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(rqUmemAddr)); + } else { + free(rqUmemAddr); + } + } + if (qpDbrUmemAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(qpDbrUmemAddr)); + } else { + free(qpDbrUmemAddr); + } + } if (qpUar) { - HIP_RUNTIME_CHECK(hipHostUnregister(qpUar)); + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipHostUnregister(qpUar)); + } } if (qp) bnxt_re_dv_destroy_qp(qp); } @@ -298,7 +368,7 @@ void BnxtQpContainer::ModifyRst2Init() { void BnxtQpContainer::ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, const ibv_port_attr& portAttr, - const ibv_device_attr_ex& deviceAttr) { + const ibv_device_attr_ex& deviceAttr, uint32_t qpId) { struct ibv_qp_attr attr; int attr_mask; @@ -319,10 +389,22 @@ void BnxtQpContainer::ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, attr.max_dest_rd_atomic = deviceAttr.orig_attr.max_qp_rd_atom; attr.min_rnr_timer = 12; + // Use qpId to select UDP sport value from the shared configuration (round-robin) + uint16_t selected_udp_sport = GetDeviceContext()->GetUdpSport(qpId); + MORI_APP_TRACE("QP {} using UDP sport {} (qpId={}, index={})", qpn, selected_udp_sport, qpId, + qpId % RDMA_UDP_SPORT_ARRAY_SIZE); + int status = bnxt_re_dv_modify_qp_udp_sport(qp, selected_udp_sport); + if (status) { + MORI_APP_ERROR("Failed to set UDP sport {} for QP {}: error code {}", selected_udp_sport, qpn, + status); + } + assert(!status); + MORI_APP_TRACE("bnxt_re_dv_modify_qp_udp_sport is done, return {}", status); + attr_mask = IBV_QP_STATE | IBV_QP_PATH_MTU | IBV_QP_RQ_PSN | IBV_QP_DEST_QPN | IBV_QP_AV | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER; - int status = bnxt_re_dv_modify_qp(qp, &attr, attr_mask, 0, 0); + status = bnxt_re_dv_modify_qp(qp, &attr, attr_mask, 0, 0); assert(!status); } @@ -368,7 +450,7 @@ RdmaEndpoint BnxtDeviceContext::CreateRdmaEndpoint(const RdmaEndpointConfig& con BnxtCqContainer* cq = new BnxtCqContainer(context, config); - BnxtQpContainer* qp = new BnxtQpContainer(context, config, cq->cq, pd); + BnxtQpContainer* qp = new BnxtQpContainer(context, config, cq->cq, pd, this); int ret; RdmaEndpoint endpoint; @@ -421,14 +503,26 @@ RdmaEndpoint BnxtDeviceContext::CreateRdmaEndpoint(const RdmaEndpointConfig& con endpoint.cqHandle.cqeNum = cq->cqeNum; endpoint.cqHandle.cqeSize = GetBnxtCqeSize(); + // Set atomic internal buffer information + endpoint.atomicIbuf.addr = reinterpret_cast(qp->atomicIbufAddr); + endpoint.atomicIbuf.lkey = qp->atomicIbufMr->lkey; + endpoint.atomicIbuf.rkey = qp->atomicIbufMr->rkey; + endpoint.atomicIbuf.nslots = RoundUpPowOfTwo(config.atomicIbufSlots); + cqPool.insert({cq->cqn, cq}); qpPool.insert({qp->qpn, qp}); + MORI_APP_TRACE( + "BNXT endpoint created: qpn={}, cqn={}, portId={}, gidIdx={}, atomicIbuf addr=0x{:x}, " + "nslots={}", + qp->qpn, cq->cqn, config.portId, config.gidIdx, endpoint.atomicIbuf.addr, + endpoint.atomicIbuf.nslots); + return endpoint; } void BnxtDeviceContext::ConnectEndpoint(const RdmaEndpointHandle& local, - const RdmaEndpointHandle& remote) { + const RdmaEndpointHandle& remote, uint32_t qpId) { uint32_t local_qpn = local.qpn; assert(qpPool.find(local_qpn) != qpPool.end()); BnxtQpContainer* qp = qpPool.at(local_qpn); @@ -436,7 +530,7 @@ void BnxtDeviceContext::ConnectEndpoint(const RdmaEndpointHandle& local, const ibv_device_attr_ex& deviceAttr = *(rdmaDevice->GetDeviceAttr()); const ibv_port_attr& portAttr = *(rdmaDevice->GetPortAttrMap()->find(local.portId)->second); qp->ModifyRst2Init(); - qp->ModifyInit2Rtr(remote, portAttr, deviceAttr); + qp->ModifyInit2Rtr(remote, portAttr, deviceAttr, qpId); qp->ModifyRtr2Rts(local, remote); } diff --git a/src/application/transport/rdma/providers/ibverbs/ibverbs.cpp b/src/application/transport/rdma/providers/ibverbs/ibverbs.cpp index c01860d5..c2e28c06 100644 --- a/src/application/transport/rdma/providers/ibverbs/ibverbs.cpp +++ b/src/application/transport/rdma/providers/ibverbs/ibverbs.cpp @@ -96,7 +96,7 @@ RdmaEndpoint IBVerbsDeviceContext::CreateRdmaEndpoint(const RdmaEndpointConfig& } void IBVerbsDeviceContext::ConnectEndpoint(const RdmaEndpointHandle& local, - const RdmaEndpointHandle& remote) { + const RdmaEndpointHandle& remote, uint32_t qpId) { ibv_qp_attr attr; int flags; diff --git a/src/application/transport/rdma/providers/mlx5/mlx5.cpp b/src/application/transport/rdma/providers/mlx5/mlx5.cpp index 9a2e578b..18773541 100644 --- a/src/application/transport/rdma/providers/mlx5/mlx5.cpp +++ b/src/application/transport/rdma/providers/mlx5/mlx5.cpp @@ -29,6 +29,7 @@ #include "mori/application/utils/check.hpp" #include "mori/application/utils/math.hpp" +#include "mori/utils/mori_log.hpp" #include "src/application/transport/rdma/providers/mlx5/mlx5_ifc.hpp" #include "src/application/transport/rdma/providers/mlx5/mlx5_prm.hpp" @@ -62,6 +63,9 @@ HcaCapability QueryHcaCap(ibv_context* context) { DEVX_GET(query_hca_cap_out, cmd_cap_out, capability.cmd_hca_cap.log_bf_reg_size); hca_cap.dbrRegSize = 1LLU << logBfRegSize; + MORI_APP_TRACE("MLX5 HCA capabilities: portType={}, dbrRegSize={}", hca_cap.portType, + hca_cap.dbrRegSize); + return hca_cap; } @@ -134,6 +138,9 @@ Mlx5CqContainer::Mlx5CqContainer(ibv_context* context, const RdmaEndpointConfig& assert(cq); cqn = DEVX_GET(create_cq_out, cmd_out, cqn); + + MORI_APP_TRACE("MLX5 CQ created: cqn={}, cqeNum={}, cqSize={}, cqUmemAddr=0x{:x}, uar_page_id={}", + cqn, cqeNum, cqSize, reinterpret_cast(cqUmemAddr), uar->page_id); } Mlx5CqContainer::~Mlx5CqContainer() { @@ -147,8 +154,8 @@ Mlx5CqContainer::~Mlx5CqContainer() { /* Mlx5QpContainer */ /* ---------------------------------------------------------------------------------------------- */ Mlx5QpContainer::Mlx5QpContainer(ibv_context* context, const RdmaEndpointConfig& config, - uint32_t cqn, uint32_t pdn) - : context(context), config(config) { + uint32_t cqn, uint32_t pdn, Mlx5DeviceContext* device_context) + : context(context), config(config), device_context(device_context) { ComputeQueueAttrs(config); CreateQueuePair(cqn, pdn); } @@ -167,7 +174,7 @@ void Mlx5QpContainer::ComputeQueueAttrs(const RdmaEndpointConfig& config) { // Send queue attributes sqAttrs.offset = rqAttrs.wqSize; sqAttrs.wqeSize = GetMlx5SqWqeSize(); - sqAttrs.wqSize = RoundUpPowOfTwo(sqAttrs.wqeSize * config.maxMsgsNum); + sqAttrs.wqSize = RoundUpPowOfTwo(sqAttrs.wqeSize * maxMsgsNum); sqAttrs.wqeNum = ceil(sqAttrs.wqSize / MLX5_SEND_WQE_BB); sqAttrs.wqeShift = MLX5_SEND_WQE_SHIFT; @@ -175,9 +182,11 @@ void Mlx5QpContainer::ComputeQueueAttrs(const RdmaEndpointConfig& config) { qpTotalSize = RoundUpPowOfTwo(rqAttrs.wqSize + sqAttrs.wqSize); qpTotalSize = (qpTotalSize + config.alignment - 1) / config.alignment * config.alignment; -#if DEBUG == 1 - std::cout << "rq[ " << rqAttrs << "] sq[ " << sqAttrs << "]" << std::endl; -#endif + MORI_APP_TRACE( + "MLX5 Queue attributes computed - RQ: wqeSize={}, wqSize={}, wqeNum={}, offset={} | SQ: " + "wqeSize={}, wqSize={}, wqeNum={}, offset={} | Total: {}", + rqAttrs.wqeSize, rqAttrs.wqSize, rqAttrs.wqeNum, rqAttrs.offset, sqAttrs.wqeSize, + sqAttrs.wqSize, sqAttrs.wqeNum, sqAttrs.offset, qpTotalSize); } void Mlx5QpContainer::CreateQueuePair(uint32_t cqn, uint32_t pdn) { @@ -216,6 +225,28 @@ void Mlx5QpContainer::CreateQueuePair(uint32_t cqn, uint32_t pdn) { qpDbrUmem = mlx5dv_devx_umem_reg(context, qpDbrUmemAddr, 8, IBV_ACCESS_LOCAL_WRITE); assert(qpDbrUmem); + // Allocate and register atomic internal buffer (ibuf) + atomicIbufSize = (RoundUpPowOfTwo(config.atomicIbufSlots) + 1) * ATOMIC_IBUF_SLOT_SIZE; + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipMalloc(&atomicIbufAddr, atomicIbufSize)); + HIP_RUNTIME_CHECK(hipMemset(atomicIbufAddr, 0, atomicIbufSize)); + } else { + status = posix_memalign(&atomicIbufAddr, config.alignment, atomicIbufSize); + memset(atomicIbufAddr, 0, atomicIbufSize); + assert(!status); + } + + // Register atomic ibuf as independent memory region + atomicIbufMr = ibv_reg_mr(device_context->GetIbvPd(), atomicIbufAddr, atomicIbufSize, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC); + assert(atomicIbufMr); + + MORI_APP_TRACE( + "MLX5 Atomic ibuf allocated: addr=0x{:x}, slots={}, size={}, lkey=0x{:x}, rkey=0x{:x}", + reinterpret_cast(atomicIbufAddr), RoundUpPowOfTwo(config.atomicIbufSlots), + atomicIbufSize, atomicIbufMr->lkey, atomicIbufMr->rkey); + // Allocate user access region qpUar = mlx5dv_devx_alloc_uar(context, MLX5DV_UAR_ALLOC_TYPE_NC); assert(qpUar); @@ -263,18 +294,52 @@ void Mlx5QpContainer::CreateQueuePair(uint32_t cqn, uint32_t pdn) { assert(qp); qpn = DEVX_GET(create_qp_out, cmd_out, qpn); + + MORI_APP_TRACE( + "MLX5 QP created: qpn={}, qpTotalSize={}, sqWqeNum={}, rqWqeNum={}, sqAddr=0x{:x}, " + "rqAddr=0x{:x}", + qpn, qpTotalSize, sqAttrs.wqeNum, rqAttrs.wqeNum, reinterpret_cast(GetSqAddress()), + reinterpret_cast(GetRqAddress())); } void Mlx5QpContainer::DestroyQueuePair() { - if (qpUmemAddr) HIP_RUNTIME_CHECK(hipFree(qpUmemAddr)); - if (qpDbrUmemAddr) HIP_RUNTIME_CHECK(hipFree(qpDbrUmemAddr)); + if (atomicIbufMr) { + ibv_dereg_mr(atomicIbufMr); + atomicIbufMr = nullptr; + } + if (atomicIbufAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(atomicIbufAddr)); + } else { + free(atomicIbufAddr); + } + atomicIbufAddr = nullptr; + } + + if (qpUmem) mlx5dv_devx_umem_dereg(qpUmem); + if (qpUmemAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(qpUmemAddr)); + } else { + free(qpUmemAddr); + } + } if (qpDbrUmem) mlx5dv_devx_umem_dereg(qpDbrUmem); + if (qpDbrUmemAddr) { + if (config.onGpu) { + HIP_RUNTIME_CHECK(hipFree(qpDbrUmemAddr)); + } else { + free(qpDbrUmemAddr); + } + } if (qpUar) { - hipPointerAttribute_t attr; - HIP_RUNTIME_CHECK(hipPointerGetAttributes(&attr, qpUar->reg_addr)); - // Multiple qp may share the same uar address, only unregister once - if ((attr.type == hipMemoryTypeHost) && (attr.hostPointer != nullptr)) { - HIP_RUNTIME_CHECK(hipHostUnregister(qpUar->reg_addr)); + if (config.onGpu) { + hipPointerAttribute_t attr; + HIP_RUNTIME_CHECK(hipPointerGetAttributes(&attr, qpUar->reg_addr)); + // Multiple qp may share the same uar address, only unregister once + if ((attr.type == hipMemoryTypeHost) && (attr.hostPointer != nullptr)) { + HIP_RUNTIME_CHECK(hipHostUnregister(qpUar->reg_addr)); + } } mlx5dv_devx_free_uar(qpUar); } @@ -312,7 +377,7 @@ void Mlx5QpContainer::ModifyRst2Init() { } void Mlx5QpContainer::ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, - const ibv_port_attr& portAttr) { + const ibv_port_attr& portAttr, uint32_t qpId) { uint8_t init2rtr_cmd_in[DEVX_ST_SZ_BYTES(init2rtr_qp_in)] = { 0, }; @@ -343,7 +408,11 @@ void Mlx5QpContainer::ModifyInit2Rtr(const RdmaEndpointHandle& remote_handle, sizeof(remote_handle.eth.mac)); DEVX_SET(qpc, qpc, primary_address_path.hop_limit, 64); DEVX_SET(qpc, qpc, primary_address_path.src_addr_index, config.gidIdx); - DEVX_SET(qpc, qpc, primary_address_path.udp_sport, 0xC000); + // Use shared UDP sport configuration with qpId-based selection + uint16_t selected_udp_sport = device_context->GetUdpSport(qpId); + DEVX_SET(qpc, qpc, primary_address_path.udp_sport, selected_udp_sport); + MORI_APP_TRACE("MLX5 QP {} using UDP sport {} (qpId={}, index={})", qpn, selected_udp_sport, + qpId, qpId % RDMA_UDP_SPORT_ARRAY_SIZE); } else if (portAttr.link_layer == IBV_LINK_LAYER_INFINIBAND) { DEVX_SET(qpc, qpc, primary_address_path.rlid, remote_handle.ib.lid); } else { @@ -400,7 +469,7 @@ RdmaEndpoint Mlx5DeviceContext::CreateRdmaEndpoint(const RdmaEndpointConfig& con ibv_context* context = GetIbvContext(); Mlx5CqContainer* cq = new Mlx5CqContainer(context, config); - Mlx5QpContainer* qp = new Mlx5QpContainer(context, config, cq->cqn, pdn); + Mlx5QpContainer* qp = new Mlx5QpContainer(context, config, cq->cqn, pdn, this); RdmaEndpoint endpoint; endpoint.handle.psn = 0; @@ -455,23 +524,42 @@ RdmaEndpoint Mlx5DeviceContext::CreateRdmaEndpoint(const RdmaEndpointConfig& con endpoint.cqHandle.cqeSize = GetMlx5CqeSize(); endpoint.cqHandle.dbrRecAddr = cq->cqDbrUmemAddr; + // Set atomic internal buffer information + endpoint.atomicIbuf.addr = reinterpret_cast(qp->atomicIbufAddr); + endpoint.atomicIbuf.lkey = qp->atomicIbufMr->lkey; + endpoint.atomicIbuf.rkey = qp->atomicIbufMr->rkey; + endpoint.atomicIbuf.nslots = RoundUpPowOfTwo(config.atomicIbufSlots); + cqPool.insert({cq->cqn, std::move(std::unique_ptr(cq))}); qpPool.insert({qp->qpn, std::move(std::unique_ptr(qp))}); + MORI_APP_TRACE( + "MLX5 endpoint created: qpn={}, cqn={}, portId={}, gidIdx={}, atomicIbuf addr=0x{:x}, " + "nslots={}", + qp->qpn, cq->cqn, config.portId, config.gidIdx, endpoint.atomicIbuf.addr, + endpoint.atomicIbuf.nslots); + return endpoint; } void Mlx5DeviceContext::ConnectEndpoint(const RdmaEndpointHandle& local, - const RdmaEndpointHandle& remote) { + const RdmaEndpointHandle& remote, uint32_t qpId) { uint32_t local_qpn = local.qpn; assert(qpPool.find(local_qpn) != qpPool.end()); Mlx5QpContainer* qp = qpPool.at(local_qpn).get(); + + MORI_APP_TRACE("MLX5 connecting endpoint: local_qpn={}, remote_qpn={}, qpId={}", local_qpn, + remote.qpn, qpId); + RdmaDevice* rdmaDevice = GetRdmaDevice(); const ibv_device_attr_ex* deviceAttr = rdmaDevice->GetDeviceAttr(); const ibv_port_attr& portAttr = *(rdmaDevice->GetPortAttrMap()->find(local.portId)->second); qp->ModifyRst2Init(); - qp->ModifyInit2Rtr(remote, portAttr); + qp->ModifyInit2Rtr(remote, portAttr, qpId); qp->ModifyRtr2Rts(local); + + MORI_APP_TRACE("MLX5 endpoint connected successfully: local_qpn={}, remote_qpn={}", local_qpn, + remote.qpn); } /* ---------------------------------------------------------------------------------------------- */ diff --git a/src/application/transport/rdma/rdma.cpp b/src/application/transport/rdma/rdma.cpp index 29b04b77..4d89b196 100644 --- a/src/application/transport/rdma/rdma.cpp +++ b/src/application/transport/rdma/rdma.cpp @@ -31,6 +31,7 @@ #include "mori/application/transport/rdma/providers/bnxt/bnxt.hpp" #include "mori/application/transport/rdma/providers/ibverbs/ibverbs.hpp" #include "mori/application/transport/rdma/providers/mlx5/mlx5.hpp" +#include "mori/utils/mori_log.hpp" namespace mori { namespace application { @@ -38,7 +39,9 @@ namespace application { /* ---------------------------------------------------------------------------------------------- */ /* RdmaDeviceContext */ /* ---------------------------------------------------------------------------------------------- */ -RdmaDeviceContext::RdmaDeviceContext(RdmaDevice* device, ibv_pd* inPd) : device(device), pd(inPd) {} +RdmaDeviceContext::RdmaDeviceContext(RdmaDevice* device, ibv_pd* inPd) : device(device), pd(inPd) { + InitializeUdpSportConfiguration(); +} RdmaDeviceContext::~RdmaDeviceContext() { ibv_dealloc_pd(pd); @@ -198,6 +201,7 @@ ActiveDevicePortList GetActiveDevicePortList(const RdmaDeviceList& devices) { /* ---------------------------------------------------------------------------------------------- */ RdmaContext::RdmaContext(RdmaBackendType backendType) : backendType(backendType) { deviceList = ibv_get_device_list(&nums_device); + MORI_APP_TRACE("ibv_get_device_list nums_device: {}", nums_device); Initialize(); } @@ -288,5 +292,67 @@ void RdmaContext::Initialize() { } } +/* ---------------------------------------------------------------------------------------------- */ +/* UDP Sport Configuration */ +/* ---------------------------------------------------------------------------------------------- */ +void RdmaDeviceContext::InitializeUdpSportConfiguration() { + // Default UDP sport configuration + static constexpr uint16_t DEFAULT_UDP_SPORTS[RDMA_UDP_SPORT_ARRAY_SIZE] = { + 0xBE10, 0xBE11, 0xBE12, 0xBE13 + }; + + // Initialize with defaults + for (uint32_t i = 0; i < RDMA_UDP_SPORT_ARRAY_SIZE; i++) { + udp_sport_setting[i] = DEFAULT_UDP_SPORTS[i]; + } + + // Check for environment variable configurations + const char* gor_port_batch = std::getenv("MORI_GOR_PORT"); + if (gor_port_batch != nullptr) { + // Parse comma-separated values + std::string batch_str(gor_port_batch); + std::stringstream ss(batch_str); + std::string item; + uint32_t index = 0; + + while (std::getline(ss, item, ',') && index < RDMA_UDP_SPORT_ARRAY_SIZE) { + try { + // Support both decimal and hexadecimal (0x prefix) formats + uint16_t port_val = static_cast(std::stoul(item, nullptr, 0)); + udp_sport_setting[index] = port_val; + index++; + } catch (const std::exception& e) { + MORI_APP_WARN("Invalid UDP sport value in MORI_GOR_PORT: {}, using default value", item); + } + } + } else { + // Check individual environment variables + const char* env_vars[RDMA_UDP_SPORT_ARRAY_SIZE] = { + "MORI_GOR_PORT1", "MORI_GOR_PORT2", "MORI_GOR_PORT3", "MORI_GOR_PORT4" + }; + + for (uint32_t i = 0; i < RDMA_UDP_SPORT_ARRAY_SIZE; i++) { + const char* env_val = std::getenv(env_vars[i]); + if (env_val != nullptr) { + try { + uint16_t port_val = static_cast(std::stoul(env_val, nullptr, 0)); + udp_sport_setting[i] = port_val; + } catch (const std::exception& e) { + MORI_APP_WARN("Invalid UDP sport value in {}: {}, using default value", env_vars[i], env_val); + } + } + } + } + + // Log final configuration + for (uint32_t i = 0; i < RDMA_UDP_SPORT_ARRAY_SIZE; i++) { + MORI_APP_INFO("UDP sport[{}] = 0x{:x}", i, udp_sport_setting[i]); + } +} + +uint16_t RdmaDeviceContext::GetUdpSport(uint32_t qpId) const { + return udp_sport_setting[qpId % RDMA_UDP_SPORT_ARRAY_SIZE]; +} + } // namespace application } // namespace mori diff --git a/src/ops/dispatch_combine/dispatch_combine.cpp b/src/ops/dispatch_combine/dispatch_combine.cpp index 7c2dbe00..e98db14b 100644 --- a/src/ops/dispatch_combine/dispatch_combine.cpp +++ b/src/ops/dispatch_combine/dispatch_combine.cpp @@ -105,6 +105,9 @@ void EpDispatchCombineHandle::InitializeTokenNumSignalBuf() { size_t tokenNumSignalSize = config.worldSize * sizeof(index_t) * 2; recvTokenNumMemObj = ShmemMallocAndReturnMemObjPtr(tokenNumSignalSize, hipDeviceMallocUncached); sendTokenNumMemObj = ShmemMallocAndReturnMemObjPtr(tokenNumSignalSize, hipDeviceMallocUncached); + // The extra *2 is for the laddr. + sendAtomicSignalMemObj = ShmemMallocAndReturnMemObjPtr( + (config.worldSize * 2) * sizeof(int64_t) * 2, hipDeviceMallocUncached); HIP_RUNTIME_CHECK(hipMalloc(&totalRecvTokenNum, sizeof(index_t))); HIP_RUNTIME_CHECK(hipMemset(totalRecvTokenNum, 0, sizeof(index_t))); @@ -113,6 +116,7 @@ void EpDispatchCombineHandle::InitializeTokenNumSignalBuf() { void EpDispatchCombineHandle::FinalizeTokenNumSignalBuf() { ShmemFree(recvTokenNumMemObj->localPtr); ShmemFree(sendTokenNumMemObj->localPtr); + ShmemFree(sendAtomicSignalMemObj->localPtr); HIP_RUNTIME_CHECK(hipFree(totalRecvTokenNum)); } @@ -162,7 +166,8 @@ void EpDispatchCombineHandle::InitializeBarrier() { HIP_RUNTIME_CHECK(hipMemset(dispatchGridBarrier, 0, barrierSize)); HIP_RUNTIME_CHECK(hipMalloc(&combineGridBarrier, barrierSize)); HIP_RUNTIME_CHECK(hipMemset(combineGridBarrier, 0, barrierSize)); - crossDeviceBarrierMemObj = ShmemMallocAndReturnMemObjPtr(barrierSize, hipDeviceMallocUncached); + crossDeviceBarrierMemObj = ShmemMallocAndReturnMemObjPtr( + barrierSize * 2 * sizeof(uint64_t) / sizeof(uint32_t), hipDeviceMallocUncached); } void EpDispatchCombineHandle::FinalizeBarrier() { diff --git a/src/ops/dispatch_combine/internode.hpp b/src/ops/dispatch_combine/internode.hpp index bf8f3e1f..6c4fe8fd 100644 --- a/src/ops/dispatch_combine/internode.hpp +++ b/src/ops/dispatch_combine/internode.hpp @@ -55,7 +55,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { int laneId = threadIdx.x & (warpSize - 1); int warpId = thdId / warpSize; - int warpNum = blockDim.x / warpSize; + int warpNum = (blockDim.x + warpSize - 1) / warpSize; int globalThdId = blockIdx.x * blockDim.x + threadIdx.x; int globalThdNum = gridDim.x * blockDim.x; @@ -139,6 +139,12 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { const int startIdx = localBlockId * baseChunk + min(localBlockId, remainder); const int endIdx = startIdx + myChunkSize; + if (localBlockId == 0 && warpId == warpNum - 1) { + shmem::ShmemPutInt32ImmNbiWarp( + args.recvTokenNumMemObj, + (myPe + (args.crossDeviceBarrierFlag & 1) * npes) * sizeof(index_t), totalTokens, destPe, + localBlockId); + } if (destNode == myNode) { // intra node use xgmi for transfer @@ -173,7 +179,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { } shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, peSortedOffset, args.shmemStagingTokMemObj, mapIdxOffset, stagingOffset, - destPe); + destPe, localBlockId); } } else { // inter node use ibgda for transfer @@ -204,7 +210,7 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { size_t dstOffset = dstIdx * stagingOffset; shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, - actualTokenNum * stagingOffset, destPe); + actualTokenNum * stagingOffset, destPe, localBlockId); ++chunkIdx; chunkOffset += chunkTokenSize; @@ -250,25 +256,20 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { __shared__ index_t recvTokenNum; __syncthreads(); - if (thdId == 0) { - // shmem::ShmemAtomicTypeNonFetchWarp(args.recvTokenNumMemObj, myPe * sizeof(index_t), - // args.shmemStagingTokMemObj->GetMemoryRegion(myPe), - // myPe * sizeof(index_t), (int32_t)(totalTokens+1), - // destPe, core::AMO_SET); - int doneBlockNum = atomicAdd(&args.dispatchGridBarrier[destPe], 1); - if (doneBlockNum == numsBlockPerDestPe - 1) { - shmem::ShmemPutInt32ImmNbiThread( - args.recvTokenNumMemObj, - (myPe + (args.crossDeviceBarrierFlag & 1) * npes) * sizeof(index_t), totalTokens + 1, - destPe); - __hip_atomic_store(&args.dispatchGridBarrier[destPe], 0, __ATOMIC_RELAXED, - __HIP_MEMORY_SCOPE_AGENT); - } + if (warpId == warpNum - 1) { + shmem::ShmemAtomicTypeNonFetchWarp( + args.sendAtomicSignalMemObj, + (myPe + (args.crossDeviceBarrierFlag & 1) * npes) * sizeof(int64_t), 1, core::AMO_ADD, + destPe, localBlockId); } if (thdId == 0) { - index_t* signal = args.recvTokenNumMemObj->template GetAs() + destPe + + int64_t* signal = args.sendAtomicSignalMemObj->template GetAs() + destPe + (args.crossDeviceBarrierFlag & 1) * npes; - recvTokenNum = shmem::ShmemInt32WaitUntilGreaterThan(signal, 0) - 1; + shmem::ShmemInt64WaitUntilGreaterThan(signal, numsBlockPerDestPe - 1); + recvTokenNum = atomicAdd( + &args.recvTokenNumMemObj + ->template GetAs()[destPe + (args.crossDeviceBarrierFlag & 1) * npes], + 0); if (localBlockId == 0) { atomicAdd(args.totalRecvTokenNum, recvTokenNum); args.destPeTokenCounter[destPe] = 0; @@ -330,7 +331,8 @@ __global__ void EpDispatchInterNodeKernel(EpDispatchCombineArgs args) { /* BarrierKernel */ /* ---------------------------------------------------------------------------------------------- */ template -inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs args) { +inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgs args, + int numQps) { int thdId = threadIdx.x; int laneId = threadIdx.x & (warpSize - 1); int globalThdId = blockIdx.x * blockDim.x + threadIdx.x; @@ -342,23 +344,24 @@ inline __device__ void CrossDeviceBarrierInterNodeKernel(EpDispatchCombineArgstemplate GetAs(); + uint64_t* localBarrierPtr = args.crossDeviceBarrierMemObj->template GetAs(); if (thdId < args.config.worldSize) { - while (core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId) != args.crossDeviceBarrierFlag) { + uint64_t currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId); +#if DEBUG == 1 + printf("Thread %d: localBarrierPtr[%d] = %lu, expected = %lu\n", thdId, thdId, currentVal, + (uint64_t)(args.crossDeviceBarrierFlag * numQps)); +#endif + + while (currentVal != args.crossDeviceBarrierFlag * numQps) { + currentVal = core::AtomicLoadRelaxedSystem(localBarrierPtr + thdId); } } __syncthreads(); } - /* ---------------------------------------------------------------------------------------------- */ /* EpCombineInterNodeKernel */ /* ---------------------------------------------------------------------------------------------- */ @@ -392,8 +395,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { const int srcNode = srcPe / MAX_GPUS_PER_NODE; const int localBlockId = blockIdx.x - srcPe * numsBlockPerSrcPe; const int srcPeTokenNum = *(args.recvTokenNumMemObj->template GetAs() + srcPe + - (args.crossDeviceBarrierFlag & 1) * npes) - - 1; + (args.crossDeviceBarrierFlag & 1) * npes); const int baseChunk = srcPeTokenNum / numsBlockPerSrcPe; const int remainder = srcPeTokenNum % numsBlockPerSrcPe; @@ -459,7 +461,7 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { size_t dstOffset = dstIdx * tokenPackSize; shmem::ShmemPutTypeNbiWarp(args.shmemInpTokMemObj, dstOffset, args.shmemStagingTokMemObj, srcOffset, - actualTokenNum * tokenPackSize, srcPe); + actualTokenNum * tokenPackSize, srcPe, localBlockId); ++chunkIdx; chunkOffset += chunkTokenSize; @@ -490,14 +492,22 @@ __global__ void EpCombineInterNodeKernel(EpDispatchCombineArgs args) { __threadfence_block(); } } + __syncthreads(); + if (warpId == warpNum - 1) { + shmem::ShmemAtomicTypeNonFetchWarp(args.crossDeviceBarrierMemObj, + args.config.rank * sizeof(uint64_t), 1, + core::AMO_ADD, srcPe, localBlockId); + } SyncIfDebugEnabled("Combine kernel: send token end"); // Make sure copy on all GPUs are finished - CrossDeviceBarrierInterNodeKernel(args); + CrossDeviceBarrierInterNodeKernel(args, numsBlockPerSrcPe); shmem::ShmemQuietThread(); if (globalThdId < npes) { args.recvTokenNumMemObj ->template GetAs()[globalThdId + (args.crossDeviceBarrierFlag & 1) * npes] = 0; + args.sendAtomicSignalMemObj + ->template GetAs()[globalThdId + (args.crossDeviceBarrierFlag & 1) * npes] = 0; } if (globalThdId == 0) { diff --git a/src/shmem/init.cpp b/src/shmem/init.cpp index 48f2c990..4ebf18b4 100644 --- a/src/shmem/init.cpp +++ b/src/shmem/init.cpp @@ -66,6 +66,7 @@ void GpuStateInit() { GpuStates gpuStates; gpuStates.rank = rank; gpuStates.worldSize = worldSize; + gpuStates.numQpPerPe = rdmaStates->commContext->GetNumQpPerPe(); // Copy transport types to GPU HIP_RUNTIME_CHECK( @@ -76,13 +77,14 @@ void GpuStateInit() { // Copy endpoints to GPU if (rdmaStates->commContext->RdmaTransportEnabled()) { + size_t numEndpoints = gpuStates.worldSize * gpuStates.numQpPerPe; HIP_RUNTIME_CHECK( - hipMalloc(&gpuStates.rdmaEndpoints, sizeof(application::RdmaEndpoint) * worldSize)); + hipMalloc(&gpuStates.rdmaEndpoints, sizeof(application::RdmaEndpoint) * numEndpoints)); HIP_RUNTIME_CHECK( hipMemcpy(gpuStates.rdmaEndpoints, rdmaStates->commContext->GetRdmaEndpoints().data(), - sizeof(application::RdmaEndpoint) * worldSize, hipMemcpyHostToDevice)); + sizeof(application::RdmaEndpoint) * numEndpoints, hipMemcpyHostToDevice)); - size_t lockSize = worldSize * sizeof(uint32_t); + size_t lockSize = numEndpoints * sizeof(uint32_t); HIP_RUNTIME_CHECK(hipMalloc(&gpuStates.endpointLock, lockSize)); HIP_RUNTIME_CHECK(hipMemset(gpuStates.endpointLock, 0, lockSize)); } diff --git a/src/shmem/internal.hpp b/src/shmem/internal.hpp index 31360cdd..80959b71 100644 --- a/src/shmem/internal.hpp +++ b/src/shmem/internal.hpp @@ -80,6 +80,7 @@ struct ShmemStates { struct GpuStates { int rank{-1}; int worldSize{-1}; + int numQpPerPe{4}; // Default to 4 QPs per peer, consistent with Context default application::TransportType* transportTypes{nullptr}; application::RdmaEndpoint* rdmaEndpoints{nullptr}; uint32_t* endpointLock{nullptr}; diff --git a/src/shmem/memory.cpp b/src/shmem/memory.cpp index fe1f3252..33c06fec 100644 --- a/src/shmem/memory.cpp +++ b/src/shmem/memory.cpp @@ -24,6 +24,7 @@ #include "mori/application/memory/symmetric_memory.hpp" #include "mori/shmem/shmem_api.hpp" #include "src/shmem/internal.hpp" +#include "mori/utils/mori_log.hpp" namespace mori { namespace shmem { @@ -32,6 +33,7 @@ void* ShmemMalloc(size_t size) { ShmemStates* states = ShmemStatesSingleton::GetInstance(); states->CheckStatusValid(); application::SymmMemObjPtr obj = states->memoryStates->symmMemMgr->Malloc(size); + MORI_SHMEM_TRACE("Allocated shared memory of size {}", size); if (obj.IsValid()) { return obj.cpu->localPtr; } @@ -43,6 +45,7 @@ void* ShmemExtMallocWithFlags(size_t size, unsigned int flags) { states->CheckStatusValid(); application::SymmMemObjPtr obj = states->memoryStates->symmMemMgr->ExtMallocWithFlags(size, flags); + MORI_SHMEM_TRACE("Allocated shared memory of size {} with flags {}", size, flags); if (obj.IsValid()) { return obj.cpu->localPtr; }