Skip to content

Commit bf74236

Browse files
committed
example: combine distributed profiler activity
1 parent a7f00e5 commit bf74236

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

examples/eval.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def _run_single_profile(test: TestCase) -> str:
515515
return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)
516516

517517

518-
def _run_distributed_profile(test: TestCase, rank: int) -> profile:
518+
def _run_distributed_profile(test: TestCase, rank: int) -> "EventList":
519519
"""
520520
Runs a single profiling case. Do not call directly
521521
"""
@@ -536,11 +536,46 @@ def _run_distributed_profile(test: TestCase, rank: int) -> profile:
536536
submission_output = custom_kernel(data)
537537
torch.cuda.synchronize()
538538

539-
return prof
539+
return prof.events()
540540

541541
finally:
542542
dist.destroy_process_group()
543543

544+
545+
def _combine_traces(traces: list["EventList"]) -> "EventList":
546+
"""
547+
Combine multiple event traces obtained from multiple (distributed) torch.profiler
548+
activities. This function simply aggregates the data as like `prof.key_averages()`,
549+
except over multiple traces. Most of this function is reimplemented
550+
from `torch.autograd.profiler_util.EventList.key_averages()`.
551+
"""
552+
from torch.autograd.profiler_util import FunctionEventAvg, EventList
553+
from collections import defaultdict
554+
555+
def get_key(event) -> tuple[str, ...]:
556+
return (
557+
str(event.key),
558+
str(event.node_id),
559+
str(event.device_type),
560+
str(event.is_legacy),
561+
str(event.is_user_annotation),
562+
)
563+
564+
stats: dict[tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg)
565+
566+
for events in traces:
567+
for event in events:
568+
stats[get_key(event)].add(event)
569+
570+
avg_list = EventList(stats.values())
571+
for event in avg_list:
572+
event.stack = []
573+
event.input_shapes = ""
574+
event.overload_name = ""
575+
576+
return avg_list
577+
578+
544579
def run_multi_gpu_profile(pool: multiprocessing.Pool, test: TestCase, world_size: int) -> str:
545580
"""
546581
Runs a single test in another process.
@@ -556,9 +591,8 @@ def run_multi_gpu_profile(pool: multiprocessing.Pool, test: TestCase, world_size
556591
)
557592

558593
rets = [el.get(120) for el in rets]
594+
return _combine_traces(rets).table(sort_by="self_cuda_time_total", row_limit=20)
559595

560-
# TODO: Combine distributed profiling results?
561-
return rets[0].key_averages().table(sort_by="self_cuda_time_total", row_limit=20)
562596

563597
def run_single_profile(test: TestCase, pool: multiprocessing.Pool) -> str:
564598
"""

0 commit comments

Comments
 (0)