-
Notifications
You must be signed in to change notification settings - Fork 43
[do_bench][easy] warmup cudagraph mode in do_bench_profiler #411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
cc3c1ff
e377dd2
0378939
aa047d1
775743a
a0c75ad
f54664b
15a2486
d419edd
f849449
0e180b7
ebec0d3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,6 +185,14 @@ def _do_bench_profiler( | |
Returns: | ||
List of measured kernel times in milliseconds (if return_mode="all") or single value. | ||
""" | ||
# we don't want any outside errors propagating into benchmarking | ||
torch.cuda.synchronize() | ||
|
||
# warmup `fn` (and catches any failures in the process) | ||
for _ in range(3): | ||
fn() | ||
torch.cuda.synchronize() | ||
|
||
# Get cache for L2 cache clearing | ||
cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() | ||
|
||
|
@@ -193,36 +201,28 @@ def _do_bench_profiler( | |
|
||
# Calculate number of iterations based on target rep time | ||
if estimate_ms == 0: | ||
n_repeat = 100 # Default if function is very fast | ||
n_repeat = 1000 # Default if function is very fast | ||
else: | ||
n_repeat = max(1, int(rep / estimate_ms)) | ||
|
||
# Helper function to execute one iteration | ||
def run_iteration(): | ||
def run_iteration(should_clear_cache: bool): | ||
if grad_to_none is not None: | ||
for x in grad_to_none: | ||
x.grad = None | ||
cache.zero_() | ||
if should_clear_cache: | ||
cache.zero_() | ||
|
||
fn() | ||
|
||
if use_cudagraph: | ||
# Create CUDA graph | ||
g = torch.cuda.CUDAGraph() | ||
with torch.cuda.graph(g): | ||
for _ in range(n_repeat): | ||
run_iteration() | ||
torch.cuda.synchronize() | ||
else: | ||
# Regular mode warmup | ||
n_warmup = max(1, int(warmup / estimate_ms)) if estimate_ms > 0 else 25 | ||
|
||
torch.cuda.synchronize() | ||
for _ in range(n_warmup): | ||
run_iteration() | ||
run_iteration(should_clear_cache=False) | ||
torch.cuda.synchronize() | ||
|
||
n_profiler_runs = 5 | ||
iterations_per_profiler_run = n_repeat | ||
n_profiler_runs = 10 | ||
|
||
# Benchmark phase - collect kernel times for each iteration | ||
all_kernel_times = [] | ||
|
@@ -243,8 +243,8 @@ def run_iteration(): | |
g.replay() | ||
else: | ||
# Execute multiple iterations for regular mode | ||
for _ in range(iterations_per_profiler_run): | ||
run_iteration() | ||
for _ in range(n_repeat): | ||
run_iteration(should_clear_cache=True) | ||
torch.cuda.synchronize() | ||
|
||
# Collect all kernel execution intervals | ||
|
@@ -299,9 +299,7 @@ def run_iteration(): | |
) | ||
|
||
# Convert to milliseconds and normalize by iterations | ||
total_kernel_time_ms = ( | ||
total_kernel_time_us / 1000.0 | ||
) / iterations_per_profiler_run | ||
total_kernel_time_ms = (total_kernel_time_us / 1000.0) / n_repeat | ||
all_kernel_times.append(total_kernel_time_ms) | ||
|
||
times = torch.tensor(all_kernel_times, dtype=torch.float) | ||
|
Uh oh!
There was an error while loading. Please reload this page.