diff --git a/src/openpi/shared/array_typing.py b/src/openpi/shared/array_typing.py index 3c9593072..db399ae5c 100644 --- a/src/openpi/shared/array_typing.py +++ b/src/openpi/shared/array_typing.py @@ -1,6 +1,7 @@ import contextlib import functools as ft import inspect +import sys from typing import TypeAlias, TypeVar, cast import beartype @@ -30,13 +31,30 @@ def _check_dataclass_annotations(self, typechecker): - if not any( - frame.frame.f_globals["__name__"] in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} - for frame in inspect.stack() - ): - return _original_check_dataclass_annotations(self, typechecker) - return None - + #if not any( + # frame.frame.f_globals["__name__"] in {"jax._src.tree_util", "flax.nnx.transforms.compilation"} + # for frame in inspect.stack() + #): + # return _original_check_dataclass_annotations(self, typechecker) + #return None + try: + # get caller of caller + frame = sys._getframe(2) + + # limit walkback to 20 frames + for _ in range(20): + # skip typechecking for JAX internal calls + if frame.f_globals.get("__name__") in {"jax._src.tree_util", "flax.nnx.transforms.compilation"}: + return None + + if frame.f_back is None: + break + + frame = frame.f_back + except (ValueError, AttributeError): + pass + + return _original_check_dataclass_annotations(self, typechecker) jaxtyping._decorator._check_dataclass_annotations = _check_dataclass_annotations # noqa: SLF001