Skip to content

Commit a4bc7a8

Browse files
authored
Handle classes with __new__ method (#10)
* handle classes with __new__ method * fix release action
1 parent 4190cf2 commit a4bc7a8

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

.github/workflows/publish-package.yml

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,19 @@ jobs:
4040
run: |
4141
poetry install --without dev
4242
43-
- name: Build Docs 🔨
44-
run: |
45-
cp README.md docs/index.md
46-
poetry run mkdocs build
43+
# ----------------------------------------
44+
# No docs for now
45+
# ----------------------------------------
46+
# - name: Build Docs 🔨
47+
# run: |
48+
# cp README.md docs/index.md
49+
# poetry run mkdocs build
4750

48-
- name: Deploy Page 🚀
49-
uses: JamesIves/[email protected]
50-
with:
51-
branch: gh-pages
52-
folder: site
51+
# - name: Deploy Page 🚀
52+
# uses: JamesIves/[email protected]
53+
# with:
54+
# branch: gh-pages
55+
# folder: site
5356

5457
- name: Publish to PyPI
5558
run: |

simple_pytree/pytree.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ def static_field(
6666

6767

6868
class PytreeMeta(ABCMeta):
69-
def __call__(self: tp.Type[P], *args: tp.Any, **kwds: tp.Any) -> P:
70-
obj: P = self.__new__(self)
69+
def __call__(self: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
70+
obj: P = self.__new__(self, *args, **kwargs)
7171
obj.__dict__["_pytree__initializing"] = True
7272
try:
73-
obj.__init__(*args, **kwds)
73+
obj.__init__(*args, **kwargs)
7474
finally:
7575
del obj.__dict__["_pytree__initializing"]
7676
return obj
@@ -172,7 +172,7 @@ def _pytree__unflatten(
172172
) -> P:
173173
node_names, static_fields = metadata
174174
node_fields = dict(zip(node_names, node_values))
175-
pytree = cls.__new__(cls)
175+
pytree = object.__new__(cls)
176176
pytree.__dict__.update(node_fields, **dict(static_fields))
177177
return pytree
178178

tests/test_pytree.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,18 @@ class B(A):
184184
leaves = jax.tree_util.tree_leaves(pytree)
185185
assert leaves == [1, 3]
186186

187+
def test_pytree_with_new(self):
188+
class A(Pytree):
189+
def __init__(self, a):
190+
self.a = a
191+
192+
def __new__(cls, a):
193+
return super().__new__(cls)
194+
195+
pytree = A(a=1)
196+
197+
pytree = jax.tree_map(lambda x: x * 2, pytree)
198+
187199

188200
class TestMutablePytree:
189201
def test_pytree(self):

0 commit comments

Comments
 (0)