Skip to content

Commit 94f3aec

Browse files
authored
Merge pull request #18 from hameerabbasi/dtensor-tests
2 parents 7ea6574 + 5a25675 commit 94f3aec

File tree

4 files changed

+107
-9
lines changed

4 files changed

+107
-9
lines changed

src/complex_tensor/ops/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def impl(
183183
out_flat = [ComplexTensor(ui, vi) for ui, vi in zip(u_flat, v_flat, strict=False)]
184184
return tree_unflatten(out_flat, u_spec)
185185

186-
func_name = f"{str(op).split('.', 1)}_impl"
186+
func_name = f"{str(op).split('.', 1)[1]}_impl"
187187
impl.__name__ = func_name
188188
impl.__qualname__ = func_name
189189

src/complex_tensor/ops/aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def is_pinned_impl(self: ComplexTensor, device: torch.device | None = None) -> b
9292
]
9393

9494
for simple_op in SIMPLE_OPS_LIST:
95-
globals()[f"{str(simple_op).split('.', 1)}_impl"] = register_simple(simple_op)
95+
globals()[f"{str(simple_op).split('.', 1)[1]}_impl"] = register_simple(simple_op)
9696

9797
# TODO (hameerabbasi): Not being tested
9898
SIMPLE_FORCE_TESTED_OPS = [

src/complex_tensor/test/test_ops.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import torch
6+
import torch.distributed as dist
47
from torch._ops import OpOverload, OpOverloadPacket
58
from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops
69
from torch.testing._internal.common_methods_invocations import op_db
@@ -120,6 +123,54 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
120123
Descriptor(op=aten.dot, variant=Variant.GradCheck): "Numerical inconsistency",
121124
Descriptor(op=aten.mul, variant=Variant.GradCheck): "Numerical inconsistency",
122125
Descriptor(op=aten.exp, variant=Variant.GradCheck): "Numerical inconsistency",
126+
Descriptor(
127+
op=aten.any, variant=Variant.Distributed
128+
): "does not have a sharding strategy registered",
129+
Descriptor(
130+
op=aten.all, variant=Variant.Distributed
131+
): "does not have a sharding strategy registered",
132+
Descriptor(
133+
op=aten.allclose, variant=Variant.Distributed
134+
): "does not have a sharding strategy registered",
135+
Descriptor(
136+
op=aten.conj_physical, variant=Variant.Distributed
137+
): "does not have a sharding strategy registered",
138+
Descriptor(
139+
op=aten._conj_physical, variant=Variant.Distributed
140+
): "does not have a sharding strategy registered",
141+
Descriptor(
142+
op=aten.cumprod, variant=Variant.Distributed
143+
): "does not have a sharding strategy registered",
144+
Descriptor(
145+
op=aten.index_add, variant=Variant.Distributed
146+
): "does not have a sharding strategy registered",
147+
Descriptor(
148+
op=aten.diagonal_scatter, variant=Variant.Distributed
149+
): "does not have a sharding strategy registered",
150+
Descriptor(
151+
op=aten.flip, variant=Variant.Distributed
152+
): "does not have a sharding strategy registered",
153+
Descriptor(
154+
op=aten.masked_fill, variant=Variant.Distributed
155+
): "does not have a sharding strategy registered",
156+
Descriptor(
157+
op=aten.masked_scatter, variant=Variant.Distributed
158+
): "does not have a sharding strategy registered",
159+
Descriptor(
160+
op=aten.rsub, variant=Variant.Distributed
161+
): "does not have a sharding strategy registered",
162+
Descriptor(
163+
op=aten.ne, variant=Variant.Distributed
164+
): "does not have a sharding strategy registered",
165+
Descriptor(
166+
op=aten.squeeze, variant=Variant.Distributed
167+
): "does not have a sharding strategy registered",
168+
Descriptor(op=aten.index_select, variant=Variant.Distributed): "Sharding propagation failed",
169+
Descriptor(op=aten.real, variant=Variant.Distributed): "No scalar support",
170+
Descriptor(op=aten.imag, variant=Variant.Distributed): "No scalar support",
171+
Descriptor(op=aten.isfinite, variant=Variant.Distributed): "No scalar support",
172+
Descriptor(op=aten.transpose, variant=Variant.Distributed): "No scalar support",
173+
Descriptor(op=aten.view_as_real, variant=Variant.Distributed): "No scalar support",
123174
}
124175

125176
EXTRA_KWARGS = {
@@ -135,29 +186,65 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
135186
"rtol": 2e-2,
136187
"atol": 2e-6,
137188
},
189+
Descriptor(op=aten.asinh, dtype=torch.complex64, variant=Variant.Distributed): {
190+
"rtol": 2e-5,
191+
"atol": 5e-5,
192+
},
193+
Descriptor(op=aten.tanh, dtype=torch.complex64, variant=Variant.Distributed): {
194+
"rtol": 1e-4,
195+
"atol": 1e-5,
196+
},
197+
Descriptor(op=aten.pow, dtype=torch.complex64, variant=Variant.Distributed): {
198+
"rtol": 2e-2,
199+
"atol": 2e-6,
200+
},
201+
Descriptor(op=aten.tan, dtype=torch.complex64, variant=Variant.Distributed): {
202+
"rtol": 2e-6,
203+
"atol": 1e-2,
204+
},
138205
}
139206

207+
STORE = dist.HashStore()
208+
dist.init_process_group(store=STORE, rank=0, world_size=1)
209+
DEVICE_MESH = dist.init_device_mesh("cpu", mesh_shape=(1,))
210+
211+
212+
def _as_complex_dtensor(arg: torch.Tensor | Any) -> torch.Tensor | Any:
213+
if not isinstance(arg, torch.Tensor):
214+
return arg
215+
216+
return dist.tensor.DTensor.from_local(_as_complex_tensor(arg), device_mesh=DEVICE_MESH)
217+
218+
219+
TRANSFORM_FUNCS = {Variant.Op: _as_complex_tensor, Variant.Distributed: _as_complex_dtensor}
220+
140221

141222
class TestComplexTensor(TestCase):
142223
_default_dtype_check_enabled = True
143224

144225
@parametrize("compile", [False, True])
145226
@ops(implemented_op_db, dtypes=OpDTypes.supported, allowed_dtypes=list(COMPLEX_DTYPES))
146227
def test_consistency(self, device, dtype, op: OpInfo, compile: bool):
147-
self.check_consistency(device, dtype, op, compile)
228+
self.check_consistency(device, dtype, op, compile, Variant.Op)
148229

149230
@parametrize("compile", [False, True])
150231
@ops(force_test_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
151232
def test_maybe_error(self, device, dtype, op: OpInfo, compile: bool):
152-
self.check_consistency(device, dtype, op, compile)
233+
self.check_consistency(device, dtype, op, compile, Variant.Op)
153234

154-
def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bool) -> None:
235+
@ops(implemented_op_db, allowed_dtypes=list(COMPLEX_DTYPES))
236+
def test_distributed(self, device, dtype, op: OpInfo):
237+
self.check_consistency(device, dtype, op, False, Variant.Distributed)
238+
239+
def check_consistency(
240+
self, device: torch.device, dtype, op: OpInfo, compile: bool, variant: Variant
241+
) -> None:
155242
test_info = Descriptor(
156243
op=get_overload_packet_from_name(op.name),
157244
device=device,
158245
dtype=dtype,
159246
compile=compile,
160-
variant=Variant.Op,
247+
variant=variant,
161248
)
162249
for xfail_info, reason in SKIPS.items():
163250
if xfail_info.matches(test_info):
@@ -174,12 +261,14 @@ def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bo
174261
if compile:
175262
op = torch.compile(op, fullgraph=True)
176263

264+
transform_fn = TRANSFORM_FUNCS[variant]
265+
177266
for sample_input in sample_inputs:
178267

179268
def expected(sample_input=sample_input):
180269
return op_eager(sample_input.input, *sample_input.args, **sample_input.kwargs)
181270

182-
subclass_sample = sample_input.transform(_as_complex_tensor)
271+
subclass_sample = sample_input.transform(transform_fn)
183272

184273
def actual(subclass_sample=subclass_sample):
185274
return op(subclass_sample.input, *subclass_sample.args, **subclass_sample.kwargs)

src/complex_tensor/test/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77

88
import torch
9+
import torch.distributed as dist
910
from torch._ops import OpOverloadPacket
1011
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
1112
from torch.utils._pytree import tree_flatten
@@ -18,6 +19,14 @@
1819
class Variant(Enum):
1920
Op = auto()
2021
GradCheck = auto()
22+
Distributed = auto()
23+
24+
25+
def _as_local(arg: dist.tensor.DTensor | Any) -> torch.Tensor | Any:
26+
if not isinstance(arg, dist.tensor.DTensor):
27+
return arg
28+
29+
return arg.full_tensor()
2130

2231

2332
@dataclass(frozen=True, kw_only=True)
@@ -79,7 +88,7 @@ def assertSameResult(
7988
spec_e, spec_a, "Both functions must return a result with the same tree structure."
8089
)
8190
for value_e, value_a in zip(flattened_e, flattened_a, strict=True):
82-
value_e = _as_interleaved(value_e)
83-
value_a = _as_interleaved(value_a)
91+
value_e = _as_interleaved(_as_local(value_e))
92+
value_a = _as_interleaved(_as_local(value_a))
8493

8594
self.assertEqual(value_e, value_a, *args, **kwargs)

0 commit comments

Comments
 (0)