diff --git a/examples/io/example.py b/examples/io/example.py index 658c851c..13461b39 100644 --- a/examples/io/example.py +++ b/examples/io/example.py @@ -77,13 +77,13 @@ def batch_read_write_example(initiator, target, size, batch_size): # Perform batch p2p transfer from initiator to target transfer_uid = initiator.allocate_transfer_uid() transfer_status = initiator.batch_read( - initiator_mem, - local_offsets, - target_mem, - remote_offsets, - sizes, - transfer_uid, - ) + [initiator_mem], + [local_offsets], + [target_mem], + [remote_offsets], + [sizes], + [transfer_uid], + )[0] while transfer_status.InProgress(): pass diff --git a/include/mori/io/common.hpp b/include/mori/io/common.hpp index cb5a9c2a..5970bbc5 100644 --- a/include/mori/io/common.hpp +++ b/include/mori/io/common.hpp @@ -129,6 +129,10 @@ struct TransferStatus { }; using SizeVec = std::vector; +using MemDescVec = std::vector; +using BatchSizeVec = std::vector; +using TransferUniqueIdVec = std::vector; +using TransferStatusPtrVec = std::vector; } // namespace io } // namespace mori diff --git a/include/mori/io/engine.hpp b/include/mori/io/engine.hpp index 06f9373e..142f8eee 100644 --- a/include/mori/io/engine.hpp +++ b/include/mori/io/engine.hpp @@ -93,12 +93,13 @@ class IOEngine { void Write(const MemoryDesc& localSrc, size_t localOffset, const MemoryDesc& remoteDest, size_t remoteOffset, size_t size, TransferStatus* status, TransferUniqueId id); - void BatchRead(const MemoryDesc& localDest, const SizeVec& localOffsets, - const MemoryDesc& remoteSrc, const SizeVec& remoteOffsets, const SizeVec& sizes, - TransferStatus* status, TransferUniqueId id); - void BatchWrite(const MemoryDesc& localSrc, const SizeVec& localOffsets, - const MemoryDesc& remoteDest, const SizeVec& remoteOffsets, const SizeVec& sizes, - TransferStatus* status, TransferUniqueId id); + void BatchRead(const MemDescVec& localDest, const BatchSizeVec& localOffsets, + const MemDescVec& remoteSrc, const BatchSizeVec& remoteOffsets, + const BatchSizeVec& sizes, TransferStatusPtrVec& status, TransferUniqueIdVec& ids); + void BatchWrite(const MemDescVec& localSrc, const BatchSizeVec& localOffsets, + const MemDescVec& remoteDest, const BatchSizeVec& remoteOffsets, + const BatchSizeVec& sizes, TransferStatusPtrVec& status, + TransferUniqueIdVec& ids); // Take the transfer status of an inbound op bool PopInboundTransferStatus(EngineKey remote, TransferUniqueId id, TransferStatus* status); diff --git a/python/mori/io/engine.py b/python/mori/io/engine.py index af3d7af4..8de4e309 100644 --- a/python/mori/io/engine.py +++ b/python/mori/io/engine.py @@ -160,7 +160,7 @@ def _batch_single_side_transfer( sizes, transfer_uid, ): - transfer_status = mori_cpp.TransferStatus() + transfer_status = [mori_cpp.TransferStatus() for _ in range(len(sizes))] func( local_dest_mem_desc, local_offsets, diff --git a/src/io/engine.cpp b/src/io/engine.cpp index 0f645761..79257581 100644 --- a/src/io/engine.cpp +++ b/src/io/engine.cpp @@ -194,28 +194,53 @@ void IOEngine::Write(const MemoryDesc& localSrc, size_t localOffset, const Memor } } -void IOEngine::BatchRead(const MemoryDesc& localDest, const SizeVec& localOffsets, - const MemoryDesc& remoteSrc, const SizeVec& remoteOffsets, - const SizeVec& sizes, TransferStatus* status, TransferUniqueId id) { +void IOEngine::BatchRead(const MemDescVec& localDest, const BatchSizeVec& localOffsets, + const MemDescVec& remoteSrc, const BatchSizeVec& remoteOffsets, + const BatchSizeVec& sizes, TransferStatusPtrVec& status, + TransferUniqueIdVec& ids) { MORI_IO_FUNCTION_TIMER; - Backend* backend = nullptr; - SELECT_BACKEND_AND_RETURN_IF_NONE(localDest, remoteSrc, status, backend); - backend->BatchRead(localDest, localOffsets, remoteSrc, remoteOffsets, sizes, status, id); - if (status->Failed()) { - MORI_IO_ERROR("Engine batch read error {} message {}", status->CodeUint32(), status->Message()); + size_t batchSize = localDest.size(); + assert(batchSize == remoteSrc.size()); + assert(batchSize == localOffsets.size()); + assert(batchSize == remoteOffsets.size()); + assert(batchSize == sizes.size()); + assert(batchSize == status.size()); + assert(batchSize == ids.size()); + + for (size_t i = 0; i < batchSize; i++) { + Backend* backend = nullptr; + SELECT_BACKEND_AND_RETURN_IF_NONE(localDest[i], remoteSrc[i], status[i], backend); + backend->BatchRead(localDest[i], localOffsets[i], remoteSrc[i], remoteOffsets[i], sizes[i], + status[i], ids[i]); + if (status[i]->Failed()) { + MORI_IO_ERROR("Engine batch read error {} message {}", status[i]->CodeUint32(), + status[i]->Message()); + } } } -void IOEngine::BatchWrite(const MemoryDesc& localSrc, const SizeVec& localOffsets, - const MemoryDesc& remoteDest, const SizeVec& remoteOffsets, - const SizeVec& sizes, TransferStatus* status, TransferUniqueId id) { +void IOEngine::BatchWrite(const MemDescVec& localSrc, const BatchSizeVec& localOffsets, + const MemDescVec& remoteDest, const BatchSizeVec& remoteOffsets, + const BatchSizeVec& sizes, TransferStatusPtrVec& status, + TransferUniqueIdVec& ids) { MORI_IO_FUNCTION_TIMER; - Backend* backend = nullptr; - SELECT_BACKEND_AND_RETURN_IF_NONE(localSrc, remoteDest, status, backend); - backend->BatchWrite(localSrc, localOffsets, remoteDest, remoteOffsets, sizes, status, id); - if (status->Failed()) { - MORI_IO_ERROR("Engine batch write error {} message {}", status->CodeUint32(), - status->Message()); + size_t batchSize = localSrc.size(); + assert(batchSize == remoteDest.size()); + assert(batchSize == localOffsets.size()); + assert(batchSize == remoteOffsets.size()); + assert(batchSize == sizes.size()); + assert(batchSize == status.size()); + assert(batchSize == ids.size()); + + for (size_t i = 0; i < batchSize; i++) { + Backend* backend = nullptr; + SELECT_BACKEND_AND_RETURN_IF_NONE(localSrc[i], remoteDest[i], status[i], backend); + backend->BatchWrite(localSrc[i], localOffsets[i], remoteDest[i], remoteOffsets[i], sizes[i], + status[i], ids[i]); + if (status[i]->Failed()) { + MORI_IO_ERROR("Engine batch write error {} message {}", status[i]->CodeUint32(), + status[i]->Message()); + } } } diff --git a/tests/python/io/benchmark.py b/tests/python/io/benchmark.py index c5daf80d..21dcc05d 100644 --- a/tests/python/io/benchmark.py +++ b/tests/python/io/benchmark.py @@ -361,15 +361,15 @@ def run_batch_once(self, buffer_size, transfer_batch_size): else self.engine.batch_write ) args = ( - self.mem, - offsets, - self.target_mem, - offsets, - sizes, - transfer_uid, + [self.mem], + [offsets], + [self.target_mem], + [offsets], + [sizes], + [transfer_uid], ) st = time.time() - transfer_status = func(*args) + transfer_status = func(*args)[0] transfer_status.Wait() duration = time.time() - st assert transfer_status.Succeeded() diff --git a/tests/python/io/stress_test.py b/tests/python/io/stress_test.py index 67d68bc4..117c585e 100644 --- a/tests/python/io/stress_test.py +++ b/tests/python/io/stress_test.py @@ -427,7 +427,14 @@ def _do_batch_ops( st = func(offsets_src, offsets_dst, sizes, uid) else: func = self.engine.batch_read if is_read else self.engine.batch_write - st = func(self.mem, offsets_src, self.target_mem, offsets_dst, sizes, uid) + st = func( + [self.mem], + [offsets_src], + [self.target_mem], + [offsets_dst], + [sizes], + [uid], + )[0] while st.InProgress(): time.sleep(0.00005) diff --git a/tests/python/io/test_engine.py b/tests/python/io/test_engine.py index 4a17c4db..0c374b56 100644 --- a/tests/python/io/test_engine.py +++ b/tests/python/io/test_engine.py @@ -252,8 +252,13 @@ def test_rdma_backend_ops( transfer_uid = initiator.allocate_transfer_uid() func = initiator.batch_read if op_type == "read" else initiator.batch_write transfer_status = func( - initiator_mem, offsets, target_mem, offsets, sizes, transfer_uid - ) + [initiator_mem], + [offsets], + [target_mem], + [offsets], + [sizes], + [transfer_uid], + )[0] uid_status_list.append((transfer_uid, transfer_status)) else: for i in range(batch_size): @@ -334,8 +339,8 @@ def test_no_backend(): transfer_uid = initiator.allocate_transfer_uid() transfer_status = initiator.batch_read( - initiator_mem, offsets, target_mem, offsets, sizes, transfer_uid - ) + [initiator_mem], [offsets], [target_mem], [offsets], [sizes], [transfer_uid] + )[0] assert transfer_status.Failed() assert transfer_status.Code() == StatusCode.ERR_BAD_STATE diff --git a/tests/python/io/test_engine_multi_session_batch.py b/tests/python/io/test_engine_multi_session_batch.py new file mode 100644 index 00000000..b382f128 --- /dev/null +++ b/tests/python/io/test_engine_multi_session_batch.py @@ -0,0 +1,225 @@ +# Copyright © Advanced Micro Devices, Inc. All rights reserved. +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Unit tests for new multi-session BatchRead/BatchWrite API. + +This focuses on the new engine-level overloaded APIs that take vectors of: + - memory descriptors (one per session) + - offset lists (one list per session) + - size lists (one list per session) + - status pointers (returned; one per session) + - transfer unique ids (one per session) + +Existing tests in `test_engine.py` already cover the single-session batch form +where a single memory descriptor pair is supplied with per-transfer offsets. +Here we validate that multiple independent session pairs can be issued in a +single BatchRead / BatchWrite call and each completes successfully. +""" + +import pytest +import torch + +from tests.python.utils import get_free_port +from mori.io import ( + IOEngineConfig, + BackendType, + IOEngine, + RdmaBackendConfig, + set_log_level, +) + + +# ----------------------------------------------------------------------------- +# Helpers / Fixtures +# ----------------------------------------------------------------------------- + + +def create_connected_engine_pair( + name_prefix, qp_per_transfer=1, post_batch_size=-1, num_worker_threads=1 +): + """Create two RDMA-enabled IOEngines and register each other. + + Returns (initiator, target). + """ + config = IOEngineConfig(host="127.0.0.1", port=get_free_port()) + initiator = IOEngine(key=f"{name_prefix}_initiator", config=config) + config.port = get_free_port() + target = IOEngine(key=f"{name_prefix}_target", config=config) + + be_cfg = RdmaBackendConfig( + qp_per_transfer=qp_per_transfer, + post_batch_size=post_batch_size, + num_worker_threads=num_worker_threads, + ) + initiator.create_backend(BackendType.RDMA, be_cfg) + target.create_backend(BackendType.RDMA, be_cfg) + + initiator_desc = initiator.get_engine_desc() + target_desc = target.get_engine_desc() + initiator.register_remote_engine(target_desc) + target.register_remote_engine(initiator_desc) + + return initiator, target + + +@pytest.fixture(scope="module") +def pre_connected_engine_pair(): + set_log_level("info") + normal = create_connected_engine_pair( + "multi_normal", qp_per_transfer=2, num_worker_threads=1 + ) + multhd = create_connected_engine_pair( + "multi_multhd", qp_per_transfer=2, num_worker_threads=2 + ) + engines = { + "normal": normal, + "multhd": multhd, + } + yield engines + # Cleanup references (explicit deregistration not strictly necessary here) + del normal, multhd + + +def wait_status(status): + while status.InProgress(): + pass + + +def wait_inbound_status(engine, remote_engine_key, transfer_uid): + while True: + target_side_status = engine.pop_inbound_transfer_status( + remote_engine_key, transfer_uid + ) + if target_side_status: + return target_side_status + + +# ----------------------------------------------------------------------------- +# Multi-session batch tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.parametrize("engine_type", ("normal", "multhd")) +@pytest.mark.parametrize("op_type", ("read", "write")) +def test_multi_session_batch_read_write( + pre_connected_engine_pair, engine_type, op_type +): + """Issue a single multi-session BatchRead/BatchWrite with >1 memory pair. + + Layout: + - For each session i we allocate independent tensors on device0 (initiator) + and device1 (target) of length BATCH_SIZE * BUFFER_SIZE bytes. + - We register each tensor to obtain MemoryDesc pairs. + - We build vectors of (mem, offsets[], sizes[]) per session and call + engine.batch_read/write with all sessions at once. + - We then wait on each returned TransferStatus and validate data movement. + """ + + initiator, target = pre_connected_engine_pair[engine_type] + + NUM_SESSIONS = 3 + BATCH_SIZE = 4 + BUFFER_SIZE = 256 # bytes per transfer within a session + TOTAL_SIZE = BATCH_SIZE * BUFFER_SIZE + + # Allocate tensors and register memory for each session. + initiator_tensors = [] + target_tensors = [] + initiator_mems = [] + target_mems = [] + + device0 = torch.device("cuda", 0) + device1 = torch.device("cuda", 1) + + for i in range(NUM_SESSIONS): + it = torch.randn(TOTAL_SIZE).to(device0, dtype=torch.uint8) + tt = torch.randn(TOTAL_SIZE).to(device1, dtype=torch.uint8) + initiator_tensors.append(it) + target_tensors.append(tt) + initiator_mems.append(initiator.register_torch_tensor(it)) + target_mems.append(target.register_torch_tensor(tt)) + + # Build per-session batch parameters. + # Offsets inside a session: contiguous segments. + per_session_offsets = [ + [j * BUFFER_SIZE for j in range(BATCH_SIZE)] for _ in range(NUM_SESSIONS) + ] + per_session_sizes = [ + [BUFFER_SIZE for _ in range(BATCH_SIZE)] for _ in range(NUM_SESSIONS) + ] + + # Allocate unique transfer IDs per session. + transfer_ids = [initiator.allocate_transfer_uid() for _ in range(NUM_SESSIONS)] + + # Call batch_read / batch_write with vectors of descriptors. + if op_type == "read": + # Read: localDest <- remoteSrc (initiator receives remote data) + statuses = initiator.batch_read( + initiator_mems, + per_session_offsets, + target_mems, + per_session_offsets, + per_session_sizes, + transfer_ids, + ) + else: + statuses = initiator.batch_write( + initiator_mems, + per_session_offsets, + target_mems, + per_session_offsets, + per_session_sizes, + transfer_ids, + ) + + assert len(statuses) == NUM_SESSIONS, "Expected one status per session" + + initiator_key = initiator.get_engine_desc().key + + # Wait & validate each session independently. + for i in range(NUM_SESSIONS): + st = statuses[i] + wait_status(st) + inbound = wait_inbound_status(target, initiator_key, transfer_ids[i]) + assert ( + st.Succeeded() + ), f"Initiator status failed for session {i}: {st.Message()}" + assert ( + inbound.Succeeded() + ), f"Target status failed for session {i}: {inbound.Message()}" + + if op_type == "read": + # After read, initiator tensor should equal original target tensor. + assert torch.equal( + initiator_tensors[i].cpu(), target_tensors[i].cpu() + ), f"Data mismatch (read) on session {i}" + else: + # After write, target tensor should equal original initiator tensor. + assert torch.equal( + initiator_tensors[i].cpu(), target_tensors[i].cpu() + ), f"Data mismatch (write) on session {i}" + + # Cleanup registrations. + for m in initiator_mems: + initiator.deregister_memory(m) + for m in target_mems: + target.deregister_memory(m)