Skip to content

Commit 7432823

Browse files
committed
add batch unittest
1 parent 76d2f0a commit 7432823

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright © Advanced Micro Devices, Inc. All rights reserved.
2+
#
3+
# MIT License
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
"""Unit tests for new multi-session BatchRead/BatchWrite API.
24+
25+
This focuses on the new engine-level overloaded APIs that take vectors of:
26+
- memory descriptors (one per session)
27+
- offset lists (one list per session)
28+
- size lists (one list per session)
29+
- status pointers (returned; one per session)
30+
- transfer unique ids (one per session)
31+
32+
Existing tests in `test_engine.py` already cover the single-session batch form
33+
where a single memory descriptor pair is supplied with per-transfer offsets.
34+
Here we validate that multiple independent session pairs can be issued in a
35+
single BatchRead / BatchWrite call and each completes successfully.
36+
"""
37+
38+
import pytest
39+
import torch
40+
41+
from tests.python.utils import get_free_port
42+
from mori.io import (
43+
IOEngineConfig,
44+
BackendType,
45+
IOEngine,
46+
RdmaBackendConfig,
47+
set_log_level,
48+
)
49+
50+
51+
# -----------------------------------------------------------------------------
52+
# Helpers / Fixtures
53+
# -----------------------------------------------------------------------------
54+
55+
56+
def create_connected_engine_pair(
57+
name_prefix, qp_per_transfer=1, post_batch_size=-1, num_worker_threads=1
58+
):
59+
"""Create two RDMA-enabled IOEngines and register each other.
60+
61+
Returns (initiator, target).
62+
"""
63+
config = IOEngineConfig(host="127.0.0.1", port=get_free_port())
64+
initiator = IOEngine(key=f"{name_prefix}_initiator", config=config)
65+
config.port = get_free_port()
66+
target = IOEngine(key=f"{name_prefix}_target", config=config)
67+
68+
be_cfg = RdmaBackendConfig(
69+
qp_per_transfer=qp_per_transfer,
70+
post_batch_size=post_batch_size,
71+
num_worker_threads=num_worker_threads,
72+
)
73+
initiator.create_backend(BackendType.RDMA, be_cfg)
74+
target.create_backend(BackendType.RDMA, be_cfg)
75+
76+
initiator_desc = initiator.get_engine_desc()
77+
target_desc = target.get_engine_desc()
78+
initiator.register_remote_engine(target_desc)
79+
target.register_remote_engine(initiator_desc)
80+
81+
return initiator, target
82+
83+
84+
@pytest.fixture(scope="module")
85+
def pre_connected_engine_pair():
86+
set_log_level("info")
87+
normal = create_connected_engine_pair(
88+
"multi_normal", qp_per_transfer=2, num_worker_threads=1
89+
)
90+
multhd = create_connected_engine_pair(
91+
"multi_multhd", qp_per_transfer=2, num_worker_threads=2
92+
)
93+
engines = {
94+
"normal": normal,
95+
"multhd": multhd,
96+
}
97+
yield engines
98+
# Cleanup references (explicit deregistration not strictly necessary here)
99+
del normal, multhd
100+
101+
102+
def wait_status(status):
103+
while status.InProgress():
104+
pass
105+
106+
107+
def wait_inbound_status(engine, remote_engine_key, transfer_uid):
108+
while True:
109+
target_side_status = engine.pop_inbound_transfer_status(
110+
remote_engine_key, transfer_uid
111+
)
112+
if target_side_status:
113+
return target_side_status
114+
115+
116+
# -----------------------------------------------------------------------------
117+
# Multi-session batch tests
118+
# -----------------------------------------------------------------------------
119+
120+
121+
@pytest.mark.parametrize("engine_type", ("normal", "multhd"))
122+
@pytest.mark.parametrize("op_type", ("read", "write"))
123+
def test_multi_session_batch_read_write(
124+
pre_connected_engine_pair, engine_type, op_type
125+
):
126+
"""Issue a single multi-session BatchRead/BatchWrite with >1 memory pair.
127+
128+
Layout:
129+
- For each session i we allocate independent tensors on device0 (initiator)
130+
and device1 (target) of length BATCH_SIZE * BUFFER_SIZE bytes.
131+
- We register each tensor to obtain MemoryDesc pairs.
132+
- We build vectors of (mem, offsets[], sizes[]) per session and call
133+
engine.batch_read/write with all sessions at once.
134+
- We then wait on each returned TransferStatus and validate data movement.
135+
"""
136+
137+
initiator, target = pre_connected_engine_pair[engine_type]
138+
139+
NUM_SESSIONS = 3
140+
BATCH_SIZE = 4
141+
BUFFER_SIZE = 256 # bytes per transfer within a session
142+
TOTAL_SIZE = BATCH_SIZE * BUFFER_SIZE
143+
144+
# Allocate tensors and register memory for each session.
145+
initiator_tensors = []
146+
target_tensors = []
147+
initiator_mems = []
148+
target_mems = []
149+
150+
device0 = torch.device("cuda", 0)
151+
device1 = torch.device("cuda", 1)
152+
153+
for i in range(NUM_SESSIONS):
154+
# torch.randn does not implement a CUDA kernel for uint8 directly; generate
155+
# in float and then cast to uint8 to match existing tests' behavior.
156+
it = torch.randn(TOTAL_SIZE, device=device0).to(torch.uint8)
157+
tt = torch.randn(TOTAL_SIZE, device=device1).to(torch.uint8)
158+
initiator_tensors.append(it)
159+
target_tensors.append(tt)
160+
initiator_mems.append(initiator.register_torch_tensor(it))
161+
target_mems.append(target.register_torch_tensor(tt))
162+
163+
# Build per-session batch parameters.
164+
# Offsets inside a session: contiguous segments.
165+
per_session_offsets = [
166+
[j * BUFFER_SIZE for j in range(BATCH_SIZE)] for _ in range(NUM_SESSIONS)
167+
]
168+
per_session_sizes = [
169+
[BUFFER_SIZE for _ in range(BATCH_SIZE)] for _ in range(NUM_SESSIONS)
170+
]
171+
172+
# Allocate unique transfer IDs per session.
173+
transfer_ids = [initiator.allocate_transfer_uid() for _ in range(NUM_SESSIONS)]
174+
175+
# Call batch_read / batch_write with vectors of descriptors.
176+
if op_type == "read":
177+
# Read: localDest <- remoteSrc (initiator receives remote data)
178+
statuses = initiator.batch_read(
179+
initiator_mems,
180+
per_session_offsets,
181+
target_mems,
182+
per_session_offsets,
183+
per_session_sizes,
184+
transfer_ids,
185+
)
186+
else:
187+
statuses = initiator.batch_write(
188+
initiator_mems,
189+
per_session_offsets,
190+
target_mems,
191+
per_session_offsets,
192+
per_session_sizes,
193+
transfer_ids,
194+
)
195+
196+
assert len(statuses) == NUM_SESSIONS, "Expected one status per session"
197+
198+
initiator_key = initiator.get_engine_desc().key
199+
200+
# Wait & validate each session independently.
201+
for i in range(NUM_SESSIONS):
202+
st = statuses[i]
203+
wait_status(st)
204+
inbound = wait_inbound_status(target, initiator_key, transfer_ids[i])
205+
assert (
206+
st.Succeeded()
207+
), f"Initiator status failed for session {i}: {st.Message()}"
208+
assert (
209+
inbound.Succeeded()
210+
), f"Target status failed for session {i}: {inbound.Message()}"
211+
212+
if op_type == "read":
213+
# After read, initiator tensor should equal original target tensor.
214+
assert torch.equal(
215+
initiator_tensors[i].cpu(), target_tensors[i].cpu()
216+
), f"Data mismatch (read) on session {i}"
217+
else:
218+
# After write, target tensor should equal original initiator tensor.
219+
assert torch.equal(
220+
initiator_tensors[i].cpu(), target_tensors[i].cpu()
221+
), f"Data mismatch (write) on session {i}"
222+
223+
# Cleanup registrations.
224+
for m in initiator_mems:
225+
initiator.deregister_memory(m)
226+
for m in target_mems:
227+
target.deregister_memory(m)

0 commit comments

Comments
 (0)