@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
9595
9696 tensor_data_names = ["qdata" , "scale" ]
9797 tensor_attribute_names = []
98+ optional_tensor_data_names = ["test_only_data" ]
9899 optional_tensor_attribute_names = [
99100 "block_size" ,
100101 "mm_config" ,
@@ -109,6 +110,7 @@ def __new__(
109110 cls ,
110111 qdata : torch .Tensor ,
111112 scale : torch .Tensor ,
113+ test_only_data : Optional [torch .Tensor ] = None ,
112114 block_size : Optional [List [int ]] = None ,
113115 mm_config : Optional [Float8MMConfig ] = None ,
114116 hp_value_lb : Optional [float ] = None ,
@@ -128,6 +130,7 @@ def __init__(
128130 self ,
129131 qdata : torch .Tensor ,
130132 scale : torch .Tensor ,
133+ test_only_data : Optional [torch .Tensor ] = None ,
131134 block_size : Optional [List [int ]] = None ,
132135 mm_config : Optional [Float8MMConfig ] = None ,
133136 hp_value_lb : Optional [float ] = None ,
@@ -138,6 +141,7 @@ def __init__(
138141 ):
139142 self .qdata = qdata
140143 self .scale = scale
144+ self .test_only_data = test_only_data
141145 self .block_size = block_size
142146 self .mm_config = mm_config
143147 self .hp_value_lb = hp_value_lb
@@ -152,6 +156,11 @@ def __repr__(self):
152156 f"{ self .shape = } , { self .device = } , { self .dtype = } )"
153157 )
154158
159+ def __setstate__ (self , state ):
160+ torch ._utils ._set_obj_state (self , state )
161+ if "test_only_data" not in self .__dict__ :
162+ self .test_only_data = None
163+
155164 def _quantization_type (self ):
156165 return f"{ self .act_quant_kwargs = } , { self .block_size = } , { self .mm_config = } , { self .scale .shape = } "
157166
0 commit comments