Skip to content

Commit 2cb6f70

Browse files
committed
update flight recorder usage
Summary: we're making changes to flight recorder's record and reset_id api's. update the usage accordingly.
1 parent 4090fa6 commit 2cb6f70

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
508508
xpuEventCacheEnabled_.load());
509509

510510
if (record) {
511-
r->trace_id_ = FlightRecorderXCCL::get()->record(
511+
auto traceId = FlightRecorderXCCL::get()->record(
512512
local_id_,
513513
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
514514
seqCollective_,
@@ -522,6 +522,9 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
522522
options_->timeout,
523523
pgStatus_,
524524
isP2P);
525+
526+
r->trace_id_ = traceId.id;
527+
r->trace_reset_epoch_ = traceId.reset_epoch;
525528
}
526529
return r;
527530
}
@@ -847,9 +850,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
847850
c10::ListType::create(c10::TensorType::get()), devices);
848851
work->future_->markCompleted(at::IValue(*work->outputs_));
849852
auto id = work->trace_id_;
853+
auto reset_epoch = work->trace_reset_epoch_;
850854
work->future_->addCallback(
851-
[id](at::ivalue::Future&) {
852-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
855+
[id, reset_epoch](at::ivalue::Future&) {
856+
FlightRecorderXCCL::get()->retire_id(
857+
id, reset_epoch, /*compute_duration*/ false);
853858
},
854859
/*use_future*/ false);
855860
work->blockingWait_ = blockingWait_;
@@ -948,7 +953,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
948953
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
949954
work->outputs_->push_back(tensor);
950955

951-
work->trace_id_ = FlightRecorderXCCL::get()->record(
956+
auto traceId = FlightRecorderXCCL::get()->record(
952957
local_id_,
953958
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
954959
seqCollective_,
@@ -962,6 +967,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
962967
options_->timeout,
963968
pgStatus_,
964969
true);
970+
971+
work->trace_id_ = traceId.id;
972+
work->trace_reset_epoch_ = traceId.reset_epoch;
965973
}
966974

967975
if (enableNanCheck_ && opType == OpType::SEND) {
@@ -999,9 +1007,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
9991007
}
10001008

10011009
auto id = work->trace_id_;
1010+
auto reset_epoch = work->trace_reset_epoch_;
10021011
work->future_->addCallback(
1003-
[id](at::ivalue::Future&) {
1004-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
1012+
[id, reset_epoch](at::ivalue::Future&) {
1013+
FlightRecorderXCCL::get()->retire_id(
1014+
id, reset_epoch, /*compute_duration*/ false);
10051015
},
10061016
/*use_future*/ false);
10071017
setEnqueuedPgStatus(work);

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
121121
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
122122
uint64_t seq_;
123123
bool isP2P_;
124-
std::optional<uint64_t> trace_id_;
124+
std::optional<size_t> trace_id_;
125+
std::optional<size_t> trace_reset_epoch_;
125126
size_t numelIn_ = -1;
126127
size_t numelOut_ = -1;
127128

0 commit comments

Comments
 (0)