Skip to content

Commit 045c959

Browse files
author
Yuxin Cui
authored
Fix FX Graph Cache issue in register_da8w4_concat_linear_cpu_pass (#2907)
* Fix FX Graph Cache issue in register_da8w4_concat_linear_cpu_pass Fix the bug that the FX Graph Cache was being bypassed when using the register_da8w4_concat_linear_cpu_pass, preventing cache hits on subsequent model runs. Implement DA8W4ConcatLinearCPUPass that inherits from CustomGraphPass. Ensure it can be serialized and saved as fxgraph properly. Add the unit test. When saving fxgraph, the fxgraph_cache_bypass shuold remain at 0, confirming that the custom pass is no longer being rejected by the cache system. Signed-off-by: Cui, Yuxin <[email protected]> * Modify the test description for test_da8w4_cpu Modify the test description for test_da8w4_cpu. Signed-off-by: Cui, Yuxin <[email protected]> * Add more detailed comments Signed-off-by: Cui, Yuxin <[email protected]> --------- Signed-off-by: Cui, Yuxin <[email protected]>
1 parent cc65dc5 commit 045c959

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

test/quantization/test_da8w4_cpu.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99

1010
import torch
11+
from torch._dynamo.utils import counters
1112
from torch.testing._internal import common_utils
1213
from torch.testing._internal.common_utils import (
1314
TestCase,
@@ -120,7 +121,6 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
120121
@common_utils.parametrize("x_dim", [2, 3])
121122
@common_utils.parametrize("bias", [True, False])
122123
def test_8da4w_concat_linear_cpu(self, x_dim, bias):
123-
self.skipTest("Disabled for now")
124124
N, K = 64, 128
125125

126126
class Mod(torch.nn.Module):
@@ -163,6 +163,15 @@ def forward(self, x):
163163
# ensure the expected op occurs only once in the code after fusion
164164
# The trailing "(" is to avoid matching the op in the comment
165165
assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1
166+
167+
# Ensure that when concat linear is enabled, fxgraph cache works
168+
# without being bypassed (fxgraph_cache_bypass = 0), indicating that
169+
# DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass
170+
# interface and uuid() function, allowing fxgraph to be saved and hit
171+
# on subsequent runs (fxgraph_cache_hit > 0).
172+
fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"]
173+
assert fx_cache_bypass_count == 0
174+
166175
with torch._inductor.config.patch(
167176
{"freezing": True, "cpp.enable_concat_linear": False}
168177
):
@@ -172,6 +181,10 @@ def forward(self, x):
172181
)
173182
assert torch.allclose(y, y_ref)
174183

184+
# Ensure that the fxgraph cache is also not bypassed when concat linear is disabled
185+
fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"]
186+
assert fx_cache_bypass_count == 0
187+
175188

176189
common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)
177190

torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,6 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias):
314314

315315

316316
# Register the concat linear fusion pass
317-
# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
317+
from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
318318

319-
# register_da8w4_concat_linear_cpu_pass()
319+
register_da8w4_concat_linear_cpu_pass()

torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@
77
import operator
88

99
import torch
10+
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
11+
12+
13+
class DA8W4ConcatLinearCPUPass(CustomGraphPass):
14+
def __call__(self, graph: torch.fx.Graph):
15+
_concat_linear_dq8w4_cpu(graph)
16+
17+
def uuid(self):
18+
return get_hash_for_files((__file__,))
1019

1120

1221
# Inductor FX passes for concat linear for DA8W4
@@ -213,4 +222,5 @@ def ...
213222
def register_da8w4_concat_linear_cpu_pass():
214223
from torch._inductor import config as inductor_config
215224

216-
inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu
225+
da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass()
226+
inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass

0 commit comments

Comments
 (0)