Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 165 additions & 78 deletions examples/ops/dispatch_combine/test_dispatch_combine_internode.py

Large diffs are not rendered by default.

68 changes: 31 additions & 37 deletions include/mori/core/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,30 +48,28 @@ inline __device__ int FlatBlockWarpId() { return FlatBlockThreadId() / DeviceWar

inline __device__ int WarpLaneId() { return FlatBlockThreadId() & (DeviceWarpSize() - 1); }

inline __device__ bool IsThreadZeroInBlock() {
return (FlatBlockThreadId() % DeviceWarpSize()) == 0;
}

inline __device__ uint64_t GetActiveLaneMask() {
return __ballot(true);
}

inline __device__ unsigned int GetActiveLaneCount(uint64_t activeLaneMask) {
return __popcll(activeLaneMask);
}

inline __device__ unsigned int GetActiveLaneCount() {
return GetActiveLaneCount(GetActiveLaneMask());
}

inline __device__ unsigned int GetActiveLaneNum(uint64_t activeLaneMask) {
return __popcll(activeLaneMask & __lanemask_lt());
}

inline __device__ unsigned int GetActiveLaneNum() {
return GetActiveLaneNum(GetActiveLaneMask());
inline __device__ int WarpLaneId1D() { return threadIdx.x & (warpSize - 1); }

inline __device__ bool IsThreadZeroInBlock() {
return (FlatBlockThreadId() % DeviceWarpSize()) == 0;
}

inline __device__ uint64_t GetActiveLaneMask() { return __ballot(true); }

inline __device__ unsigned int GetActiveLaneCount(uint64_t activeLaneMask) {
return __popcll(activeLaneMask);
}

inline __device__ unsigned int GetActiveLaneCount() {
return GetActiveLaneCount(GetActiveLaneMask());
}

inline __device__ unsigned int GetActiveLaneNum(uint64_t activeLaneMask) {
return __popcll(activeLaneMask & __lanemask_lt());
}

inline __device__ unsigned int GetActiveLaneNum() { return GetActiveLaneNum(GetActiveLaneMask()); }

inline __device__ int GetFirstActiveLaneID(uint64_t activeLaneMask) {
return activeLaneMask ? __ffsll((unsigned long long int)activeLaneMask) - 1 : -1;
}
Expand All @@ -84,21 +82,17 @@ inline __device__ int GetLastActiveLaneID(uint64_t activeLaneMask) {

inline __device__ int GetLastActiveLaneID() { return GetLastActiveLaneID(GetActiveLaneMask()); }

inline __device__ bool IsFirstActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == 0;
}

inline __device__ bool IsFirstActiveLane() {
return IsFirstActiveLane(GetActiveLaneMask());
}

inline __device__ bool IsLastActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == GetActiveLaneCount(activeLaneMask) - 1;
}

inline __device__ bool IsLastActiveLane() {
return IsLastActiveLane(GetActiveLaneMask());
}
inline __device__ bool IsFirstActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == 0;
}

inline __device__ bool IsFirstActiveLane() { return IsFirstActiveLane(GetActiveLaneMask()); }

inline __device__ bool IsLastActiveLane(uint64_t activeLaneMask) {
return GetActiveLaneNum(activeLaneMask) == GetActiveLaneCount(activeLaneMask) - 1;
}

inline __device__ bool IsLastActiveLane() { return IsLastActiveLane(GetActiveLaneMask()); }

/* ---------------------------------------------------------------------------------------------- */
/* Atomic Operations */
Expand Down
14 changes: 14 additions & 0 deletions include/mori/ops/dispatch_combine/dispatch_combine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace moe {
enum KernelType {
IntraNode = 0,
InterNode = 1,
InterNodeDedup = 2,
};

inline const char* HipDataTypeToString(hipDataType dtype) {
Expand Down Expand Up @@ -84,6 +85,8 @@ struct EpDispatchCombineConfig {
// If true, use external buffer which incurs extra copy overhead; otherwise, the kernel assumes
// the provided buffer is shmemInpTokMemObj
bool useExternalInpBuffer{true};
int gpuPerNode{8};
int rdmaBlockNum{1};

inline __host__ __device__ int MaxNumTokensToSendPerRank() const { return maxNumInpTokenPerRank; }

Expand Down Expand Up @@ -215,6 +218,11 @@ class EpDispatchCombineHandle {
index_t* totalRecvTokenNum{nullptr};
mori::application::SymmMemObjPtr crossDeviceBarrierMemObj;
uint32_t crossDeviceBarrierFlag{1};

// Inter-node dedup kernel
mori::application::SymmMemObjPtr recvTokenFlagMemObj;
index_t* destNodeTokenCounter{nullptr};
mori::application::SymmMemObjPtr nodeRecvTokenNumMemObj;
};

template <typename T>
Expand Down Expand Up @@ -252,6 +260,9 @@ struct EpDispatchCombineArgs {
index_t* totalRecvTokenNum{nullptr};
mori::application::SymmMemObjPtr crossDeviceBarrierMemObj;
uint32_t crossDeviceBarrierFlag{1};
mori::application::SymmMemObjPtr recvTokenFlagMemObj;
index_t* destNodeTokenCounter{nullptr};
mori::application::SymmMemObjPtr nodeRecvTokenNumMemObj;
};

using EpDispatchCombineArgsVariant =
Expand Down Expand Up @@ -293,6 +304,9 @@ EpDispatchCombineArgs<T> GetEpDispatchCombineArgs(const EpDispatchCombineHandle&
args.totalRecvTokenNum = handle.totalRecvTokenNum;
args.crossDeviceBarrierMemObj = handle.crossDeviceBarrierMemObj;
args.crossDeviceBarrierFlag = handle.crossDeviceBarrierFlag;
args.recvTokenFlagMemObj = handle.recvTokenFlagMemObj;
args.destNodeTokenCounter = handle.destNodeTokenCounter;
args.nodeRecvTokenNumMemObj = handle.nodeRecvTokenNumMemObj;
return args;
}

Expand Down
9 changes: 8 additions & 1 deletion python/mori/ops/dispatch_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class EpDispatchCombineConfig:
block_num: int = 80
use_external_inp_buf: bool = True
kernel_type: EpDispatchCombineKernelType = EpDispatchCombineKernelType.IntraNode
gpu_per_node: int = 8
rdma_block_num: int = 0


def _cpp_dispatch_combine_factory(entity_name):
Expand All @@ -72,6 +74,8 @@ def __init__(self, config):
warp_num_per_block=config.warp_num_per_block,
block_num=config.block_num,
use_external_inp_buf=config.use_external_inp_buf,
gpu_per_node=config.gpu_per_node,
rdma_block_num=config.rdma_block_num,
)
)

Expand Down Expand Up @@ -176,7 +180,10 @@ def _allgather_with_token_num_padding(self, input, max_token_num):
def get_dispatch_src_token_pos(self):
torch.cuda.synchronize()

if self.config.kernel_type.value == EpDispatchCombineKernelType.IntraNode.value:
if self.config.kernel_type.value in (
EpDispatchCombineKernelType.IntraNode.value,
EpDispatchCombineKernelType.InterNodeDedup.value,
):
return self._get_dispatch_src_token_pos_func(self._handle)

dispatch_sender_token_id_map = self._get_dispatch_sender_token_idx_map_func(
Expand Down
4 changes: 2 additions & 2 deletions src/application/context/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ void Context::InitializePossibleTransports() {
application::RdmaEndpointConfig config;
config.portId = portId;
config.gidIdx = 3;
config.maxMsgsNum = 4096;
config.maxCqeNum = 4096;
config.maxMsgsNum = 8192;
config.maxCqeNum = 8192;
config.alignment = 4096;
config.onGpu = true;
RdmaEndpoint ep = rdmaDeviceContext->CreateRdmaEndpoint(config);
Expand Down
23 changes: 23 additions & 0 deletions src/ops/dispatch_combine/dispatch_combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mori/core/core.hpp"
#include "mori/shmem/shmem.hpp"
#include "src/ops/dispatch_combine/internode.hpp"
#include "src/ops/dispatch_combine/internode_v1.hpp"
#include "src/ops/dispatch_combine/intranode.hpp"

namespace mori {
Expand All @@ -42,6 +43,7 @@ using namespace mori::shmem;
/* EpDispatchCombineHandle */
/* ---------------------------------------------------------------------------------------------- */
EpDispatchCombineHandle::EpDispatchCombineHandle(EpDispatchCombineConfig config) : config(config) {
assert(IsPowerOf2(config.gpuPerNode) && (config.worldSize % config.gpuPerNode == 0));
InitializeShmemBuf();
InitializeTokenNumSignalBuf();
InitializeOrderMapBuf();
Expand Down Expand Up @@ -108,11 +110,16 @@ void EpDispatchCombineHandle::InitializeTokenNumSignalBuf() {

HIP_RUNTIME_CHECK(hipMalloc(&totalRecvTokenNum, sizeof(index_t)));
HIP_RUNTIME_CHECK(hipMemset(totalRecvTokenNum, 0, sizeof(index_t)));

size_t nodeTokenNumSignalSize = config.worldSize / config.gpuPerNode * sizeof(index_t);
nodeRecvTokenNumMemObj =
ShmemMallocAndReturnMemObjPtr(nodeTokenNumSignalSize, hipDeviceMallocUncached);
}

void EpDispatchCombineHandle::FinalizeTokenNumSignalBuf() {
ShmemFree(recvTokenNumMemObj->localPtr);
ShmemFree(sendTokenNumMemObj->localPtr);
ShmemFree(nodeRecvTokenNumMemObj->localPtr);
HIP_RUNTIME_CHECK(hipFree(totalRecvTokenNum));
}

Expand All @@ -133,6 +140,11 @@ void EpDispatchCombineHandle::InitializeOrderMapBuf() {
HIP_RUNTIME_CHECK(hipMalloc(&destPeTokenCounter, config.worldSize * sizeof(index_t)));
HIP_RUNTIME_CHECK(hipMemset(destPeTokenCounter, 0, config.worldSize * sizeof(index_t)));

HIP_RUNTIME_CHECK(
hipMalloc(&destNodeTokenCounter, config.worldSize / config.gpuPerNode * sizeof(index_t)));
HIP_RUNTIME_CHECK(
hipMemset(destNodeTokenCounter, 0, config.worldSize / config.gpuPerNode * sizeof(index_t)));

HIP_RUNTIME_CHECK(hipMalloc(&localPeTokenCounter, config.numExpertPerRank * sizeof(index_t)));
HIP_RUNTIME_CHECK(hipMemset(localPeTokenCounter, 0, config.numExpertPerRank * sizeof(index_t)));

Expand All @@ -150,6 +162,7 @@ void EpDispatchCombineHandle::FinalizeOrderMapBuf() {
HIP_RUNTIME_CHECK(hipFree(destPeTokenIdxMap));
HIP_RUNTIME_CHECK(hipFree(srcPeTokenIdxMap));
HIP_RUNTIME_CHECK(hipFree(destPeTokenCounter));
HIP_RUNTIME_CHECK(hipFree(destNodeTokenCounter));
HIP_RUNTIME_CHECK(hipFree(localPeTokenCounter));
ShmemFree(dispTokOffsetMemObj->localPtr);
ShmemFree(dispTokIdToSrcTokIdMemObj->localPtr);
Expand All @@ -163,12 +176,17 @@ void EpDispatchCombineHandle::InitializeBarrier() {
HIP_RUNTIME_CHECK(hipMalloc(&combineGridBarrier, barrierSize));
HIP_RUNTIME_CHECK(hipMemset(combineGridBarrier, 0, barrierSize));
crossDeviceBarrierMemObj = ShmemMallocAndReturnMemObjPtr(barrierSize, hipDeviceMallocUncached);

size_t recvTokenFlagSize =
config.worldSize / config.gpuPerNode * config.MaxNumTokensToRecvPerRank() * sizeof(index_t);
recvTokenFlagMemObj = ShmemMallocAndReturnMemObjPtr(recvTokenFlagSize, hipDeviceMallocUncached);
}

void EpDispatchCombineHandle::FinalizeBarrier() {
HIP_RUNTIME_CHECK(hipFree(dispatchGridBarrier));
HIP_RUNTIME_CHECK(hipFree(combineGridBarrier));
ShmemFree(crossDeviceBarrierMemObj->localPtr);
ShmemFree(recvTokenFlagMemObj->localPtr);
}

void EpDispatchCombineHandle::LaunchIntraNodeDispatch(int blockNum, int warpPerBlock,
Expand Down Expand Up @@ -210,6 +228,8 @@ void EpDispatchCombineHandle::LaunchDispatch(KernelType kernelType, int blockNum
if (kernelType == KernelType::InterNode) {
assert(config.useExternalInpBuffer);
EpDispatchInterNodeKernel<<<grid, block, sharedMemSize, stream>>>(args);
} else if (kernelType == KernelType::InterNodeDedup) {
EpDispatchInterNodeKernelV1<<<grid, block, sharedMemSize, stream>>>(args);
} else if (kernelType == KernelType::IntraNode) {
EpDispatchIntraNodeKernel<DataT><<<grid, block, sharedMemSize, stream>>>(args);
} else {
Expand All @@ -236,6 +256,9 @@ void EpDispatchCombineHandle::LaunchCombine(KernelType kernelType, int blockNum,
if (kernelType == KernelType::InterNode) {
assert(config.useExternalInpBuffer);
EpCombineInterNodeKernel<<<grid, block, sharedMemSize, stream>>>(args);
} else if (kernelType == KernelType::InterNodeDedup) {
assert(config.useExternalInpBuffer);
EpCombineInterNodeDedupKernel<<<grid, block, sharedMemSize, stream>>>(args);
} else if (kernelType == KernelType::IntraNode) {
EpCombineIntraNodeKernel<DataT><<<grid, block, sharedMemSize, stream>>>(args);
} else {
Expand Down
Loading