Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/io/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def batch_read_write_example(initiator, target, size, batch_size):
# Perform batch p2p transfer from initiator to target
transfer_uid = initiator.allocate_transfer_uid()
transfer_status = initiator.batch_read(
initiator_mem,
local_offsets,
target_mem,
remote_offsets,
sizes,
transfer_uid,
)
[initiator_mem],
[local_offsets],
[target_mem],
[remote_offsets],
[sizes],
[transfer_uid],
)[0]

while transfer_status.InProgress():
pass
Expand Down
4 changes: 4 additions & 0 deletions include/mori/io/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ struct TransferStatus {
};

using SizeVec = std::vector<size_t>;
using MemDescVec = std::vector<MemoryDesc>;
using BatchSizeVec = std::vector<SizeVec>;
using TransferUniqueIdVec = std::vector<TransferUniqueId>;
using TransferStatusPtrVec = std::vector<TransferStatus*>;

} // namespace io
} // namespace mori
13 changes: 7 additions & 6 deletions include/mori/io/engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ class IOEngine {
void Write(const MemoryDesc& localSrc, size_t localOffset, const MemoryDesc& remoteDest,
size_t remoteOffset, size_t size, TransferStatus* status, TransferUniqueId id);

void BatchRead(const MemoryDesc& localDest, const SizeVec& localOffsets,
const MemoryDesc& remoteSrc, const SizeVec& remoteOffsets, const SizeVec& sizes,
TransferStatus* status, TransferUniqueId id);
void BatchWrite(const MemoryDesc& localSrc, const SizeVec& localOffsets,
const MemoryDesc& remoteDest, const SizeVec& remoteOffsets, const SizeVec& sizes,
TransferStatus* status, TransferUniqueId id);
void BatchRead(const MemDescVec& localDest, const BatchSizeVec& localOffsets,
const MemDescVec& remoteSrc, const BatchSizeVec& remoteOffsets,
const BatchSizeVec& sizes, TransferStatusPtrVec& status, TransferUniqueIdVec& ids);
void BatchWrite(const MemDescVec& localSrc, const BatchSizeVec& localOffsets,
const MemDescVec& remoteDest, const BatchSizeVec& remoteOffsets,
const BatchSizeVec& sizes, TransferStatusPtrVec& status,
TransferUniqueIdVec& ids);
// Take the transfer status of an inbound op
bool PopInboundTransferStatus(EngineKey remote, TransferUniqueId id, TransferStatus* status);

Expand Down
2 changes: 1 addition & 1 deletion python/mori/io/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _batch_single_side_transfer(
sizes,
transfer_uid,
):
transfer_status = mori_cpp.TransferStatus()
transfer_status = [mori_cpp.TransferStatus() for _ in range(len(sizes))]
func(
local_dest_mem_desc,
local_offsets,
Expand Down
59 changes: 42 additions & 17 deletions src/io/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,28 +194,53 @@ void IOEngine::Write(const MemoryDesc& localSrc, size_t localOffset, const Memor
}
}

void IOEngine::BatchRead(const MemoryDesc& localDest, const SizeVec& localOffsets,
const MemoryDesc& remoteSrc, const SizeVec& remoteOffsets,
const SizeVec& sizes, TransferStatus* status, TransferUniqueId id) {
void IOEngine::BatchRead(const MemDescVec& localDest, const BatchSizeVec& localOffsets,
const MemDescVec& remoteSrc, const BatchSizeVec& remoteOffsets,
const BatchSizeVec& sizes, TransferStatusPtrVec& status,
TransferUniqueIdVec& ids) {
MORI_IO_FUNCTION_TIMER;
Backend* backend = nullptr;
SELECT_BACKEND_AND_RETURN_IF_NONE(localDest, remoteSrc, status, backend);
backend->BatchRead(localDest, localOffsets, remoteSrc, remoteOffsets, sizes, status, id);
if (status->Failed()) {
MORI_IO_ERROR("Engine batch read error {} message {}", status->CodeUint32(), status->Message());
size_t batchSize = localDest.size();
assert(batchSize == remoteSrc.size());
assert(batchSize == localOffsets.size());
assert(batchSize == remoteOffsets.size());
assert(batchSize == sizes.size());
assert(batchSize == status.size());
assert(batchSize == ids.size());

for (size_t i = 0; i < batchSize; i++) {
Backend* backend = nullptr;
SELECT_BACKEND_AND_RETURN_IF_NONE(localDest[i], remoteSrc[i], status[i], backend);
backend->BatchRead(localDest[i], localOffsets[i], remoteSrc[i], remoteOffsets[i], sizes[i],
status[i], ids[i]);
if (status[i]->Failed()) {
MORI_IO_ERROR("Engine batch read error {} message {}", status[i]->CodeUint32(),
status[i]->Message());
}
}
}

void IOEngine::BatchWrite(const MemoryDesc& localSrc, const SizeVec& localOffsets,
const MemoryDesc& remoteDest, const SizeVec& remoteOffsets,
const SizeVec& sizes, TransferStatus* status, TransferUniqueId id) {
void IOEngine::BatchWrite(const MemDescVec& localSrc, const BatchSizeVec& localOffsets,
const MemDescVec& remoteDest, const BatchSizeVec& remoteOffsets,
const BatchSizeVec& sizes, TransferStatusPtrVec& status,
TransferUniqueIdVec& ids) {
MORI_IO_FUNCTION_TIMER;
Backend* backend = nullptr;
SELECT_BACKEND_AND_RETURN_IF_NONE(localSrc, remoteDest, status, backend);
backend->BatchWrite(localSrc, localOffsets, remoteDest, remoteOffsets, sizes, status, id);
if (status->Failed()) {
MORI_IO_ERROR("Engine batch write error {} message {}", status->CodeUint32(),
status->Message());
size_t batchSize = localSrc.size();
assert(batchSize == remoteDest.size());
assert(batchSize == localOffsets.size());
assert(batchSize == remoteOffsets.size());
assert(batchSize == sizes.size());
assert(batchSize == status.size());
assert(batchSize == ids.size());

for (size_t i = 0; i < batchSize; i++) {
Backend* backend = nullptr;
SELECT_BACKEND_AND_RETURN_IF_NONE(localSrc[i], remoteDest[i], status[i], backend);
backend->BatchWrite(localSrc[i], localOffsets[i], remoteDest[i], remoteOffsets[i], sizes[i],
status[i], ids[i]);
if (status[i]->Failed()) {
MORI_IO_ERROR("Engine batch write error {} message {}", status[i]->CodeUint32(),
status[i]->Message());
}
}
}

Expand Down
14 changes: 7 additions & 7 deletions tests/python/io/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,15 +361,15 @@ def run_batch_once(self, buffer_size, transfer_batch_size):
else self.engine.batch_write
)
args = (
self.mem,
offsets,
self.target_mem,
offsets,
sizes,
transfer_uid,
[self.mem],
[offsets],
[self.target_mem],
[offsets],
[sizes],
[transfer_uid],
)
st = time.time()
transfer_status = func(*args)
transfer_status = func(*args)[0]
transfer_status.Wait()
duration = time.time() - st
assert transfer_status.Succeeded()
Expand Down
9 changes: 8 additions & 1 deletion tests/python/io/stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,14 @@ def _do_batch_ops(
st = func(offsets_src, offsets_dst, sizes, uid)
else:
func = self.engine.batch_read if is_read else self.engine.batch_write
st = func(self.mem, offsets_src, self.target_mem, offsets_dst, sizes, uid)
st = func(
[self.mem],
[offsets_src],
[self.target_mem],
[offsets_dst],
[sizes],
[uid],
)[0]

while st.InProgress():
time.sleep(0.00005)
Expand Down
13 changes: 9 additions & 4 deletions tests/python/io/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,13 @@ def test_rdma_backend_ops(
transfer_uid = initiator.allocate_transfer_uid()
func = initiator.batch_read if op_type == "read" else initiator.batch_write
transfer_status = func(
initiator_mem, offsets, target_mem, offsets, sizes, transfer_uid
)
[initiator_mem],
[offsets],
[target_mem],
[offsets],
[sizes],
[transfer_uid],
)[0]
uid_status_list.append((transfer_uid, transfer_status))
else:
for i in range(batch_size):
Expand Down Expand Up @@ -334,8 +339,8 @@ def test_no_backend():

transfer_uid = initiator.allocate_transfer_uid()
transfer_status = initiator.batch_read(
initiator_mem, offsets, target_mem, offsets, sizes, transfer_uid
)
[initiator_mem], [offsets], [target_mem], [offsets], [sizes], [transfer_uid]
)[0]

assert transfer_status.Failed()
assert transfer_status.Code() == StatusCode.ERR_BAD_STATE
Expand Down
Loading