Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ cmake_minimum_required(VERSION 3.19)

set(CMAKE_CXX_FLAGS_DEBUG "-g -O0")
set(CMAKE_CXX_FLAGS_RELEASE "-O3")
set(CMAKE_CXX_FLAGS_O3DEBUG "-ggdb -O3")

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down
304 changes: 166 additions & 138 deletions examples/dist_rdma_ops/dist_write.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def run_test_once(self, op, test_data, error_round, round):
)
# TODO: test output scales

if self.config.rank == 0:
print("Dispatch Pass")
if self.rank % self.gpu_per_node == 0:
print(f"Node {self.rank // self.gpu_per_node} Dispatch Pass")

dist.barrier()

Expand Down Expand Up @@ -311,8 +311,8 @@ def run_test_once(self, op, test_data, error_round, round):
)
assert weight_match, f"Weight assertion failed for token {i}"

if self.config.rank == 0:
print("Combine Pass")
if self.rank % self.gpu_per_node == 0:
print(f"Node {self.rank // self.gpu_per_node} Combine Pass")

def test_dispatch_combine(self):
op = mori.ops.EpDispatchCombineOp(self.config)
Expand Down
89 changes: 66 additions & 23 deletions examples/shmem/atomic_fetch_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ __global__ void AtomicFetchThreadKernel(int myPe, const SymmMemObjPtr memObj) {
int threadOffset = globalTid * sizeof(T);

if (myPe == sendPe) {
RdmaMemoryRegion source = memObj->GetRdmaMemoryRegion(sendPe);

ShmemAtomicTypeFetchThread<T>(memObj, threadOffset, source, threadOffset, sendPe, recvPe,
recvPe, AMO_FETCH_AND);
T ret = ShmemAtomicTypeFetchThread<T>(memObj, 2 * sizeof(T), 1, 0, AMO_FETCH_ADD, recvPe);
__threadfence_system();
if (ret == gridDim.x * blockDim.x) {
printf("globalTid: %d ret = %lu atomic fetch is ok!~\n", globalTid, (uint64_t)ret);
}

ShmemQuietThread();
// __syncthreads();
} else {
while (atomicAdd(reinterpret_cast<T*>(memObj->localPtr) + globalTid, 0) != sendPe) {
while (AtomicLoadRelaxed(reinterpret_cast<T*>(memObj->localPtr) + 2) !=
gridDim.x * blockDim.x + 1) {
}
if (globalTid == 0) {
printf("atomic fetch is ok!~\n");
Expand Down Expand Up @@ -97,24 +97,67 @@ void testAtomicFetchThread() {
SymmMemObjPtr buffObj = ShmemQueryMemObjPtr(buff);
assert(buffObj.IsValid());

// Run uint64 atomic nonfetch
AtomicFetchThreadKernel<uint64_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff) + numEle - 1));
for (int iteration = 0; iteration < 10; iteration++) {
if (myPe == 0) {
printf("========== Iteration %d ==========\n", iteration + 1);
}

buffSize = numEle * sizeof(uint32_t);
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<uint32_t*>(buff), myPe, numEle));
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff) + numEle - 1));
// Run uint32 atomic nonfetch
AtomicFetchThreadKernel<uint32_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff) + numEle - 1));
// Run uint64 atomic nonfetch
myHipMemsetD64(buff, myPe, numEle);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff)));
AtomicFetchThreadKernel<uint64_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff) + 2));

// Test int64_t atomic nonfetch
buffSize = numEle * sizeof(int64_t);
myHipMemsetD64(buff, myPe, numEle);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast<int64_t*>(buff)),
*(reinterpret_cast<int64_t*>(buff)));
// Run int64 atomic nonfetch
AtomicFetchThreadKernel<int64_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast<int64_t*>(buff)),
*(reinterpret_cast<int64_t*>(buff) + 2));

// Test uint32_t atomic nonfetch
buffSize = numEle * sizeof(uint32_t);
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<uint32_t*>(buff), myPe, numEle));
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff)));
// Run uint32 atomic nonfetch
AtomicFetchThreadKernel<uint32_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff) + 2));

// Test int32_t atomic nonfetch
buffSize = numEle * sizeof(int32_t);
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<int32_t*>(buff), myPe, numEle));
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast<int32_t*>(buff)),
*(reinterpret_cast<int32_t*>(buff)));
// Run int32 atomic nonfetch
AtomicFetchThreadKernel<int32_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast<int32_t*>(buff)),
*(reinterpret_cast<int32_t*>(buff) + 2));

MPI_Barrier(MPI_COMM_WORLD); // Ensure all processes complete this iteration before next
if (myPe == 0) {
printf("Iteration %d completed\n", iteration + 1);
}
sleep(1);
}

// Finalize
ShmemFree(buff);
Expand Down
94 changes: 71 additions & 23 deletions examples/shmem/atomic_nonfetch_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,25 @@ __global__ void AtomicNonFetchThreadKernel(int myPe, const SymmMemObjPtr memObj)
constexpr int recvPe = 1;

int globalTid = blockIdx.x * blockDim.x + threadIdx.x;
int globalWarpId = globalTid / warpSize;
int threadOffset = globalTid * sizeof(T);

if (myPe == sendPe) {
RdmaMemoryRegion source = memObj->GetRdmaMemoryRegion(sendPe);

// ShmemAtomicUint64NonFetchThread(memObj, threadOffset, sendPe, recvPe, AMO_SET);
ShmemAtomicTypeNonFetchThread<T>(memObj, threadOffset, source, threadOffset, sendPe, recvPe,
AMO_SET);
// ShmemAtomicUint64NonFetchThread(memObj, threadOffset, sendPe, AMO_SET, recvPe);
if (globalWarpId % 2 == 0) {
ShmemAtomicTypeNonFetchThread<T>(memObj, 2 * sizeof(T), 1, AMO_ADD, recvPe);
} else {
ShmemAtomicTypeNonFetchThread<T>(memObj, 2 * sizeof(T), 1, AMO_ADD, recvPe, 1);
}
__threadfence_system();

ShmemQuietThread();
// __syncthreads();
} else {
while (atomicAdd(reinterpret_cast<T*>(memObj->localPtr) + globalTid, 0) != sendPe) {
while (AtomicLoadRelaxed(reinterpret_cast<T*>(memObj->localPtr) + 2) !=
gridDim.x * blockDim.x + 1) {
}
if (globalTid == 0) {
printf("atomic nonfetch is ok!~\n");
Expand All @@ -90,32 +95,75 @@ void testAtomicNonFetchThread() {
int numEle = threadNum * blockNum;
int buffSize = numEle * sizeof(uint64_t);

void* buff = ShmemMalloc(buffSize);
void* buff = ShmemExtMallocWithFlags(buffSize, hipDeviceMallocUncached);
myHipMemsetD64(buff, myPe, numEle);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff) + numEle - 1));
*(reinterpret_cast<uint64_t*>(buff)));
SymmMemObjPtr buffObj = ShmemQueryMemObjPtr(buff);
assert(buffObj.IsValid());

// Run uint64 atomic nonfetch
AtomicNonFetchThreadKernel<uint64_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff) + numEle - 1));
// Run atomic operations 10 times
for (int iteration = 0; iteration < 10; iteration++) {
if (myPe == 0) {
printf("========== Iteration %d ==========\n", iteration + 1);
}

buffSize = numEle * sizeof(uint32_t);
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<uint32_t*>(buff), myPe, numEle));
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff) + numEle - 1));
// Run uint32 atomic nonfetch
AtomicNonFetchThreadKernel<uint32_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff) + numEle - 1));
// Run uint64 atomic nonfetch
myHipMemsetD64(buff, myPe, numEle);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff)));
AtomicNonFetchThreadKernel<uint64_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] uint64: %lu %lu\n", myPe, *(reinterpret_cast<uint64_t*>(buff)),
*(reinterpret_cast<uint64_t*>(buff) + 2));

// Test int64_t atomic nonfetch
buffSize = numEle * sizeof(int64_t);
myHipMemsetD64(buff, myPe, numEle);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast<int64_t*>(buff)),
*(reinterpret_cast<int64_t*>(buff)));
// Run int64 atomic nonfetch
AtomicNonFetchThreadKernel<int64_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] int64: %ld %ld\n", myPe, *(reinterpret_cast<int64_t*>(buff)),
*(reinterpret_cast<int64_t*>(buff) + 2));

// Test uint32_t atomic nonfetch
buffSize = numEle * sizeof(uint32_t);
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<uint32_t*>(buff), myPe, numEle));
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff)));
// Run uint32 atomic nonfetch
AtomicNonFetchThreadKernel<uint32_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] uint32: %u %u\n", myPe, *(reinterpret_cast<uint32_t*>(buff)),
*(reinterpret_cast<uint32_t*>(buff) + 2));

// Test int32_t atomic nonfetch
buffSize = numEle * sizeof(int32_t);
HIP_RUNTIME_CHECK(hipMemsetD32(reinterpret_cast<int32_t*>(buff), myPe, numEle));
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
printf("before rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast<int32_t*>(buff)),
*(reinterpret_cast<int32_t*>(buff)));
// Run int32 atomic nonfetch
AtomicNonFetchThreadKernel<int32_t><<<blockNum, threadNum>>>(myPe, buffObj);
HIP_RUNTIME_CHECK(hipDeviceSynchronize());
MPI_Barrier(MPI_COMM_WORLD);
printf("after rank[%d] int32: %d %d\n", myPe, *(reinterpret_cast<int32_t*>(buff)),
*(reinterpret_cast<int32_t*>(buff) + 2));
MPI_Barrier(MPI_COMM_WORLD); // Ensure all processes complete this iteration before next
if (myPe == 0) {
printf("Iteration %d completed\n", iteration + 1);
}
sleep(1);
}

// Finalize
ShmemFree(buff);
Expand Down
2 changes: 1 addition & 1 deletion examples/shmem/concurrent_put_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ __global__ void ConcurrentPutThreadKernel(int myPe, const SymmMemObjPtr memObj)
if (myPe == sendPe) {
RdmaMemoryRegion source = memObj->GetRdmaMemoryRegion(myPe);

ShmemPutMemNbiThread(memObj, threadOffset, source, threadOffset, sizeof(uint32_t), recvPe);
ShmemPutMemNbiThread(memObj, threadOffset, source, threadOffset, sizeof(uint32_t), recvPe, 1);
__threadfence_system();

if (blockIdx.x == 0)
Expand Down
12 changes: 8 additions & 4 deletions examples/utils/args_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,11 @@ void BenchmarkConfig::readArgs(int argc, char** argv) {
{"scope", required_argument, 0, 's'},
{"atomic_op", required_argument, 0, 'a'},
{"stride", required_argument, 0, 'i'},
{"num_qp", required_argument, 0, 'q'},
{0, 0, 0, 0}};

int opt_idx = 0;
while ((c = getopt_long(argc, argv, "hb:e:f:n:w:c:t:d:o:s:a:i:", long_opts, &opt_idx)) != -1) {
while ((c = getopt_long(argc, argv, "hb:e:f:n:w:c:t:d:o:s:a:i:q:", long_opts, &opt_idx)) != -1) {
switch (c) {
case 'h':
printf(
Expand Down Expand Up @@ -244,6 +245,9 @@ void BenchmarkConfig::readArgs(int argc, char** argv) {
case 'a':
atomicOpParse(optarg);
break;
case 'q':
atolScaled(optarg, &num_qp);
break;
case '?':
if (optopt == 'c') {
fprintf(stderr, "Option -%c requires an argument.\n", optopt);
Expand All @@ -269,13 +273,13 @@ void BenchmarkConfig::readArgs(int argc, char** argv) {

printf("Runtime options after parsing command line arguments \n");
printf(
"min_size: %zu, max_size: %zu, step_factor: %zu, iterations: %zu, "
"min_size: %zu, max_size: %zu, num_qp: %zu, step_factor: %zu, iterations: %zu, "
"warmup iterations: %zu, number of ctas: %zu, threads per cta: %zu "
"stride: %zu, datatype: %s, reduce_op: %s, threadgroup_scope: %s, "
"atomic_op: %s, dir: %s, report_msgrate: %d, bidirectional: %d, "
"putget_issue: %s\n",
min_size, max_size, step_factor, iters, warmup_iters, num_blocks, threads_per_block, stride,
datatype.name.c_str(), reduce_op.name.c_str(), threadgroup_scope.name.c_str(),
min_size, max_size, num_qp, step_factor, iters, warmup_iters, num_blocks, threads_per_block,
stride, datatype.name.c_str(), reduce_op.name.c_str(), threadgroup_scope.name.c_str(),
test_amo.name.c_str(), dir.name.c_str(), report_msgrate, bidirectional,
putget_issue.name.c_str());
printf(
Expand Down
2 changes: 2 additions & 0 deletions examples/utils/args_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class BenchmarkConfig {
size_t getStepFactor() const { return step_factor; }
size_t getMaxSizeLog() const { return max_size_log; }
size_t getStride() const { return stride; }
size_t getNumQp() const { return num_qp; }

bool isBidirectional() const { return bidirectional; }
bool isReportMsgrate() const { return report_msgrate; }
Expand All @@ -130,6 +131,7 @@ class BenchmarkConfig {
size_t step_factor = 2;
size_t max_size_log = 1;
size_t stride = 1;
size_t num_qp = 1;

bool bidirectional = false;
bool report_msgrate = false;
Expand Down
2 changes: 2 additions & 0 deletions include/mori/application/context/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Context {

TransportType GetTransportType(int destRank) const { return transportTypes[destRank]; }
std::vector<TransportType> GetTransportTypes() const { return transportTypes; }
int GetNumQpPerPe() const { return numQpPerPe; }

RdmaContext* GetRdmaContext() const { return rdmaContext.get(); }
RdmaDeviceContext* GetRdmaDeviceContext() const { return rdmaDeviceContext.get(); }
Expand All @@ -57,6 +58,7 @@ class Context {
private:
BootstrapNetwork& bootNet;
int rankInNode{-1};
int numQpPerPe{4};
std::vector<std::string> hostnames;
std::vector<TransportType> transportTypes;

Expand Down
Loading