diff --git a/pyglove/core/symbolic/contextual_object.py b/pyglove/core/symbolic/contextual_object.py index 892a1ad..a856a87 100644 --- a/pyglove/core/symbolic/contextual_object.py +++ b/pyglove/core/symbolic/contextual_object.py @@ -46,7 +46,7 @@ class A(pg.ContextualObject): """ import threading -from typing import Annotated, Any, ContextManager, Dict, Optional, Type +from typing import Annotated, Any, ClassVar, ContextManager, Dict, Optional, Type from pyglove.core import utils as pg_utils from pyglove.core.symbolic import base from pyglove.core.symbolic import inferred as pg_inferred @@ -109,13 +109,13 @@ def foo(a): """ # Override __repr__ format to use inferred values when available. - __repr_format_kwargs__ = dict( + __repr_format_kwargs__: ClassVar[Dict[str, Any]] = dict( compact=True, use_inferred=True, ) # Override __str__ format to use inferred values when available. - __str_format_kwargs__ = dict( + __str_format_kwargs__: ClassVar[Dict[str, Any]] = dict( compact=False, verbose=False, use_inferred=True, diff --git a/pyglove/core/symbolic/object.py b/pyglove/core/symbolic/object.py index 65bf0bf..d3e78eb 100644 --- a/pyglove/core/symbolic/object.py +++ b/pyglove/core/symbolic/object.py @@ -28,6 +28,14 @@ from pyglove.core.symbolic import flags +if sys.version_info >= (3, 11): + _dataclass_transform = typing.dataclass_transform +else: + def _dataclass_transform(): + return lambda cls: cls + + +@_dataclass_transform() class ObjectMeta(abc.ABCMeta): """Meta class for pg.Object.""" @@ -153,6 +161,8 @@ def _infer_fields_from_annotations(cls) -> List[pg_typing.Field]: if attr_name == '__kwargs__': # __kwargs__ is speical annotation for enabling keyword arguments. key = pg_typing.StrKey() + if typing.get_origin(attr_annotation) is typing.ClassVar: + attr_annotation = typing.get_args(attr_annotation)[0] elif not attr_name.isupper() and not attr_name.startswith('_'): key = pg_typing.ConstStrKey(attr_name) else: diff --git a/pyglove/core/utils/formatting.py b/pyglove/core/utils/formatting.py index b133b62..6aeff31 100644 --- a/pyglove/core/utils/formatting.py +++ b/pyglove/core/utils/formatting.py @@ -17,7 +17,7 @@ import enum import io import sys -from typing import Any, Callable, ContextManager, Dict, List, Optional, Sequence, Set, Tuple +from typing import Any, Callable, ClassVar, ContextManager, Dict, List, Optional, Sequence, Set, Tuple from pyglove.core.utils import thread_local @@ -55,10 +55,12 @@ class Formattable(metaclass=abc.ABCMeta): """ # Additional format keyword arguments for `__str__`. - __str_format_kwargs__ = dict(compact=False, verbose=True) + __str_format_kwargs__: ClassVar[Dict[str, Any]] = dict( + compact=False, verbose=True + ) # Additional format keyword arguments for `__repr__`. - __repr_format_kwargs__ = dict(compact=True) + __repr_format_kwargs__: ClassVar[Dict[str, Any]] = dict(compact=True) @abc.abstractmethod def format(self,