|
| 1 | +# Wrapping `ComplexTensor` in `DTensor` |
| 2 | + |
| 3 | +## Mechanism |
| 4 | + |
| 5 | +The wrapping is done implicitly, with no special code in either `ComplexTensor` or `DTensor`. |
| 6 | + |
| 7 | +## Tests |
| 8 | + |
| 9 | +The `DTensor` composition tests are placed in |
| 10 | +[`src/complex_tensor/test/test_ops.py::TestComplexTensor.test_distributed`](https://github.com/openteams-ai/pytorch-complex-tensor/blob/main/src/complex_tensor/test/test_ops.py). |
| 11 | + |
| 12 | +We use the `OpInfo`s to perform the tests. We first wrap every `torch.Tensor` with a complex |
| 13 | +`dtype` inside a `ComplexTensor`. We then proceed to wrap all tensors (regular and complex) into |
| 14 | +a `torch.distributed.tensor.DTensor`. |
| 15 | + |
| 16 | +## Outstanding Issues |
| 17 | + |
| 18 | +### Missing sharding strategies |
| 19 | + |
| 20 | +Some operations are missing sharding strategies in the `DTensor` implementation, making it |
| 21 | +impossible to compose certain ops, or other ops that use them. Examples are `aten.ne`, `aten.all` |
| 22 | +and `aten.cumprod`. |
| 23 | + |
| 24 | +### Missing scalar support |
| 25 | + |
| 26 | +Most of the sharding strategies in the `DTensor` implementation assume that input tensors all |
| 27 | +have `ndim >= 1`. This makes it impossible to support scalar inputs that are generated by the |
| 28 | +`OpInfo` samples in any way. |
| 29 | + |
| 30 | +### Can Only Be Composed Inside |
| 31 | + |
| 32 | +All `ComplexTensor`s must be placed inside a `DTensor` for tests to work. The reason for this |
| 33 | +is that `aten.complex` (an op to create an interleaved `Tensor` from its real and imaginary |
| 34 | +parts) is one of the ops that doesn't have a registered sharding strategy. `aten.allclose` |
| 35 | +also doesn't work due to `aten.all` not working. |
| 36 | + |
| 37 | +Since all of these are used extensively during testing, it makes it impossible to test |
| 38 | +a `ComplexTensor` with its real and imaginary parts being a `DTensor`. We can, however, |
| 39 | +test a `DTensor` that wraps a `ComplexTensor`, as these ops are not necessarily used on |
| 40 | +a `DTensor` during those tests; it is usually gathered into a local tensor, therefore |
| 41 | +making these ops possible. |
0 commit comments