Skip to content

Commit d70bf85

Browse files
committed
update flight recorder usage
Summary: we're making changes to flight recorder's record and reset_id api's. update the usage accordingly.
1 parent ab48005 commit d70bf85

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
483483
: std::nullopt);
484484

485485
if (record) {
486-
r->trace_id_ = FlightRecorderXCCL::get()->record(
486+
auto traceId = FlightRecorderXCCL::get()->record(
487487
local_id_,
488488
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
489489
seqCollective_,
@@ -497,6 +497,9 @@ c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
497497
options_->timeout,
498498
pgStatus_,
499499
isP2P);
500+
501+
r->trace_id_ = traceId.id;
502+
r->trace_reset_epoch_ = traceId.reset_epoch;
500503
}
501504
return r;
502505
}
@@ -803,9 +806,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
803806
c10::ListType::create(c10::TensorType::get()), devices);
804807
work->future_->markCompleted(at::IValue(*work->outputs_));
805808
auto id = work->trace_id_;
809+
auto reset_epoch = work->trace_reset_epoch_;
806810
work->future_->addCallback(
807-
[id](at::ivalue::Future&) {
808-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
811+
[id, reset_epoch](at::ivalue::Future&) {
812+
FlightRecorderXCCL::get()->retire_id(
813+
id, reset_epoch, /*compute_duration*/ false);
809814
},
810815
/*use_future*/ false);
811816
work->blockingWait_ = blockingWait_;
@@ -891,7 +896,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
891896
work->outputs_ = std::make_shared<std::vector<at::Tensor>>();
892897
work->outputs_->push_back(tensor);
893898

894-
work->trace_id_ = FlightRecorderXCCL::get()->record(
899+
auto traceId = FlightRecorderXCCL::get()->record(
895900
local_id_,
896901
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
897902
seqCollective_,
@@ -905,6 +910,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
905910
options_->timeout,
906911
pgStatus_,
907912
true);
913+
work->trace_id_ = traceId.id;
914+
work->trace_reset_epoch_ = traceId.reset_epoch;
908915

909916
c10::OptionalDeviceGuard gpuGuard(device);
910917

@@ -922,9 +929,11 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
922929
c10::ListType::create(c10::TensorType::get()), devices);
923930
work->future_->markCompleted(at::IValue(*work->outputs_));
924931
auto id = work->trace_id_;
932+
auto reset_epoch = work->trace_reset_epoch_;
925933
work->future_->addCallback(
926-
[id](at::ivalue::Future&) {
927-
FlightRecorderXCCL::get()->retire_id(id, /*compute_duration*/ false);
934+
[id, reset_epoch](at::ivalue::Future&) {
935+
FlightRecorderXCCL::get()->retire_id(
936+
id, reset_epoch, /*compute_duration*/ false);
928937
},
929938
/*use_future*/ false);
930939

@@ -2059,8 +2068,8 @@ c10::DeviceIndex ProcessGroupXCCL::guessDeviceId() const {
20592068
} else if (!usedDeviceIdxs_.empty()) {
20602069
return *usedDeviceIdxs_.begin();
20612070
}
2062-
int devIdx =
2063-
static_cast<int16_t>(globalRank() % at::detail::getXPUHooks().getNumGPUs());
2071+
int devIdx = static_cast<int16_t>(
2072+
globalRank() % at::detail::getXPUHooks().getNumGPUs());
20642073
LOG(WARNING)
20652074
<< logPrefix()
20662075
<< c10::str(

src/xccl/ProcessGroupXCCL.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
110110
uint64_t seq_;
111111
bool isP2P_;
112112
std::optional<uint64_t> trace_id_;
113+
std::optional<uint64_t> trace_reset_epoch_;
113114
size_t numelIn_ = -1;
114115
size_t numelOut_ = -1;
115116

0 commit comments

Comments
 (0)