|
14 | 14 | import pytest
|
15 | 15 | import torch
|
16 | 16 | from torch._inductor.test_case import TestCase as InductorTestCase
|
17 |
| -from torch.profiler import ProfilerActivity, profile |
| 17 | +from torch._inductor.utils import run_and_get_code |
| 18 | +from torch.testing import FileCheck |
18 | 19 | from torch.testing._internal import common_utils
|
19 | 20 |
|
20 | 21 | from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
|
@@ -766,32 +767,36 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version):
|
766 | 767 | config,
|
767 | 768 | )
|
768 | 769 |
|
769 |
| - m = torch.compile(m, mode="default") |
| 770 | + m = torch.compile(m) |
770 | 771 | x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
771 |
| - |
772 |
| - # warm up |
773 |
| - _ = m(x) |
774 |
| - # capture trace |
775 |
| - with profile(activities=[ProfilerActivity.CUDA]) as prof: |
776 |
| - _ = m(x) |
777 |
| - |
778 |
| - cuda_kernel_events = [x for x in prof.key_averages() if x.cuda_time > 0] |
779 |
| - |
780 |
| - if granularity == PerTensor(): |
| 772 | + out, code = run_and_get_code(m, x) |
| 773 | + |
| 774 | + # triton kernel call looks like: |
| 775 | + # triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_clone_div_expand_permute_transpose_unsqueeze_view_0.run(arg3_1, buf1, buf2, 128, 256, stream=stream0) |
| 776 | + # scaled_mm call looks like: |
| 777 | + # extern_kernels._scaled_mm(buf1, reinterpret_tensor(arg0_1, (256, 512), (1, 256), 0), buf2, reinterpret_tensor(arg1_1, (1, 512), (1, 1), 0), arg2_1, out_dtype=torch.bfloat16, use_fast_accum=True, out=buf3) |
| 778 | + if granularity == PerRow(): |
| 779 | + # one triton kernel for quantizing the activation |
| 780 | + FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run( |
| 781 | + code[0] |
| 782 | + ) |
| 783 | + # one scaled_mm call |
| 784 | + FileCheck().check("def call(").check_count( |
| 785 | + "._scaled_mm(", 1, exactly=True |
| 786 | + ).run(code[0]) |
| 787 | + else: |
| 788 | + assert granularity == PerTensor(), "unsupported" |
| 789 | + # three triton kernels for quantizing the activation: |
781 | 790 | # kernel 1: x_max_tmp = max(x, ...)
|
782 | 791 | # kernel 2: x_max = max(x_max_tmp)
|
783 | 792 | # kernel 3: x_float8 = to_float8(x, x_max)
|
784 |
| - # kernel 4: gemm |
785 |
| - assert len(cuda_kernel_events) == 4, ( |
786 |
| - f"too many cuda kernels: {cuda_kernel_events}" |
787 |
| - ) |
788 |
| - else: |
789 |
| - assert granularity == PerRow() |
790 |
| - # kernel 1: x_float8 = to_float8(x) |
791 |
| - # kernel 2: gemm |
792 |
| - assert len(cuda_kernel_events) == 2, ( |
793 |
| - f"too many cuda kernels: {cuda_kernel_events}" |
| 793 | + FileCheck().check("def call(").check_count(".run(", 3, exactly=True).run( |
| 794 | + code[0] |
794 | 795 | )
|
| 796 | + # one scaled_mm call |
| 797 | + FileCheck().check("def call(").check_count( |
| 798 | + "._scaled_mm(", 1, exactly=True |
| 799 | + ).run(code[0]) |
795 | 800 |
|
796 | 801 |
|
797 | 802 | common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
|
|
0 commit comments