11#ifdef USE_C10D_XCCL
22
33#include < torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
4+ #include < torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
45#include < xccl/NanCheck_XPU.hpp>
56#include < xccl/ProcessGroupXCCL.hpp>
67
78namespace c10d {
89
10+ using FlightRecorderXCCL = FlightRecorder<at::xpu::XPUEvent>;
11+
912namespace {
1013
1114#if defined(CCL_MAJOR_VERSION) && \
@@ -200,6 +203,17 @@ void syncStream(
200203
201204} // namespace
202205
206+ std::string dump_xccl_trace (
207+ bool includeCollectives,
208+ bool includeStackTraces,
209+ bool onlyActive) {
210+ auto xcclDumpMap = std::unordered_map<
211+ std::string,
212+ std::unordered_map<std::string, std::string>>();
213+ return FlightRecorderXCCL::get ()->dump (
214+ xcclDumpMap, includeCollectives, includeStackTraces, onlyActive);
215+ }
216+
203217constexpr int64_t kSynchronizeBusyWaitMillis = 10 ;
204218thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0 ;
205219
@@ -303,6 +317,10 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
303317 return true ;
304318}
305319
320+ ProcessGroupXCCL::Options::Options ()
321+ : Backend::Options(XCCL_BACKEND_NAME) {}
322+
323+
306324static std::atomic<size_t > process_group_id = 0 ;
307325
308326constexpr const char * MULTI_DEVICE_ERROR_MSG =
@@ -332,19 +350,28 @@ const std::string& ProcessGroupXCCL::logPrefix() const {
332350ProcessGroupXCCL::ProcessGroupXCCL (
333351 const c10::intrusive_ptr<Store>& store,
334352 int rank,
335- int size)
353+ int size,
354+ c10::intrusive_ptr<Options> options)
336355 : Backend(rank, size),
337356 store_(store),
357+ options_(std::move(options)),
338358 xcclCommCounter_(0 ),
339359 local_id_(process_group_id++) {
340360 logPrefix_ = createLogPrefix ();
341361 blockingWait_ = getCvarBool (TORCH_XCCL_BLOCKING_WAIT, false );
362+ traceBufferSize_ = getCvarInt ({" TORCH_FR_BUFFER_SIZE" }, 2000 );
363+
364+ this ->setGroupUid (options_->group_name );
365+ // In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
366+ const auto XcclVersion = getXcclVersion ();
367+ FlightRecorderXCCL::get ()->record_pg_ranks (
368+ std::make_tuple (pg_uid_, pg_desc_), groupRanks ());
369+ FlightRecorderXCCL::get ()->record_accelerator_version (XcclVersion);
342370 enableNanCheck_ = getCvarBool (TORCH_XCCL_NAN_CHECK, false );
343371 init ();
344372 const std::string OFF = " OFF" ;
345373 std::string torch_distributed_debug =
346374 getCvarString ({" TORCH_DISTRIBUTED_DEBUG" }, OFF.c_str ());
347- const auto XcclVersion = getXcclVersion ();
348375 LOG (INFO) << logPrefix () << " ProcessGroupXCCL initialization options: "
349376 << " size: " << size << " , global rank: " << rank_;
350377
@@ -353,9 +380,63 @@ ProcessGroupXCCL::ProcessGroupXCCL(
353380 << " , TORCH_XCCL_BLOCKING_WAIT: " << blockingWait_
354381 << " , TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
355382 << " , TORCH_XCCL_NAN_CHECK: " << enableNanCheck_;
383+
384+ // Heartbeat monitor thread dumps debug info on write to pipe
385+ heartbeatMonitor_ = std::make_unique<HeartbeatMonitorXCCL>(this );
386+ heartbeatMonitor_->start ();
387+ }
388+
389+ ProcessGroupXCCL::~ProcessGroupXCCL () {
390+ heartbeatMonitor_->stop ();
391+ // Wait for all threads to finish before returning
392+ heartbeatMonitor_->join ();
356393}
357394
358- ProcessGroupXCCL::~ProcessGroupXCCL () = default ;
395+ bool ProcessGroupXCCL::dumpDebuggingInfo (bool includeStackTrace /* =true*/ ) {
396+ STATIC_SCOPED_WAIT_COUNTER (pytorch.ProcessGroupXCCL__dumpDebuggingInfo );
397+ LOG (ERROR)
398+ << logPrefix ()
399+ << " ProcessGroupXCCL preparing to dump debug info. Include stack trace: "
400+ << includeStackTrace;
401+ if (traceBufferSize_ > 0 ) {
402+ // TODO: dump_xccl_trace
403+ auto xcclTrace = dump_xccl_trace (true , includeStackTrace, false );
404+ DebugInfoWriter& writer = DebugInfoWriter::getWriter (rank_);
405+ LOG (INFO) << logPrefix () << " ProcessGroupXCCL dumping xccl trace to "
406+ << writer.getWriterTarget ();
407+ writer.write (xcclTrace);
408+ LOG (INFO) << logPrefix () << " Flight Recorder trace successfully dumped." ;
409+ return true ;
410+ }
411+ return false ;
412+ }
413+
414+ const std::vector<uint64_t >& ProcessGroupXCCL::groupRanks () const {
415+ if (options_->global_ranks_in_group .empty () && local_id_ == 0 ) {
416+ static std::vector<uint64_t > globalRanks (size_);
417+ std::iota (globalRanks.begin (), globalRanks.end (), 0 );
418+ return globalRanks;
419+ }
420+ return options_->global_ranks_in_group ;
421+ }
422+
423+ void ProcessGroupXCCL::setEnqueuedPgStatus (
424+ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
425+ pgStatus_->lastEnqueuedSeq = static_cast <int64_t >(work->getSequencenumber ());
426+ pgStatus_->lastEnqueuedWorkName = opTypeToString (work->opType_ );
427+ pgStatus_->lastEnqueuedNumelIn = work->numelIn_ ;
428+ pgStatus_->lastEnqueuedNumelOut = work->numelOut_ ;
429+ }
430+
431+ void ProcessGroupXCCL::setCompletedPgStatus (
432+ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work) {
433+ pgStatus_->lastCompletedSeq = static_cast <int64_t >(work->getSequencenumber ());
434+ pgStatus_->lastCompletedWorkName = opTypeToString (work->opType_ );
435+ pgStatus_->lastCompletedNumelIn = work->numelIn_ ;
436+ pgStatus_->lastCompletedNumelOut = work->numelOut_ ;
437+ // To avoid complexity, we're not computing duration.
438+ FlightRecorderXCCL::get ()->retire_id (work->trace_id_ , /* compute_duration*/ false );
439+ }
359440
360441void ProcessGroupXCCL::setSequenceNumberForGroup () {}
361442
@@ -384,6 +465,21 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
384465 profilingTitle,
385466 profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
386467 : std::nullopt );
468+
469+ r->trace_id_ = FlightRecorderXCCL::get ()->record (
470+ local_id_,
471+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
472+ seqCollective_,
473+ seqP2P_,
474+ op_id_,
475+ profilingTitle ? profilingTitle : " " ,
476+ inputs,
477+ outputs,
478+ nullptr ,
479+ r->xcclEndEvent_ .get (),
480+ options_->timeout ,
481+ pgStatus_,
482+ isP2P);
387483 return r;
388484}
389485
@@ -538,6 +634,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::endCoalescing(OpType optype) {
538634 groupEnd ();
539635
540636 work->xcclEndEvent_ ->record (stream);
637+ setEnqueuedPgStatus (work);
541638
542639 coalescing_state_ = 0 ;
543640 coalescedComm_ = nullptr ;
@@ -572,6 +669,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
572669 if ((coalescing_state_ & CoalColl) == 0 ) {
573670 seqCollective_++;
574671 }
672+ op_id_++;
575673 coalescing_state_ |= CoalColl;
576674 if (coalescedDevice_.index () < 0 ) {
577675 coalescedDevice_ = device;
@@ -614,6 +712,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
614712 c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
615713 work =
616714 initWork (device, rank_, opType, false , profilingTitle, inputs, outputs);
715+ if (coalescing_state_) {
716+ FlightRecorderXCCL::get ()->record (
717+ local_id_,
718+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
719+ seqCollective_,
720+ seqP2P_,
721+ op_id_,
722+ profilingTitle ? profilingTitle : " " ,
723+ inputs,
724+ outputs,
725+ nullptr ,
726+ nullptr ,
727+ options_->timeout ,
728+ pgStatus_,
729+ false );
730+ }
617731
618732 work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
619733
@@ -653,8 +767,22 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
653767 work->future_ = c10::make_intrusive<at::ivalue::Future>(
654768 c10::ListType::create (c10::TensorType::get ()), devices);
655769 work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
770+ work->future_ ->addCallback (
771+ [this , work](at::ivalue::Future&) {
772+ this ->setCompletedPgStatus (work);
773+ });
656774 work->blockingWait_ = blockingWait_;
657775
776+ work->numelIn_ = 0 ;
777+ work->numelOut_ = 0 ;
778+ for (const auto & input : inputs) {
779+ work->numelIn_ += input.numel ();
780+ }
781+ for (const auto & output : outputs) {
782+ work->numelOut_ += output.numel ();
783+ }
784+ setEnqueuedPgStatus (work);
785+
658786 return asyncOp ? work : nullptr ;
659787}
660788
@@ -687,6 +815,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
687815 }
688816 }
689817
818+ op_id_++;
690819 auto comm = getXCCLComm (key, device, opType, p2pRank, isSendRecvSelf);
691820
692821 if (coalescing_state_ & CoalActive) {
@@ -722,6 +851,21 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
722851 work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
723852 work->outputs_ ->push_back (tensor);
724853
854+ work->trace_id_ = FlightRecorderXCCL::get ()->record (
855+ local_id_,
856+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
857+ seqCollective_,
858+ seqP2P_,
859+ op_id_,
860+ profilingTitle,
861+ {tensor},
862+ {tensor},
863+ nullptr ,
864+ work->xcclEndEvent_ .get (),
865+ options_->timeout ,
866+ pgStatus_,
867+ true );
868+
725869 c10::OptionalDeviceGuard gpuGuard (device);
726870
727871 c10::xpu::XPUCachingAllocator::recordStream (
@@ -737,8 +881,29 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
737881 work->future_ = c10::make_intrusive<at::ivalue::Future>(
738882 c10::ListType::create (c10::TensorType::get ()), devices);
739883 work->future_ ->markCompleted (at::IValue (*work->outputs_ ));
884+ work->future_ ->addCallback (
885+ [this , work](at::ivalue::Future&) {
886+ this ->setCompletedPgStatus (work);
887+ });
888+
889+ work->numelIn_ = work->numelOut_ = tensor.numel ();
890+ setEnqueuedPgStatus (work);
740891 return work;
741892 } else {
893+ FlightRecorderXCCL::get ()->record (
894+ local_id_,
895+ std::make_tuple (pg_uid_, pg_desc_), // PG name tuple
896+ seqCollective_,
897+ seqP2P_,
898+ op_id_,
899+ profilingTitle,
900+ {tensor},
901+ {tensor},
902+ nullptr ,
903+ nullptr ,
904+ options_->timeout ,
905+ pgStatus_,
906+ true );
742907 c10::OptionalDeviceGuard gpuGuard (device);
743908
744909 c10::xpu::XPUCachingAllocator::recordStream (
@@ -2135,6 +2300,14 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::alltoall(
21352300 " xccl:all_to_all" );
21362301}
21372302
2303+ std::string getXcclVersion () {
2304+ auto xccl_version = ccl::get_library_version ();
2305+ std::string versionString = std::to_string (xccl_version.major ) + " ." +
2306+ std::to_string (xccl_version.minor ) + " ." +
2307+ std::to_string (xccl_version.update );
2308+ return versionString;
2309+ }
2310+
21382311} // namespace c10d
21392312
21402313#endif // USE_C10D_XCCL
0 commit comments