@@ -95,6 +95,7 @@ class Float8Tensor(TorchAOBaseTensor):
9595
9696 tensor_data_names = ["qdata" , "scale" ]
9797 tensor_attribute_names = []
98+ optional_tensor_data_names = ["test_only_data" ]
9899 optional_tensor_attribute_names = [
99100 "block_size" ,
100101 "mm_config" ,
@@ -103,19 +104,22 @@ class Float8Tensor(TorchAOBaseTensor):
103104 "act_quant_kwargs" ,
104105 "kernel_preference" ,
105106 "dtype" ,
107+ "new_optional_attr" ,
106108 ]
107109
108110 def __new__ (
109111 cls ,
110112 qdata : torch .Tensor ,
111113 scale : torch .Tensor ,
114+ test_only_data : Optional [torch .Tensor ] = None ,
112115 block_size : Optional [List [int ]] = None ,
113116 mm_config : Optional [Float8MMConfig ] = None ,
114117 hp_value_lb : Optional [float ] = None ,
115118 hp_value_ub : Optional [float ] = None ,
116119 act_quant_kwargs : Optional [QuantizeTensorToFloat8Kwargs ] = None ,
117120 kernel_preference : KernelPreference = KernelPreference .AUTO ,
118121 dtype : Optional [torch .dtype ] = None ,
122+ new_optional_attr : Optional [int ] = None ,
119123 ):
120124 shape = qdata .shape
121125 kwargs = {}
@@ -128,22 +132,26 @@ def __init__(
128132 self ,
129133 qdata : torch .Tensor ,
130134 scale : torch .Tensor ,
135+ test_only_data : Optional [torch .Tensor ] = None ,
131136 block_size : Optional [List [int ]] = None ,
132137 mm_config : Optional [Float8MMConfig ] = None ,
133138 hp_value_lb : Optional [float ] = None ,
134139 hp_value_ub : Optional [float ] = None ,
135140 act_quant_kwargs : Optional [QuantizeTensorToFloat8Kwargs ] = None ,
136141 kernel_preference : KernelPreference = KernelPreference .AUTO ,
137142 dtype : Optional [torch .dtype ] = None ,
143+ new_optional_attr : Optional [int ] = None ,
138144 ):
139145 self .qdata = qdata
140146 self .scale = scale
147+ self .test_only_data = test_only_data
141148 self .block_size = block_size
142149 self .mm_config = mm_config
143150 self .hp_value_lb = hp_value_lb
144151 self .hp_value_ub = hp_value_ub
145152 self .act_quant_kwargs = act_quant_kwargs
146153 self .kernel_preference = kernel_preference
154+ self .new_optional_attr = new_optional_attr
147155
148156 def __repr__ (self ):
149157 return (
0 commit comments