A dead simple Python package for creating custom JAX pytree objects.
- Strives to be minimal, the implementation is just ~100 lines of code
- Has no dependencies other than JAX
- Its compatible with both
dataclassesand regular classes - It has no intention of supporting Neural Network use cases (e.g. partitioning)
pip install simple-pytreeimport jax
from simple_pytree import Pytree
class Foo(Pytree):
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)
assert foo.x == -1 and foo.y == -2You can mark fields as static by assigning static_field() to a class attribute with the same name
as the instance attribute:
import jax
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2Static fields are not included in the pytree leaves, they are passed as pytree metadata instead.
simple_pytree provides a dataclass decorator you can use with classes
that contain static_fields:
import jax
from simple_pytree import Pytree, dataclass, static_field
@dataclass
class Foo(Pytree):
x: int
y: int = static_field(default=2)
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified
assert foo.x == -1 and foo.y == 2simple_pytree.dataclass is just a wrapper around dataclasses.dataclass but
when used static analysis tools and IDEs will understand that static_field is a
field specifier just like dataclasses.field.
Pytree objects are immutable by default after __init__:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # AttributeErrorIf you want to make them mutable, you can use the mutable argument in class definition:
from simple_pytree import Pytree, static_field
class Foo(Pytree, mutable=True):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo.x = 3 # OKIf you want to make a copy of a Pytree object with some fields modified, you can use the .replace() method:
from simple_pytree import Pytree, static_field
class Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = y
foo = Foo(1, 2)
foo = foo.replace(x=10)
assert foo.x == 10 and foo.y == 2replace works for both mutable and immutable Pytree objects. If the class
is a dataclass, replace internally use dataclasses.replace.