11from __future__ import annotations
22
3+ from typing import Any
4+
35import torch
6+ import torch .distributed as dist
47from torch ._ops import OpOverload , OpOverloadPacket
58from torch .testing ._internal .common_device_type import OpDTypes , instantiate_device_type_tests , ops
69from torch .testing ._internal .common_methods_invocations import op_db
@@ -120,6 +123,54 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
120123 Descriptor (op = aten .dot , variant = Variant .GradCheck ): "Numerical inconsistency" ,
121124 Descriptor (op = aten .mul , variant = Variant .GradCheck ): "Numerical inconsistency" ,
122125 Descriptor (op = aten .exp , variant = Variant .GradCheck ): "Numerical inconsistency" ,
126+ Descriptor (
127+ op = aten .any , variant = Variant .Distributed
128+ ): "does not have a sharding strategy registered" ,
129+ Descriptor (
130+ op = aten .all , variant = Variant .Distributed
131+ ): "does not have a sharding strategy registered" ,
132+ Descriptor (
133+ op = aten .allclose , variant = Variant .Distributed
134+ ): "does not have a sharding strategy registered" ,
135+ Descriptor (
136+ op = aten .conj_physical , variant = Variant .Distributed
137+ ): "does not have a sharding strategy registered" ,
138+ Descriptor (
139+ op = aten ._conj_physical , variant = Variant .Distributed
140+ ): "does not have a sharding strategy registered" ,
141+ Descriptor (
142+ op = aten .cumprod , variant = Variant .Distributed
143+ ): "does not have a sharding strategy registered" ,
144+ Descriptor (
145+ op = aten .index_add , variant = Variant .Distributed
146+ ): "does not have a sharding strategy registered" ,
147+ Descriptor (
148+ op = aten .diagonal_scatter , variant = Variant .Distributed
149+ ): "does not have a sharding strategy registered" ,
150+ Descriptor (
151+ op = aten .flip , variant = Variant .Distributed
152+ ): "does not have a sharding strategy registered" ,
153+ Descriptor (
154+ op = aten .masked_fill , variant = Variant .Distributed
155+ ): "does not have a sharding strategy registered" ,
156+ Descriptor (
157+ op = aten .masked_scatter , variant = Variant .Distributed
158+ ): "does not have a sharding strategy registered" ,
159+ Descriptor (
160+ op = aten .rsub , variant = Variant .Distributed
161+ ): "does not have a sharding strategy registered" ,
162+ Descriptor (
163+ op = aten .ne , variant = Variant .Distributed
164+ ): "does not have a sharding strategy registered" ,
165+ Descriptor (
166+ op = aten .squeeze , variant = Variant .Distributed
167+ ): "does not have a sharding strategy registered" ,
168+ Descriptor (op = aten .index_select , variant = Variant .Distributed ): "Sharding propagation failed" ,
169+ Descriptor (op = aten .real , variant = Variant .Distributed ): "No scalar support" ,
170+ Descriptor (op = aten .imag , variant = Variant .Distributed ): "No scalar support" ,
171+ Descriptor (op = aten .isfinite , variant = Variant .Distributed ): "No scalar support" ,
172+ Descriptor (op = aten .transpose , variant = Variant .Distributed ): "No scalar support" ,
173+ Descriptor (op = aten .view_as_real , variant = Variant .Distributed ): "No scalar support" ,
123174}
124175
125176EXTRA_KWARGS = {
@@ -135,29 +186,65 @@ def get_overload_packet_from_name(name: str) -> OpOverloadPacket:
135186 "rtol" : 2e-2 ,
136187 "atol" : 2e-6 ,
137188 },
189+ Descriptor (op = aten .asinh , dtype = torch .complex64 , variant = Variant .Distributed ): {
190+ "rtol" : 2e-5 ,
191+ "atol" : 5e-5 ,
192+ },
193+ Descriptor (op = aten .tanh , dtype = torch .complex64 , variant = Variant .Distributed ): {
194+ "rtol" : 1e-4 ,
195+ "atol" : 1e-5 ,
196+ },
197+ Descriptor (op = aten .pow , dtype = torch .complex64 , variant = Variant .Distributed ): {
198+ "rtol" : 2e-2 ,
199+ "atol" : 2e-6 ,
200+ },
201+ Descriptor (op = aten .tan , dtype = torch .complex64 , variant = Variant .Distributed ): {
202+ "rtol" : 2e-6 ,
203+ "atol" : 1e-2 ,
204+ },
138205}
139206
207+ STORE = dist .HashStore ()
208+ dist .init_process_group (store = STORE , rank = 0 , world_size = 1 )
209+ DEVICE_MESH = dist .init_device_mesh ("cpu" , mesh_shape = (1 ,))
210+
211+
212+ def _as_complex_dtensor (arg : torch .Tensor | Any ) -> torch .Tensor | Any :
213+ if not isinstance (arg , torch .Tensor ):
214+ return arg
215+
216+ return dist .tensor .DTensor .from_local (_as_complex_tensor (arg ), device_mesh = DEVICE_MESH )
217+
218+
219+ TRANSFORM_FUNCS = {Variant .Op : _as_complex_tensor , Variant .Distributed : _as_complex_dtensor }
220+
140221
141222class TestComplexTensor (TestCase ):
142223 _default_dtype_check_enabled = True
143224
144225 @parametrize ("compile" , [False , True ])
145226 @ops (implemented_op_db , dtypes = OpDTypes .supported , allowed_dtypes = list (COMPLEX_DTYPES ))
146227 def test_consistency (self , device , dtype , op : OpInfo , compile : bool ):
147- self .check_consistency (device , dtype , op , compile )
228+ self .check_consistency (device , dtype , op , compile , Variant . Op )
148229
149230 @parametrize ("compile" , [False , True ])
150231 @ops (force_test_op_db , allowed_dtypes = list (COMPLEX_DTYPES ))
151232 def test_maybe_error (self , device , dtype , op : OpInfo , compile : bool ):
152- self .check_consistency (device , dtype , op , compile )
233+ self .check_consistency (device , dtype , op , compile , Variant . Op )
153234
154- def check_consistency (self , device : torch .device , dtype , op : OpInfo , compile : bool ) -> None :
235+ @ops (implemented_op_db , allowed_dtypes = list (COMPLEX_DTYPES ))
236+ def test_distributed (self , device , dtype , op : OpInfo ):
237+ self .check_consistency (device , dtype , op , False , Variant .Distributed )
238+
239+ def check_consistency (
240+ self , device : torch .device , dtype , op : OpInfo , compile : bool , variant : Variant
241+ ) -> None :
155242 test_info = Descriptor (
156243 op = get_overload_packet_from_name (op .name ),
157244 device = device ,
158245 dtype = dtype ,
159246 compile = compile ,
160- variant = Variant . Op ,
247+ variant = variant ,
161248 )
162249 for xfail_info , reason in SKIPS .items ():
163250 if xfail_info .matches (test_info ):
@@ -174,12 +261,14 @@ def check_consistency(self, device: torch.device, dtype, op: OpInfo, compile: bo
174261 if compile :
175262 op = torch .compile (op , fullgraph = True )
176263
264+ transform_fn = TRANSFORM_FUNCS [variant ]
265+
177266 for sample_input in sample_inputs :
178267
179268 def expected (sample_input = sample_input ):
180269 return op_eager (sample_input .input , * sample_input .args , ** sample_input .kwargs )
181270
182- subclass_sample = sample_input .transform (_as_complex_tensor )
271+ subclass_sample = sample_input .transform (transform_fn )
183272
184273 def actual (subclass_sample = subclass_sample ):
185274 return op (subclass_sample .input , * subclass_sample .args , ** subclass_sample .kwargs )
0 commit comments