@@ -515,7 +515,7 @@ def _run_single_profile(test: TestCase) -> str:
515
515
return prof .key_averages ().table (sort_by = "self_cuda_time_total" , row_limit = 20 )
516
516
517
517
518
- def _run_distributed_profile (test : TestCase , rank : int ) -> profile :
518
+ def _run_distributed_profile (test : TestCase , rank : int ) -> "EventList" :
519
519
"""
520
520
Runs a single profiling case. Do not call directly
521
521
"""
@@ -536,11 +536,46 @@ def _run_distributed_profile(test: TestCase, rank: int) -> profile:
536
536
submission_output = custom_kernel (data )
537
537
torch .cuda .synchronize ()
538
538
539
- return prof
539
+ return prof . events ()
540
540
541
541
finally :
542
542
dist .destroy_process_group ()
543
543
544
+
545
+ def _combine_traces (traces : list ["EventList" ]) -> "EventList" :
546
+ """
547
+ Combine multiple event traces obtained from multiple (distributed) torch.profiler
548
+ activities. This function simply aggregates the data as like `prof.key_averages()`,
549
+ except over multiple traces. Most of this function is reimplemented
550
+ from `torch.autograd.profiler_util.EventList.key_averages()`.
551
+ """
552
+ from torch .autograd .profiler_util import FunctionEventAvg , EventList
553
+ from collections import defaultdict
554
+
555
+ def get_key (event ) -> tuple [str , ...]:
556
+ return (
557
+ str (event .key ),
558
+ str (event .node_id ),
559
+ str (event .device_type ),
560
+ str (event .is_legacy ),
561
+ str (event .is_user_annotation ),
562
+ )
563
+
564
+ stats : dict [tuple [str , ...], FunctionEventAvg ] = defaultdict (FunctionEventAvg )
565
+
566
+ for events in traces :
567
+ for event in events :
568
+ stats [get_key (event )].add (event )
569
+
570
+ avg_list = EventList (stats .values ())
571
+ for event in avg_list :
572
+ event .stack = []
573
+ event .input_shapes = ""
574
+ event .overload_name = ""
575
+
576
+ return avg_list
577
+
578
+
544
579
def run_multi_gpu_profile (pool : multiprocessing .Pool , test : TestCase , world_size : int ) -> str :
545
580
"""
546
581
Runs a single test in another process.
@@ -556,9 +591,8 @@ def run_multi_gpu_profile(pool: multiprocessing.Pool, test: TestCase, world_size
556
591
)
557
592
558
593
rets = [el .get (120 ) for el in rets ]
594
+ return _combine_traces (rets ).table (sort_by = "self_cuda_time_total" , row_limit = 20 )
559
595
560
- # TODO: Combine distributed profiling results?
561
- return rets [0 ].key_averages ().table (sort_by = "self_cuda_time_total" , row_limit = 20 )
562
596
563
597
def run_single_profile (test : TestCase , pool : multiprocessing .Pool ) -> str :
564
598
"""
0 commit comments