diff --git a/tritonbench/components/do_bench/run.py b/tritonbench/components/do_bench/run.py index e5e6901c3..9e1a881d6 100644 --- a/tritonbench/components/do_bench/run.py +++ b/tritonbench/components/do_bench/run.py @@ -185,6 +185,14 @@ def _do_bench_profiler( Returns: List of measured kernel times in milliseconds (if return_mode="all") or single value. """ + # we don't want any outside errors propagating into benchmarking + torch.cuda.synchronize() + + # warmup `fn` (and catches any failures in the process) + for _ in range(3): + fn() + torch.cuda.synchronize() + # Get cache for L2 cache clearing cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() @@ -212,17 +220,8 @@ def run_iteration(): for _ in range(n_repeat): run_iteration() torch.cuda.synchronize() - else: - # Regular mode warmup - n_warmup = max(1, int(warmup / estimate_ms)) if estimate_ms > 0 else 25 - - torch.cuda.synchronize() - for _ in range(n_warmup): - run_iteration() - torch.cuda.synchronize() n_profiler_runs = 5 - iterations_per_profiler_run = n_repeat # Benchmark phase - collect kernel times for each iteration all_kernel_times = [] @@ -243,13 +242,28 @@ def run_iteration(): g.replay() else: # Execute multiple iterations for regular mode - for _ in range(iterations_per_profiler_run): + for _ in range(n_repeat): run_iteration() torch.cuda.synchronize() # Collect all kernel execution intervals kernel_intervals = [] + # check the number of cache clear kernels. + # we rely on hard-coded aten op name for excluding cache clear kernels. + # this check ensures that pytorch does not dispatch to another kernel. + num_cache_clear_kernels = len( + [ + evt + for evt in prof.events() + if evt.device_type == torch.autograd.DeviceType.CUDA + and evt.name == CACHE_CLEAR_KERNEL + ] + ) + assert ( + num_cache_clear_kernels == n_repeat + ), f"Expected {n_repeat} cache clear kernels but found {num_cache_clear_kernels}" + # Get raw function events and collect time intervals for evt in prof.events(): # Check for CUDA kernel events, excluding cache clear kernel @@ -299,9 +313,7 @@ def run_iteration(): ) # Convert to milliseconds and normalize by iterations - total_kernel_time_ms = ( - total_kernel_time_us / 1000.0 - ) / iterations_per_profiler_run + total_kernel_time_ms = (total_kernel_time_us / 1000.0) / n_repeat all_kernel_times.append(total_kernel_time_ms) times = torch.tensor(all_kernel_times, dtype=torch.float)