From fdbd53e752fcdde028ba68996402aa73f2c23cd6 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 18 Aug 2025 15:20:42 -0700 Subject: [PATCH] Refactor TorchAOBaseTensor for better BC support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Summary: After this PR, tensors inheriting from TorchAOBaseTensor will have better support BC, that is if they add some optional tensor data attribute or optional non-tensor attribute, we will still have BC without any additional changes. More Details: The BC story we are looking at is that, after we land some tensor, e.g. Int4Tensor, Float8Tensor, future changes should only add optional Tensor data attributes and optional non-Tensor attributes to the Tensor (other bigger changes will require a version bump, we need to add that too). The current TorchAOBaseTensor doesn’t support this very well. also see https://github.com/pytorch/ao/pull/2840 for a real test that adds both an optional tensor and optional non-tensor attribute to Float8Tensor, and the BC test in https://github.com/pytorch/ao/blob/main/test/integration/test_load_and_run_checkpoint.py that tests Float8Tensor does not fail. Docs for current TorchAOBaseTensor: https://github.com/pytorch/ao/blob/e6b38bb0e1477ae6aaca0a3d30de70598be43290/torchao/utils.py#L726-L731 `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match the `__init__` list of tensor subclass `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` Problems: current optional_tensor_data_names is not truly optional, since it is followed by tensor_attribute_names which contains both required and optional attributes. So if we add a tensor data attribute to Tensor, it will break BC. Here are a few options: ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, act_scale=None, block_size=None, shape=None, _demo_only_optional_attr=None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] optional_tensor_data_names = ["act_scale"] required_tensor_attribute_names = ["block_size", "shape"] optional_tensor_attribute_names = ["_demo_only_optional_attr"] def __init__(self, qdata, scale, zero_point, block_size, shape, act_scale=None, _demo_only_optional_attr = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` ``` class Int4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale", "zero_point"] tensor_attribute_names = ["block_size", "shape", "_demo_only_optional_attr"] optional_tensor_data_names = ["act_scale"] def __init__(self, qdata, scale, zero_point, block_size, shape, _demo_only_optional_attr = None, act_scale = None): ... # for BC def __setstate__(self, state): torch._utils._set_obj_state(self, state) if "act_scale" not in self.__dict__: self.act_scale = None ``` Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: https://github.com/pytorch/ao/pull/2793, branch: jerryzh168/stack/29 --- test/test_utils.py | 65 +++++-- .../workflows/float8/float8_tensor.py | 21 ++- .../workflows/int4/int4_preshuffled_tensor.py | 26 +-- torchao/utils.py | 175 ++++++++++++++---- 4 files changed, 215 insertions(+), 72 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index c5bbf45a96..b5d26432ff 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -186,60 +186,103 @@ class MyTensor(TorchAOBaseTensor): tensor_data_names = ["qdata"] tensor_attribute_names = ["attr", "device"] - def __new__(cls, qdata, attr, device=None): + def __new__(cls, qdata, attr, device): shape = qdata.shape if device is None: device = qdata.device kwargs = {"device": device} return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, qdata, attr, device=None): + def __init__(self, qdata, attr, device): self.qdata = qdata self.attr = attr l = torch.nn.Linear(2, 3) - l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr")) + l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None)) lp_tensor = l.weight another_tensor = torch.nn.Linear(2, 3).weight # attribute has to be the same - lp_tensor_for_copy = MyTensor(another_tensor, "attr") + lp_tensor_for_copy = MyTensor(another_tensor, "attr", None) self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) @skip_if_no_cuda() def test_default_impls_with_optional_data(self): class MyTensorWithOptionalData(TorchAOBaseTensor): tensor_data_names = ["qdata"] - optional_tensor_data_names = ["zero_point"] tensor_attribute_names = ["attr", "device"] + optional_tensor_data_names = ["zero_point"] - def __new__(cls, qdata, zero_point=None, attr=1.0, device=None): + def __new__(cls, qdata, attr, device, zero_point=None): shape = qdata.shape if device is None: device = qdata.device kwargs = {"device": device} return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, qdata, zero_point=None, attr=1.0, device=None): + def __init__(self, qdata, attr, device, zero_point=None): self.qdata = qdata + self.attr = attr self.zero_point = zero_point + + # test both the optional Tensor is None + # and not None + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + l = torch.nn.Linear(2, 3) + lp_tensor = MyTensorWithOptionalData( + l.weight, "attr", None, torch.zeros_like(l.weight) + ) + l = torch.nn.Linear(2, 3) + lp_tensor_for_copy = MyTensorWithOptionalData( + l.weight, "attr", None, torch.zeros_like(l.weight) + ) + self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) + + @skip_if_no_cuda() + def test_default_impls_with_optional_attr(self): + class MyTensorWithOptionalData(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr", "device"] + optional_tensor_data_names = ["zero_point"] + optional_tensor_attribute_names = ["optional_attr"] + + def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None): + shape = qdata.shape + if device is None: + device = qdata.device + kwargs = {"device": device} + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, qdata, attr, device, zero_point=None, optional_attr=None + ): + self.qdata = qdata self.attr = attr + self.zero_point = zero_point + self.optional_attr = optional_attr # test both the optional Tensor is None # and not None l = torch.nn.Linear(2, 3) - lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr") + lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None) l = torch.nn.Linear(2, 3) - lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr") + lp_tensor_for_copy = MyTensorWithOptionalData( + l.weight, "attr", None, zero_point=None + ) self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) l = torch.nn.Linear(2, 3) lp_tensor = MyTensorWithOptionalData( - l.weight, torch.zeros_like(l.weight), "attr" + l.weight, "attr", None, zero_point=None, optional_attr="value" ) l = torch.nn.Linear(2, 3) lp_tensor_for_copy = MyTensorWithOptionalData( - l.weight, torch.zeros_like(l.weight), "attr" + l.weight, "attr", None, zero_point=None, optional_attr="value" ) self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 7726b2094c..ff295ae200 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -94,7 +94,8 @@ class Float8Tensor(TorchAOBaseTensor): """ tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = [ + tensor_attribute_names = [] + optional_tensor_attribute_names = [ "block_size", "mm_config", "hp_value_lb", @@ -106,15 +107,15 @@ class Float8Tensor(TorchAOBaseTensor): def __new__( cls, - qdata, - scale, - block_size, - mm_config, - hp_value_lb, - hp_value_ub, - act_quant_kwargs, - kernel_preference, - dtype, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: Optional[List[int]] = None, + mm_config: Optional[Float8MMConfig] = None, + hp_value_lb: Optional[float] = None, + hp_value_ub: Optional[float] = None, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + kernel_preference: KernelPreference = KernelPreference.AUTO, + dtype: Optional[torch.dtype] = None, ): shape = qdata.shape kwargs = {} diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index 50cf261642..7310d975de 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -75,17 +75,17 @@ class Int4PreshuffledTensor(TorchAOBaseTensor): """ tensor_data_names = ["qdata", "group_scale"] - optional_tensor_data_names = ["group_zero", "row_scale"] tensor_attribute_names = ["block_size", "shape"] + optional_tensor_data_names = ["group_zero", "row_scale"] def __new__( cls, - qdata, - group_scale, - group_zero, - row_scale, - block_size, - shape, + qdata: torch.Tensor, + group_scale: torch.Tensor, + block_size: List[int], + shape: List[int], + group_zero: Optional[torch.Tensor] = None, + row_scale: Optional[torch.Tensor] = None, ): kwargs = {} kwargs["device"] = qdata.device @@ -97,19 +97,19 @@ def __init__( self, qdata: torch.Tensor, group_scale: torch.Tensor, - group_zero: Optional[torch.Tensor], - row_scale: Optional[torch.Tensor], block_size: List[int], shape: List[int], + group_zero: Optional[torch.Tensor] = None, + row_scale: Optional[torch.Tensor] = None, ): # one and only one of group_scale and group_zero should be None assert group_zero is None or row_scale is None assert not (group_zero is not None and row_scale is not None) self.qdata = qdata - self.group_scale = group_scale - self.group_zero = group_zero self.row_scale = row_scale self.block_size = block_size + self.group_scale = group_scale + self.group_zero = group_zero def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" @@ -178,10 +178,10 @@ def from_hp( return Int4PreshuffledTensor( qdata=wq, group_scale=group_scale, - group_zero=group_zero, - row_scale=row_scale, block_size=block_size, shape=original_shape, + group_zero=group_zero, + row_scale=row_scale, ) diff --git a/torchao/utils.py b/torchao/utils.py index 9d7e73b541..68d17ededf 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -517,12 +517,21 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: getattr(self, a_name) == getattr(src, a_name) for a_name in self.tensor_attribute_names ) + + _optional_attr_match = True + if hasattr(self, "optional_tensor_attribute_names"): + _optional_attr_match = all( + getattr(self, a_name) == getattr(src, a_name) + for a_name in self.optional_tensor_attribute_names + ) + return ( type(self) == type(src) and self.shape == src.shape and _tensor_shape_match and _optional_tensor_shape_match and _attr_match + and _optional_attr_match ) @implements(aten.copy_.default) @@ -549,22 +558,32 @@ def _(func, types, args, kwargs): tensors = [ getattr(self, name).to(device) for name in self.tensor_data_names ] + optional_tensors = [] if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: maybe_tensor = getattr(self, tensor_data_name) if maybe_tensor is not None: - tensors.append(maybe_tensor.to(device)) + optional_tensors.append(maybe_tensor.to(device)) else: - tensors.append(None) + optional_tensors.append(None) # change device tensor_attributes = [ getattr(self, attr_name) if attr_name != "device" else device for attr_name in self.tensor_attribute_names ] + optional_tensor_attributes = [] + if hasattr(self, "optional_tensor_attribute_names"): + optional_tensor_attributes = [ + getattr(self, attr_name) if attr_name != "device" else device + for attr_name in self.optional_tensor_attribute_names + ] + t = self.__class__( *tensors, *tensor_attributes, + *optional_tensors, + *optional_tensor_attributes, ) return return_and_correct_aliasing(func, args, kwargs, t) @@ -573,6 +592,26 @@ def _(func, types, args, kwargs): ) +def _torchao_base_tensor__setstate__(self, state): + assert hasattr(self, "tensor_data_names") and hasattr( + self, "tensor_attribute_names" + ) + torch._utils._set_obj_state(self, state) + for optional_tensor_data_name in getattr(self, "optional_tensor_data_names", []): + if optional_tensor_data_name not in self.__dict__ and not hasattr( + self, optional_tensor_data_name + ): + setattr(self, optional_tensor_data_name, None) + + for optional_tensor_attribute_name in getattr( + self, "optional_tensor_attribute_names", [] + ): + if optional_tensor_attribute_name not in self.__dict__ and not hasattr( + self, optional_tensor_attribute_name + ): + setattr(self, optional_tensor_attribute_name, None) + + def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None): """Use this util function for a common `__torch_function__` implementation that dispatches to ops/functions registered with `_implements` @@ -725,10 +764,13 @@ class PlainAQTTensorImpl(...): class variables to define to simplify implmentation of tensor subclasses: `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match - the `__init__` list of tensor subclass - `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional + the `__init__` list of tensor subclass `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, - order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` + order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments + `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor data attributes, when defined, this will be a list of names of Tensors that can be optional + `optional_tensor_attribute_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional non-Tensor attributes, when defined, this will be a list of names of attributes that can be optional + Note: Argument order in __init__ and __new__ should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present) + If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional functions that will be added, this includes: @@ -739,23 +781,31 @@ class variables to define to simplify implmentation of tensor subclasses: recreate a new subclassed Tensor with the transformed tensor data `__repr__`: the string representation of the subclassed tensor instance `_same_metadata`: returns whether the metadata is the same between two instances of cls + `__setstate__`: when loading a serialized tensor subclass checkpoints, it sets the new + optional tensor and tensor attribute that is saved in the old checkpoint to None, + to maintain BC of old checkpoints when we add new optional tensor data or attributes to + the tensor subclass torch ops: torch.Tensor.contiguous aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to) Example: class MyTensor(torch.Tensor): tensor_data_names = ["a", "b"] - optional_tensor_data_names = ["c", "d"] - tensor_attribute_names = ["e", "f"] + tensor_attribute_names = ["c", "d"] + optional_tensor_data_names = ["e", "f"] + optional_tensor_attribute_names = ["g", "h"] + def __new__( cls, a: Tensor, b: Tensor, - c: Optional[Tensor], - d: Optional[Tensor], - e: int, - f: str + c: int, + d: str, + e: Optional[Tensor] = None, + f: Optional[Tensor] = None, + g: Optional[int] = None, + h: Optional[int] = None, ): pass @@ -763,10 +813,12 @@ def __init__( self, a: Tensor, b: Tensor, - c: Optional[Tensor], - d: Optional[Tensor], - e: int, - f: str + c: int, + d: str + e: Optional[Tensor] = None, + f: Optional[Tensor] = None, + g: Optional[int] = None, + h: Optional[int] = None, ): pass @@ -780,9 +832,11 @@ def __init_subclass__(cls, **kwargs): if cls not in cls._ATEN_OP_OR_TORCH_FN_TABLE: cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] = {} - # define the common ops if the tensor_data_names and tensor_attribute_names are defined + # define the common ops and __set_state__ for BC + # if the tensor_data_names and tensor_attribute_names are defined if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"): cls._implements_common_tensor_ops() + cls.__setstate__ = _torchao_base_tensor__setstate__ # inherit the torch function and dispatch implementations from direct parent classes # e.g. for `class C(B, A)`, C.__bases__ == (B, A) @@ -811,47 +865,82 @@ def __tensor_flatten__(self): if maybe_tensor is not None: tensor_data_names.append(tensor_data_name) + attrs = [getattr(self, attr) for attr in self.tensor_attribute_names] + if hasattr(self, "optional_tensor_attribute_names"): + attrs += [ + getattr(self, attr) for attr in self.optional_tensor_attribute_names + ] + # TODO(future PR): also return names of tensor attributes for easier # debugging - return tensor_data_names, [ - getattr(self, attr) for attr in self.tensor_attribute_names - ] + return tensor_data_names, attrs raise NotImplementedError( - "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it" + "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" ) @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - tensors = [tensor_data_dict[name] for name in cls.tensor_data_names] - if hasattr(cls, "optional_tensor_data_names"): - for tensor_data_name in cls.optional_tensor_data_names: - if tensor_data_name in tensor_data_dict: - tensors.append(tensor_data_dict[tensor_data_name]) - else: - tensors.append(None) - return cls(*tensors, *tensor_attributes) + if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"): + required_tensors = [ + tensor_data_dict[name] for name in cls.tensor_data_names + ] + optional_tensors = [] + if hasattr(cls, "optional_tensor_data_names"): + for tensor_data_name in cls.optional_tensor_data_names: + if tensor_data_name in tensor_data_dict: + optional_tensors.append(tensor_data_dict[tensor_data_name]) + else: + optional_tensors.append(None) + + required_attributes = tensor_attributes[: len(cls.tensor_attribute_names)] + optional_attributes = [] + if hasattr(cls, "optional_tensor_attribute_names"): + optional_attributes = tensor_attributes[ + len(cls.tensor_attribute_names) : + ] + + return cls( + *required_tensors, + *required_attributes, + *optional_tensors, + *optional_attributes, + ) + raise NotImplementedError( + "Subclasses should implement __tensor_unflatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" + ) def _apply_fn_to_data(self, fn): if hasattr(self, "tensor_data_names") and hasattr( self, "tensor_attribute_names" ): - tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names] + required_tensors = [ + fn(getattr(self, attr)) for attr in self.tensor_data_names + ] + optional_tensors = [] if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: maybe_tensor = getattr(self, tensor_data_name) if maybe_tensor is not None: - tensors.append(fn(maybe_tensor)) + optional_tensors.append(fn(maybe_tensor)) else: - tensors.append(None) + optional_tensors.append(None) - tensor_attributes = [ + required_attributes = [ getattr(self, attr) for attr in self.tensor_attribute_names ] + optional_attributes = [] + if hasattr(self, "optional_tensor_attribute_names"): + optional_attributes = [ + getattr(self, attr) for attr in self.optional_tensor_attribute_names + ] + return self.__class__( - *tensors, - *tensor_attributes, + *required_tensors, + *required_attributes, + *optional_tensors, + *optional_attributes, ) raise NotImplementedError( @@ -863,19 +952,29 @@ def __repr__(self): self, "tensor_attribute_names" ): repr_str = "" + # required tensor data repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}" for tensor_data_name in self.tensor_data_names[1:]: repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}" + + # required attributes + for tensor_attribute_name in self.tensor_attribute_names: + repr_str += ( + f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" + ) + + # optional tensor data if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: repr_str += ( f", {tensor_data_name}={getattr(self, tensor_data_name)}" ) - for tensor_attribute_name in self.tensor_attribute_names: - repr_str += ( - f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" - ) + # optional tensor attributes + if hasattr(self, "optional_tensor_attribute_names"): + for tensor_attribute_name in self.optional_tensor_attribute_names: + repr_str += f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" + return f"{self.__class__.__name__}({repr_str})" raise NotImplementedError(