Skip to content

Commit 1b69822

Browse files
vkuzoliangel-02
authored andcommitted
float8 kernel test: make more robust (#2847)
Update [ghstack-poisoned]
1 parent 8f7d62b commit 1b69822

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import pytest
1515
import torch
1616
from torch._inductor.test_case import TestCase as InductorTestCase
17-
from torch.profiler import ProfilerActivity, profile
17+
from torch._inductor.utils import run_and_get_code
18+
from torch.testing import FileCheck
1819
from torch.testing._internal import common_utils
1920

2021
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
@@ -766,32 +767,36 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version):
766767
config,
767768
)
768769

769-
m = torch.compile(m, mode="default")
770+
m = torch.compile(m)
770771
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
771-
772-
# warm up
773-
_ = m(x)
774-
# capture trace
775-
with profile(activities=[ProfilerActivity.CUDA]) as prof:
776-
_ = m(x)
777-
778-
cuda_kernel_events = [x for x in prof.key_averages() if x.cuda_time > 0]
779-
780-
if granularity == PerTensor():
772+
out, code = run_and_get_code(m, x)
773+
774+
# triton kernel call looks like:
775+
# triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_clone_div_expand_permute_transpose_unsqueeze_view_0.run(arg3_1, buf1, buf2, 128, 256, stream=stream0)
776+
# scaled_mm call looks like:
777+
# extern_kernels._scaled_mm(buf1, reinterpret_tensor(arg0_1, (256, 512), (1, 256), 0), buf2, reinterpret_tensor(arg1_1, (1, 512), (1, 1), 0), arg2_1, out_dtype=torch.bfloat16, use_fast_accum=True, out=buf3)
778+
if granularity == PerRow():
779+
# one triton kernel for quantizing the activation
780+
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(
781+
code[0]
782+
)
783+
# one scaled_mm call
784+
FileCheck().check("def call(").check_count(
785+
"._scaled_mm(", 1, exactly=True
786+
).run(code[0])
787+
else:
788+
assert granularity == PerTensor(), "unsupported"
789+
# three triton kernels for quantizing the activation:
781790
# kernel 1: x_max_tmp = max(x, ...)
782791
# kernel 2: x_max = max(x_max_tmp)
783792
# kernel 3: x_float8 = to_float8(x, x_max)
784-
# kernel 4: gemm
785-
assert len(cuda_kernel_events) == 4, (
786-
f"too many cuda kernels: {cuda_kernel_events}"
787-
)
788-
else:
789-
assert granularity == PerRow()
790-
# kernel 1: x_float8 = to_float8(x)
791-
# kernel 2: gemm
792-
assert len(cuda_kernel_events) == 2, (
793-
f"too many cuda kernels: {cuda_kernel_events}"
793+
FileCheck().check("def call(").check_count(".run(", 3, exactly=True).run(
794+
code[0]
794795
)
796+
# one scaled_mm call
797+
FileCheck().check("def call(").check_count(
798+
"._scaled_mm(", 1, exactly=True
799+
).run(code[0])
795800

796801

797802
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

0 commit comments

Comments
 (0)