diff --git a/test/test_utils.py b/test/test_utils.py index b5d26432ff..c5bbf45a96 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -186,103 +186,60 @@ class MyTensor(TorchAOBaseTensor): tensor_data_names = ["qdata"] tensor_attribute_names = ["attr", "device"] - def __new__(cls, qdata, attr, device): + def __new__(cls, qdata, attr, device=None): shape = qdata.shape if device is None: device = qdata.device kwargs = {"device": device} return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__(self, qdata, attr, device): + def __init__(self, qdata, attr, device=None): self.qdata = qdata self.attr = attr l = torch.nn.Linear(2, 3) - l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None)) + l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr")) lp_tensor = l.weight another_tensor = torch.nn.Linear(2, 3).weight # attribute has to be the same - lp_tensor_for_copy = MyTensor(another_tensor, "attr", None) + lp_tensor_for_copy = MyTensor(another_tensor, "attr") self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) @skip_if_no_cuda() def test_default_impls_with_optional_data(self): class MyTensorWithOptionalData(TorchAOBaseTensor): tensor_data_names = ["qdata"] - tensor_attribute_names = ["attr", "device"] optional_tensor_data_names = ["zero_point"] - - def __new__(cls, qdata, attr, device, zero_point=None): - shape = qdata.shape - if device is None: - device = qdata.device - kwargs = {"device": device} - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, qdata, attr, device, zero_point=None): - self.qdata = qdata - self.attr = attr - self.zero_point = zero_point - - # test both the optional Tensor is None - # and not None - l = torch.nn.Linear(2, 3) - lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None) - l = torch.nn.Linear(2, 3) - lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None) - self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) - - l = torch.nn.Linear(2, 3) - lp_tensor = MyTensorWithOptionalData( - l.weight, "attr", None, torch.zeros_like(l.weight) - ) - l = torch.nn.Linear(2, 3) - lp_tensor_for_copy = MyTensorWithOptionalData( - l.weight, "attr", None, torch.zeros_like(l.weight) - ) - self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) - - @skip_if_no_cuda() - def test_default_impls_with_optional_attr(self): - class MyTensorWithOptionalData(TorchAOBaseTensor): - tensor_data_names = ["qdata"] tensor_attribute_names = ["attr", "device"] - optional_tensor_data_names = ["zero_point"] - optional_tensor_attribute_names = ["optional_attr"] - def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None): + def __new__(cls, qdata, zero_point=None, attr=1.0, device=None): shape = qdata.shape if device is None: device = qdata.device kwargs = {"device": device} return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - def __init__( - self, qdata, attr, device, zero_point=None, optional_attr=None - ): + def __init__(self, qdata, zero_point=None, attr=1.0, device=None): self.qdata = qdata - self.attr = attr self.zero_point = zero_point - self.optional_attr = optional_attr + self.attr = attr # test both the optional Tensor is None # and not None l = torch.nn.Linear(2, 3) - lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None) + lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr") l = torch.nn.Linear(2, 3) - lp_tensor_for_copy = MyTensorWithOptionalData( - l.weight, "attr", None, zero_point=None - ) + lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr") self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) l = torch.nn.Linear(2, 3) lp_tensor = MyTensorWithOptionalData( - l.weight, "attr", None, zero_point=None, optional_attr="value" + l.weight, torch.zeros_like(l.weight), "attr" ) l = torch.nn.Linear(2, 3) lp_tensor_for_copy = MyTensorWithOptionalData( - l.weight, "attr", None, zero_point=None, optional_attr="value" + l.weight, torch.zeros_like(l.weight), "attr" ) self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index baf6d493df..3cc3961ef4 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -94,8 +94,7 @@ class Float8Tensor(TorchAOBaseTensor): """ tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = [] - optional_tensor_attribute_names = [ + tensor_attribute_names = [ "block_size", "mm_config", "hp_value_lb", @@ -107,15 +106,15 @@ class Float8Tensor(TorchAOBaseTensor): def __new__( cls, - qdata: torch.Tensor, - scale: torch.Tensor, - block_size: Optional[List[int]] = None, - mm_config: Optional[Float8MMConfig] = None, - hp_value_lb: Optional[float] = None, - hp_value_ub: Optional[float] = None, - act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, - kernel_preference: KernelPreference = KernelPreference.AUTO, - dtype: Optional[torch.dtype] = None, + qdata, + scale, + block_size, + mm_config, + hp_value_lb, + hp_value_ub, + act_quant_kwargs, + kernel_preference, + dtype, ): shape = qdata.shape kwargs = {} diff --git a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py index 7310d975de..50cf261642 100644 --- a/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py +++ b/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py @@ -75,17 +75,17 @@ class Int4PreshuffledTensor(TorchAOBaseTensor): """ tensor_data_names = ["qdata", "group_scale"] - tensor_attribute_names = ["block_size", "shape"] optional_tensor_data_names = ["group_zero", "row_scale"] + tensor_attribute_names = ["block_size", "shape"] def __new__( cls, - qdata: torch.Tensor, - group_scale: torch.Tensor, - block_size: List[int], - shape: List[int], - group_zero: Optional[torch.Tensor] = None, - row_scale: Optional[torch.Tensor] = None, + qdata, + group_scale, + group_zero, + row_scale, + block_size, + shape, ): kwargs = {} kwargs["device"] = qdata.device @@ -97,19 +97,19 @@ def __init__( self, qdata: torch.Tensor, group_scale: torch.Tensor, + group_zero: Optional[torch.Tensor], + row_scale: Optional[torch.Tensor], block_size: List[int], shape: List[int], - group_zero: Optional[torch.Tensor] = None, - row_scale: Optional[torch.Tensor] = None, ): # one and only one of group_scale and group_zero should be None assert group_zero is None or row_scale is None assert not (group_zero is not None and row_scale is not None) self.qdata = qdata - self.row_scale = row_scale - self.block_size = block_size self.group_scale = group_scale self.group_zero = group_zero + self.row_scale = row_scale + self.block_size = block_size def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}" @@ -178,10 +178,10 @@ def from_hp( return Int4PreshuffledTensor( qdata=wq, group_scale=group_scale, - block_size=block_size, - shape=original_shape, group_zero=group_zero, row_scale=row_scale, + block_size=block_size, + shape=original_shape, ) diff --git a/torchao/utils.py b/torchao/utils.py index 68d17ededf..9d7e73b541 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -517,21 +517,12 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: getattr(self, a_name) == getattr(src, a_name) for a_name in self.tensor_attribute_names ) - - _optional_attr_match = True - if hasattr(self, "optional_tensor_attribute_names"): - _optional_attr_match = all( - getattr(self, a_name) == getattr(src, a_name) - for a_name in self.optional_tensor_attribute_names - ) - return ( type(self) == type(src) and self.shape == src.shape and _tensor_shape_match and _optional_tensor_shape_match and _attr_match - and _optional_attr_match ) @implements(aten.copy_.default) @@ -558,32 +549,22 @@ def _(func, types, args, kwargs): tensors = [ getattr(self, name).to(device) for name in self.tensor_data_names ] - optional_tensors = [] if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: maybe_tensor = getattr(self, tensor_data_name) if maybe_tensor is not None: - optional_tensors.append(maybe_tensor.to(device)) + tensors.append(maybe_tensor.to(device)) else: - optional_tensors.append(None) + tensors.append(None) # change device tensor_attributes = [ getattr(self, attr_name) if attr_name != "device" else device for attr_name in self.tensor_attribute_names ] - optional_tensor_attributes = [] - if hasattr(self, "optional_tensor_attribute_names"): - optional_tensor_attributes = [ - getattr(self, attr_name) if attr_name != "device" else device - for attr_name in self.optional_tensor_attribute_names - ] - t = self.__class__( *tensors, *tensor_attributes, - *optional_tensors, - *optional_tensor_attributes, ) return return_and_correct_aliasing(func, args, kwargs, t) @@ -592,26 +573,6 @@ def _(func, types, args, kwargs): ) -def _torchao_base_tensor__setstate__(self, state): - assert hasattr(self, "tensor_data_names") and hasattr( - self, "tensor_attribute_names" - ) - torch._utils._set_obj_state(self, state) - for optional_tensor_data_name in getattr(self, "optional_tensor_data_names", []): - if optional_tensor_data_name not in self.__dict__ and not hasattr( - self, optional_tensor_data_name - ): - setattr(self, optional_tensor_data_name, None) - - for optional_tensor_attribute_name in getattr( - self, "optional_tensor_attribute_names", [] - ): - if optional_tensor_attribute_name not in self.__dict__ and not hasattr( - self, optional_tensor_attribute_name - ): - setattr(self, optional_tensor_attribute_name, None) - - def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None): """Use this util function for a common `__torch_function__` implementation that dispatches to ops/functions registered with `_implements` @@ -764,13 +725,10 @@ class PlainAQTTensorImpl(...): class variables to define to simplify implmentation of tensor subclasses: `tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match - the `__init__` list of tensor subclass + the `__init__` list of tensor subclass + `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional `tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, - order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments - `optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor data attributes, when defined, this will be a list of names of Tensors that can be optional - `optional_tensor_attribute_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional non-Tensor attributes, when defined, this will be a list of names of attributes that can be optional - Note: Argument order in __init__ and __new__ should match exaclty with tensor_data_names + tensor_attribute_names + optional_tensor_data_names (if present) + optional_tensor_attribute_names (if present) - + order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` If `tensor_data_names` and `tensor_attribute_names` are defined, there are some additional functions that will be added, this includes: @@ -781,31 +739,23 @@ class variables to define to simplify implmentation of tensor subclasses: recreate a new subclassed Tensor with the transformed tensor data `__repr__`: the string representation of the subclassed tensor instance `_same_metadata`: returns whether the metadata is the same between two instances of cls - `__setstate__`: when loading a serialized tensor subclass checkpoints, it sets the new - optional tensor and tensor attribute that is saved in the old checkpoint to None, - to maintain BC of old checkpoints when we add new optional tensor data or attributes to - the tensor subclass torch ops: torch.Tensor.contiguous aten ops: aten.detach.default, aten.clone.default, aten.alias,default, aten.contiguous.default, aten.copy_.default, aten._to_copy.default (enables t.to) Example: class MyTensor(torch.Tensor): tensor_data_names = ["a", "b"] - tensor_attribute_names = ["c", "d"] - optional_tensor_data_names = ["e", "f"] - optional_tensor_attribute_names = ["g", "h"] - + optional_tensor_data_names = ["c", "d"] + tensor_attribute_names = ["e", "f"] def __new__( cls, a: Tensor, b: Tensor, - c: int, - d: str, - e: Optional[Tensor] = None, - f: Optional[Tensor] = None, - g: Optional[int] = None, - h: Optional[int] = None, + c: Optional[Tensor], + d: Optional[Tensor], + e: int, + f: str ): pass @@ -813,12 +763,10 @@ def __init__( self, a: Tensor, b: Tensor, - c: int, - d: str - e: Optional[Tensor] = None, - f: Optional[Tensor] = None, - g: Optional[int] = None, - h: Optional[int] = None, + c: Optional[Tensor], + d: Optional[Tensor], + e: int, + f: str ): pass @@ -832,11 +780,9 @@ def __init_subclass__(cls, **kwargs): if cls not in cls._ATEN_OP_OR_TORCH_FN_TABLE: cls._ATEN_OP_OR_TORCH_FN_TABLE[cls] = {} - # define the common ops and __set_state__ for BC - # if the tensor_data_names and tensor_attribute_names are defined + # define the common ops if the tensor_data_names and tensor_attribute_names are defined if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"): cls._implements_common_tensor_ops() - cls.__setstate__ = _torchao_base_tensor__setstate__ # inherit the torch function and dispatch implementations from direct parent classes # e.g. for `class C(B, A)`, C.__bases__ == (B, A) @@ -865,82 +811,47 @@ def __tensor_flatten__(self): if maybe_tensor is not None: tensor_data_names.append(tensor_data_name) - attrs = [getattr(self, attr) for attr in self.tensor_attribute_names] - if hasattr(self, "optional_tensor_attribute_names"): - attrs += [ - getattr(self, attr) for attr in self.optional_tensor_attribute_names - ] - # TODO(future PR): also return names of tensor attributes for easier # debugging - return tensor_data_names, attrs + return tensor_data_names, [ + getattr(self, attr) for attr in self.tensor_attribute_names + ] raise NotImplementedError( - "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" + "Subclasses should implement __tensor_flatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class or tensor instance before using it" ) @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - if hasattr(cls, "tensor_data_names") and hasattr(cls, "tensor_attribute_names"): - required_tensors = [ - tensor_data_dict[name] for name in cls.tensor_data_names - ] - optional_tensors = [] - if hasattr(cls, "optional_tensor_data_names"): - for tensor_data_name in cls.optional_tensor_data_names: - if tensor_data_name in tensor_data_dict: - optional_tensors.append(tensor_data_dict[tensor_data_name]) - else: - optional_tensors.append(None) - - required_attributes = tensor_attributes[: len(cls.tensor_attribute_names)] - optional_attributes = [] - if hasattr(cls, "optional_tensor_attribute_names"): - optional_attributes = tensor_attributes[ - len(cls.tensor_attribute_names) : - ] - - return cls( - *required_tensors, - *required_attributes, - *optional_tensors, - *optional_attributes, - ) - raise NotImplementedError( - "Subclasses should implement __tensor_unflatten__ or specify `tensor_data_names` and `tensor_attribute_names` for tensor class before using it" - ) + tensors = [tensor_data_dict[name] for name in cls.tensor_data_names] + if hasattr(cls, "optional_tensor_data_names"): + for tensor_data_name in cls.optional_tensor_data_names: + if tensor_data_name in tensor_data_dict: + tensors.append(tensor_data_dict[tensor_data_name]) + else: + tensors.append(None) + return cls(*tensors, *tensor_attributes) def _apply_fn_to_data(self, fn): if hasattr(self, "tensor_data_names") and hasattr( self, "tensor_attribute_names" ): - required_tensors = [ - fn(getattr(self, attr)) for attr in self.tensor_data_names - ] - optional_tensors = [] + tensors = [fn(getattr(self, attr)) for attr in self.tensor_data_names] if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: maybe_tensor = getattr(self, tensor_data_name) if maybe_tensor is not None: - optional_tensors.append(fn(maybe_tensor)) + tensors.append(fn(maybe_tensor)) else: - optional_tensors.append(None) + tensors.append(None) - required_attributes = [ + tensor_attributes = [ getattr(self, attr) for attr in self.tensor_attribute_names ] - optional_attributes = [] - if hasattr(self, "optional_tensor_attribute_names"): - optional_attributes = [ - getattr(self, attr) for attr in self.optional_tensor_attribute_names - ] - return self.__class__( - *required_tensors, - *required_attributes, - *optional_tensors, - *optional_attributes, + *tensors, + *tensor_attributes, ) raise NotImplementedError( @@ -952,29 +863,19 @@ def __repr__(self): self, "tensor_attribute_names" ): repr_str = "" - # required tensor data repr_str += f"{self.tensor_data_names[0]}={getattr(self, self.tensor_data_names[0])}" for tensor_data_name in self.tensor_data_names[1:]: repr_str += f", {tensor_data_name}={getattr(self, tensor_data_name)}" - - # required attributes - for tensor_attribute_name in self.tensor_attribute_names: - repr_str += ( - f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" - ) - - # optional tensor data if hasattr(self, "optional_tensor_data_names"): for tensor_data_name in self.optional_tensor_data_names: repr_str += ( f", {tensor_data_name}={getattr(self, tensor_data_name)}" ) - # optional tensor attributes - if hasattr(self, "optional_tensor_attribute_names"): - for tensor_attribute_name in self.optional_tensor_attribute_names: - repr_str += f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" - + for tensor_attribute_name in self.tensor_attribute_names: + repr_str += ( + f", {tensor_attribute_name}={getattr(self, tensor_attribute_name)}" + ) return f"{self.__class__.__name__}({repr_str})" raise NotImplementedError(