Skip to content

Commit 16d06a4

Browse files
committed
Update README.md with some development information.
1 parent 0021343 commit 16d06a4

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

README.md

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,25 @@
1-
# pytorch-complex-tensor
2-
Complex Tensor Subclass
1+
# `complex-tensor`
2+
3+
Subclass of `torch.Tensor` for working with complex numbers.
4+
5+
# Development Setup
6+
7+
For now, the development setup uses the CPU version of PyTorch only.
8+
9+
1. [Install `uv`](https://docs.astral.sh/uv/getting-started/installation/)
10+
2. Run the tests with `uv run pytest -n auto` to get started.
11+
* This will also create a virtual environment in `.venv/`.
12+
13+
# Repository Structure
14+
15+
* The main `torch.Tensor` subclass is found in [`src/complex_tensor/complex_tensor.py`](https://github.com/openteams-ai/pytorch-complex-tensor/blob/main/src/complex_tensor/complex_tensor.py).
16+
* Operations are implemented in the [`src/complex_tensor/ops/`](https://github.com/openteams-ai/pytorch-complex-tensor/tree/main/src/complex_tensor/ops) directory.
17+
* [`_common.py`](https://github.com/openteams-ai/pytorch-complex-tensor/blob/main/src/complex_tensor/ops/_common.py) defines some basic utility functions.
18+
* [`aten.py`](https://github.com/openteams-ai/pytorch-complex-tensor/blob/main/src/complex_tensor/ops/aten.py) defines overloads for `torch.ops.aten`.
19+
* [`prims.py`](https://github.com/openteams-ai/pytorch-complex-tensor/blob/main/src/complex_tensor/ops/aten.py) does the same for `torch.ops.prims`.
20+
* Currently, this directory is empty.
21+
* Tests are located in [`src/complex_tensor/test`](https://github.com/openteams-ai/pytorch-complex-tensor/tree/main/src/complex_tensor/test).
22+
* Testing currently needs to be expanded; currently only tests which provide `OpInfo`s in `torch.testing._internal.common_methods_invocations.op_db` are tested.
23+
* Exceptions are noted in-tree with a `TODO`.
24+
25+
This repository is currently WIP, which means not all ops are implemented, but many common ones are.

src/complex_tensor/ops/prims.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import torch
2+
3+
from ..complex_tensor import ComplexTensor
4+
from ._common import (
5+
complex_to_real_dtype,
6+
register_force_test,
7+
split_complex_tensor,
8+
)
9+
10+
prims = torch.ops.prims
11+
12+
13+
@register_force_test(prims.convert_element_type)
14+
def convert_element_type_impl(x: ComplexTensor, dtype: torch.dtype) -> ComplexTensor:
15+
dtype = complex_to_real_dtype(dtype)
16+
u, v = split_complex_tensor(x)
17+
u_out = prims.convert_element_type(u, dtype)
18+
v_out = prims.convert_element_type(v, dtype)
19+
20+
return ComplexTensor(u_out, v_out)

0 commit comments

Comments
 (0)