Skip to content

Commit cae6ba3

Browse files
Add FlightRecorder tests (#1971)
As a follow-up to #1867 , this PR includes tests for the FlightRecorder on XCCL, as well as moving some definitions from ProcessGroupXCCL::Options to Backend::Options. These tests are largely based on `pytorch/test/distributed/test_c10d_nccl.py`, but doesn't include some tests: - `test_short_json` since json dumps are not supported in ProcessGroupXCCL - `test_trace_while_all_works_retired`: `_wait_for_pending_works` isn't supported by XCCL - `test_trace_while_active`: XCCL hangs when op is called on only one rank - `test_trace_while_stuck`: XCCL hangs when op is called on only one rank --------- Co-authored-by: Yu, Guangye <[email protected]>
1 parent 086f20a commit cae6ba3

File tree

3 files changed

+598
-21
lines changed

3 files changed

+598
-21
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
459459
bool isP2P,
460460
const char* profilingTitle,
461461
const std::vector<at::Tensor>& inputs,
462-
const std::vector<at::Tensor>& outputs) {
462+
const std::vector<at::Tensor>& outputs,
463+
bool record) {
463464
auto r = c10::make_intrusive<ProcessGroupXCCL::WorkXCCL>(
464465
device,
465466
rank,
@@ -470,20 +471,22 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
470471
profilingTitle != nullptr ? std::optional<std::vector<at::Tensor>>(inputs)
471472
: std::nullopt);
472473

473-
r->trace_id_ = FlightRecorderXCCL::get()->record(
474-
local_id_,
475-
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
476-
seqCollective_,
477-
seqP2P_,
478-
op_id_,
479-
profilingTitle ? profilingTitle : "",
480-
inputs,
481-
outputs,
482-
nullptr,
483-
r->xcclEndEvent_.get(),
484-
options_->timeout,
485-
pgStatus_,
486-
isP2P);
474+
if (record) {
475+
r->trace_id_ = FlightRecorderXCCL::get()->record(
476+
local_id_,
477+
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
478+
seqCollective_,
479+
seqP2P_,
480+
op_id_,
481+
profilingTitle ? profilingTitle : "",
482+
inputs,
483+
outputs,
484+
nullptr,
485+
r->xcclEndEvent_.get(),
486+
options_->timeout,
487+
pgStatus_,
488+
isP2P);
489+
}
487490
return r;
488491
}
489492

@@ -664,16 +667,19 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
664667
const char* profilingTitle,
665668
bool nanCheck) {
666669
nanCheck &= enableNanCheck_;
667-
seqCollective_++;
668670
auto device = inputs[0].device();
669671
const auto key = std::to_string(device.index());
670672
auto comm = getXCCLComm(key, device, opType);
671673

674+
if (!coalescing_state_) {
675+
seqCollective_++;
676+
}
677+
op_id_++;
678+
672679
if (coalescing_state_ & CoalActive) {
673680
if ((coalescing_state_ & CoalColl) == 0) {
674681
seqCollective_++;
675682
}
676-
op_id_++;
677683
coalescing_state_ |= CoalColl;
678684
if (coalescedDevice_.index() < 0) {
679685
coalescedDevice_ = device;
@@ -714,8 +720,15 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
714720
}
715721

716722
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;
717-
work =
718-
initWork(device, rank_, opType, false, profilingTitle, inputs, outputs);
723+
work = initWork(
724+
device,
725+
rank_,
726+
opType,
727+
false,
728+
profilingTitle,
729+
inputs,
730+
outputs,
731+
!coalescing_state_);
719732
if (coalescing_state_) {
720733
FlightRecorderXCCL::get()->record(
721734
local_id_,

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
180180
bool isP2P,
181181
const char* profilingTitle = nullptr,
182182
const std::vector<at::Tensor>& inputs = {},
183-
const std::vector<at::Tensor>& outputs = {});
183+
const std::vector<at::Tensor>& outputs = {},
184+
bool record = false);
184185

185186
template <typename Fn>
186187
c10::intrusive_ptr<Work> collective(

0 commit comments

Comments
 (0)