Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def test_nvfp4_swizzled_scales_serialization():
tensor_list, ctx = original_tensor.__tensor_flatten__()

# Verify swizzled flag is preserved in context
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
assert ctx[2] == True

# Test deserialization
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_nvfp4_swizzled_scales_serialization():
tensor_list, ctx = original_tensor.__tensor_flatten__()

# Verify swizzled flag is preserved in context
assert NVFP4Tensor.tensor_attribute_names[2] == "_is_swizzled_scales"
assert NVFP4Tensor.optional_tensor_attribute_names[0] == "_is_swizzled_scales"
assert ctx[2] == True

# Test deserialization
Expand Down
155 changes: 126 additions & 29 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ def __init__(self, data):
l.weight = torch.nn.Parameter(MyTensor(l.weight))

def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
# get `all_tensor_data_names` and `all_tensor_attribute_names`
all_tensor_data_names = lp_tensor.tensor_data_names.copy()
if hasattr(lp_tensor, "optional_tensor_data_names"):
for tensor_data_name in lp_tensor.optional_tensor_data_names:
if getattr(lp_tensor, tensor_data_name) is not None:
all_tensor_data_names.append(tensor_data_name)
all_tensor_attribute_names = lp_tensor.tensor_attribute_names.copy()
if hasattr(lp_tensor, "optional_tensor_attribute_names"):
for tensor_attribute_name in lp_tensor.optional_tensor_attribute_names:
if getattr(lp_tensor, tensor_attribute_name) is not None:
all_tensor_attribute_names.append(tensor_attribute_name)

# test __tensor_flatten__ and __tensor_unflatten__
tensor_data_names, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
Expand All @@ -116,6 +128,19 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
reconstructed = type(lp_tensor).__tensor_unflatten__(
tensor_data_dict, tensor_attributes, outer_size, outer_stride
)
for tensor_data_name in all_tensor_data_names:
self.assertTrue(
torch.equal(
getattr(lp_tensor, tensor_data_name),
getattr(reconstructed, tensor_data_name),
)
)
for tensor_attribute_name in all_tensor_attribute_names:
self.assertEqual(
getattr(lp_tensor, tensor_attribute_name),
getattr(reconstructed, tensor_attribute_name),
)

self.assertTrue(torch.equal(lp_tensor.qdata, reconstructed.qdata))
self.assertEqual(lp_tensor.attr, reconstructed.attr)

Expand All @@ -129,52 +154,81 @@ def _test_default_impls_helper(self, lp_tensor, lp_tensor_for_copy):
# __repr__
_ = str(lp_tensor)

# other ops
# op test: detach
lp_tensor = lp_tensor.detach()
# explicitly testing aten.alias
# op test: alias
lp_tensor = torch.ops.aten.alias(lp_tensor)
lp_tensor = lp_tensor.clone()
# get all tensor_data_names for both

# op test: clone
lp_tensor_clone = lp_tensor.clone()

for tensor_data_name in all_tensor_data_names:
self.assertTrue(
torch.equal(
getattr(lp_tensor_clone, tensor_data_name),
getattr(lp_tensor, tensor_data_name),
)
)
for tensor_attribute_name in all_tensor_attribute_names:
self.assertEqual(
getattr(lp_tensor_clone, tensor_attribute_name),
getattr(lp_tensor, tensor_attribute_name),
)

# op test: transpose
# non optional and valid optional tensors
tensor_data_names = lp_tensor.tensor_data_names.copy()
if hasattr(lp_tensor, "optional_tensor_data_names"):
for tensor_data_name in lp_tensor.optional_tensor_data_names:
if getattr(lp_tensor, tensor_data_name) is not None:
tensor_data_names.append(tensor_data_name)

# for each of the tensor data, we try to
# make it non-contiguous and then use
# lp_tensor.contiguous() call to make sure
# contiguous() works
for tensor_data_name in tensor_data_names:
for tensor_data_name in all_tensor_data_names:
tensor = getattr(lp_tensor, tensor_data_name)
# making qdata not contiguous
tensor = tensor.transpose(0, 1).contiguous()
tensor = tensor.transpose(0, 1)
setattr(lp_tensor, tensor_data_name, tensor)
self.assertFalse(getattr(lp_tensor, tensor_data_name).is_contiguous())
lp_tensor = lp_tensor.contiguous()
# making sure contiguous call works
self.assertTrue(getattr(lp_tensor, tensor_data_name).is_contiguous())

# copy_
lp_tensor_t = lp_tensor.contiguous()

# making sure contiguous call works
for tensor_data_name in all_tensor_data_names:
self.assertTrue(getattr(lp_tensor_t, tensor_data_name).is_contiguous())

# making sure transpose does not change attributes
for tensor_attribute_name in all_tensor_attribute_names:
self.assertEqual(
getattr(lp_tensor_t, tensor_attribute_name),
getattr(lp_tensor, tensor_attribute_name),
)

# op test: copy_
# making sure that initially tensor values are not the same so we can test copy_
self.assertNotEqual(lp_tensor.qdata[0][0], lp_tensor_for_copy.qdata[0][0])
# copy_ requires the attributes to be the same
for tensor_attr_name in lp_tensor.tensor_attribute_names:
for tensor_attribute_name in all_tensor_attribute_names:
self.assertEqual(
getattr(lp_tensor, tensor_attr_name),
getattr(lp_tensor_for_copy, tensor_attr_name),
getattr(lp_tensor_for_copy, tensor_attribute_name),
getattr(lp_tensor, tensor_attribute_name),
)

lp_tensor.copy_(lp_tensor_for_copy)
# after copy_, the tensor values should match
for tensor_data_name in tensor_data_names:
for tensor_data_name in all_tensor_data_names:
self.assertTrue(
torch.equal(
getattr(lp_tensor, tensor_data_name),
getattr(lp_tensor_for_copy, tensor_data_name),
)
)
# after copy_, the tensor attributes still matches
# copy_ requires the attributes to be the same
for tensor_attribute_name in all_tensor_attribute_names:
self.assertEqual(
getattr(lp_tensor_for_copy, tensor_attribute_name),
getattr(lp_tensor, tensor_attribute_name),
)

@skip_if_no_cuda()
def test_default_impls(self):
Expand All @@ -186,60 +240,103 @@ class MyTensor(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr", "device"]

def __new__(cls, qdata, attr, device=None):
def __new__(cls, qdata, attr, device):
shape = qdata.shape
if device is None:
device = qdata.device
kwargs = {"device": device}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, attr, device=None):
def __init__(self, qdata, attr, device):
self.qdata = qdata
self.attr = attr

l = torch.nn.Linear(2, 3)
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None))
lp_tensor = l.weight

another_tensor = torch.nn.Linear(2, 3).weight
# attribute has to be the same
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
lp_tensor_for_copy = MyTensor(another_tensor, "attr", None)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

@skip_if_no_cuda()
def test_default_impls_with_optional_data(self):
class MyTensorWithOptionalData(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
optional_tensor_data_names = ["zero_point"]
tensor_attribute_names = ["attr", "device"]
optional_tensor_data_names = ["zero_point"]

def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
def __new__(cls, qdata, attr, device, zero_point=None):
shape = qdata.shape
if device is None:
device = qdata.device
kwargs = {"device": device}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
def __init__(self, qdata, attr, device, zero_point=None):
self.qdata = qdata
self.attr = attr
self.zero_point = zero_point

# test both the optional Tensor is None
# and not None
l = torch.nn.Linear(2, 3)
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None)
l = torch.nn.Linear(2, 3)
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

l = torch.nn.Linear(2, 3)
lp_tensor = MyTensorWithOptionalData(
l.weight, "attr", None, torch.zeros_like(l.weight)
)
l = torch.nn.Linear(2, 3)
lp_tensor_for_copy = MyTensorWithOptionalData(
l.weight, "attr", None, torch.zeros_like(l.weight)
)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

@skip_if_no_cuda()
def test_default_impls_with_optional_attr(self):
class MyTensorWithOptionalData(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr", "device"]
optional_tensor_data_names = ["zero_point"]
optional_tensor_attribute_names = ["optional_attr"]

def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None):
shape = qdata.shape
if device is None:
device = qdata.device
kwargs = {"device": device}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self, qdata, attr, device, zero_point=None, optional_attr=None
):
self.qdata = qdata
self.attr = attr
self.zero_point = zero_point
self.optional_attr = optional_attr

# test both the optional Tensor is None
# and not None
l = torch.nn.Linear(2, 3)
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None)
l = torch.nn.Linear(2, 3)
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
lp_tensor_for_copy = MyTensorWithOptionalData(
l.weight, "attr", None, zero_point=None
)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

l = torch.nn.Linear(2, 3)
lp_tensor = MyTensorWithOptionalData(
l.weight, torch.zeros_like(l.weight), "attr"
l.weight, "attr", None, zero_point=None, optional_attr="value"
)
l = torch.nn.Linear(2, 3)
lp_tensor_for_copy = MyTensorWithOptionalData(
l.weight, torch.zeros_like(l.weight), "attr"
l.weight, "attr", None, zero_point=None, optional_attr="value"
)
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)

Expand Down
36 changes: 19 additions & 17 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ class NVFP4Tensor(TorchAOBaseTensor):
"""

tensor_data_names = ["qdata", "_scale_e4m3"]
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
tensor_attribute_names = [
"_block_size",
"_orig_dtype",
]
optional_tensor_data_names = ["_per_tensor_scale", "_act_per_tensor_scale"]
optional_tensor_attribute_names = [
"_is_swizzled_scales",
"use_triton_kernel",
"act_quant_kwargs",
Expand All @@ -92,10 +94,10 @@ def __new__(
cls,
qdata,
blockwise_scales,
per_tensor_scale,
act_per_tensor_scale,
block_size,
orig_dtype,
per_tensor_scale,
act_per_tensor_scale,
is_swizzled_scales=False,
use_triton_kernel=False,
act_quant_kwargs=None,
Expand All @@ -116,13 +118,13 @@ def __new__(
requires_grad=False,
)

self._scale_e4m3 = blockwise_scales
self._is_swizzled_scales = is_swizzled_scales
self._per_tensor_scale = per_tensor_scale
self._act_per_tensor_scale = act_per_tensor_scale
self.qdata = qdata
self._scale_e4m3 = blockwise_scales
self._block_size = block_size
self._orig_dtype = orig_dtype
self._per_tensor_scale = per_tensor_scale
self._act_per_tensor_scale = act_per_tensor_scale
self._is_swizzled_scales = is_swizzled_scales
self.use_triton_kernel = use_triton_kernel
self.act_quant_kwargs = act_quant_kwargs
return self
Expand Down Expand Up @@ -184,10 +186,10 @@ def to_nvfp4(
return NVFP4Tensor(
data_lp,
blockwise_scales,
per_tensor_scale,
act_per_tensor_scale,
block_size,
data_hp.dtype,
per_tensor_scale,
act_per_tensor_scale,
is_swizzled_scales,
use_triton_kernel,
act_quant_kwargs,
Expand Down Expand Up @@ -312,10 +314,10 @@ def nvfp4_to_copy(func, types, args, kwargs):
res = NVFP4Tensor(
tensor.qdata,
tensor._scale_e4m3,
tensor._per_tensor_scale,
tensor._act_per_tensor_scale,
tensor._block_size,
dtype,
tensor._per_tensor_scale,
tensor._act_per_tensor_scale,
tensor._is_swizzled_scales,
tensor.use_triton_kernel,
tensor.act_quant_kwargs,
Expand Down Expand Up @@ -513,10 +515,10 @@ def nvfp4_slice(func, types, args, kwargs):
result = NVFP4Tensor(
sliced_data,
sliced_scale,
x._per_tensor_scale,
x._act_per_tensor_scale,
x._block_size,
x._orig_dtype,
x._per_tensor_scale,
x._act_per_tensor_scale,
x._is_swizzled_scales,
x.use_triton_kernel,
x.act_quant_kwargs,
Expand All @@ -532,10 +534,10 @@ def nvfp4_t(func, types, args, kwargs):
new = NVFP4Tensor(
old.qdata.t(),
old._scale_e4m3,
old._per_tensor_scale,
old._act_per_tensor_scale,
old._block_size,
old._orig_dtype,
old._per_tensor_scale,
old._act_per_tensor_scale,
old._is_swizzled_scales,
old.use_triton_kernel,
old.act_quant_kwargs,
Expand All @@ -552,10 +554,10 @@ def nvfp4_view_op(func, types, args, kwargs):
return NVFP4Tensor(
new_data,
args[0]._scale_e4m3,
args[0]._per_tensor_scale,
args[0]._act_per_tensor_scale,
args[0]._block_size,
args[0]._orig_dtype,
args[0]._per_tensor_scale,
args[0]._act_per_tensor_scale,
args[0]._is_swizzled_scales,
args[0].use_triton_kernel,
args[0].act_quant_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def do_autoquant_bench(op, *args, **kwargs):
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
if torch_version_at_least("2.8.0"):
if torch_version_at_least("2.9.0.dev"):
from statistics import median

res = benchmarker.benchmark_gpu(
Expand Down
Loading
Loading