Skip to content

Commit cf9ea8a

Browse files
authored
mxtensor: add serialization support (pytorch#3078)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent d009386 commit cf9ea8a

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import tempfile
89

910
import pytest
1011
import torch
@@ -100,6 +101,16 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
100101
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
101102
)
102103

104+
# serialization
105+
with tempfile.NamedTemporaryFile() as f:
106+
torch.save(m_mx.state_dict(), f)
107+
f.seek(0)
108+
109+
# temporary workaround for https://github.com/pytorch/ao/issues/3077
110+
torch.serialization.add_safe_globals([getattr])
111+
112+
_ = torch.load(f, weights_only=True)
113+
103114

104115
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
105116
@pytest.mark.skipif(

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
_validate_elem_dtype,
1818
_validate_gemm_kernel_choice,
1919
)
20-
from torchao.prototype.mx_formats.mx_tensor import MXTensor, QuantizeTensorToMXKwargs
20+
from torchao.prototype.mx_formats.mx_tensor import (
21+
MXTensor,
22+
QuantizeTensorToMXKwargs,
23+
ScaleCalculationMode,
24+
)
2125
from torchao.prototype.mx_formats.nvfp4_tensor import (
2226
NVFP4MMConfig,
2327
NVFP4Tensor,
@@ -206,6 +210,8 @@ def _nvfp4_inference_linear_transform(
206210
NVFP4Tensor,
207211
NVFP4MMConfig,
208212
MXGemmKernelChoice,
213+
QuantizeTensorToMXKwargs,
214+
ScaleCalculationMode,
209215
]
210216
)
211217

0 commit comments

Comments
 (0)