Skip to content

Commit 8b558fe

Browse files
committed
update flight recorder usage
Summary: we're making changes to flight recorder's record and reset_id api's. update the usage accordingly.
1 parent ab48005 commit 8b558fe

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
483483
: std::nullopt);
484484

485485
if (record) {
486-
r->trace_id_ = FlightRecorderXCCL::get()->record(
486+
auto traceId = FlightRecorderXCCL::get()->record(
487487
local_id_,
488488
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
489489
seqCollective_,
@@ -497,6 +497,9 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
497497
options_->timeout,
498498
pgStatus_,
499499
isP2P);
500+
501+
r->trace_id_ = traceId.id;
502+
r->trace_reset_epoch_ = traceId.reset_epoch;
500503
}
501504
return r;
502505
}
@@ -803,9 +806,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
803806
c10::ListType::create(c10::TensorType::get()), devices);
804807
work->future_->markCompleted(at::IValue(*work->outputs_));
805808
auto id = work->trace_id_;
809+
auto reset_epoch = work->trace_reset_epoch_;
806810
work->future_->addCallback(
807811
[id](at::ivalue::Future&) {
808-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
812+
FlightRecorderXCCL::get()->retire_id(id, reset_epoch, /*compute_duration*/ false);
809813
},
810814
/*use_future*/ false);
811815
work->blockingWait_ = blockingWait_;
@@ -891,7 +895,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
891895
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
892896
work->outputs_->push_back(tensor);
893897

894-
work->trace_id_ = FlightRecorderXCCL::get()->record(
898+
auto traceId = FlightRecorderXCCL::get()->record(
895899
local_id_,
896900
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
897901
seqCollective_,
@@ -905,6 +909,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
905909
options_->timeout,
906910
pgStatus_,
907911
true);
912+
work->trace_id_ = traceId.id;
913+
work->trace_reset_epoch_ = traceId.reset_epoch;
908914

909915
c10::OptionalDeviceGuard gpuGuard(device);
910916

@@ -922,9 +928,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
922928
c10::ListType::create(c10::TensorType::get()), devices);
923929
work->future_->markCompleted(at::IValue(*work->outputs_));
924930
auto id = work->trace_id_;
931+
auto reset_epoch = work->trace_reset_epoch_;
925932
work->future_->addCallback(
926933
[id](at::ivalue::Future&) {
927-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
934+
FlightRecorderXCCL::get()->retire_id(id, reset_epoch, /*compute_duration*/ false);
928935
},
929936
/*use_future*/ false);
930937

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
110110
uint64_t seq_;
111111
bool isP2P_;
112112
std::optional<uint64_t> trace_id_;
113+
std::optional<uint64_t> trace_reset_epoch_;
113114
size_t numelIn_ = -1;
114115
size_t numelOut_ = -1;
115116

0 commit comments

Comments
 (0)