Skip to content

Commit 7c0f315

Browse files
frost-intelmarkc-614
authored andcommitted
[fr] [xpu] Add FlightRecorder support for ProcessGroupXCCL (pytorch#158568)
Adds support for FlightRecorder in ProcessGroupXCCL. See intel/torch-xpu-ops#1867 for XCCL implementation and more details. Pull Request resolved: pytorch#158568 Approved by: https://github.com/guangyey, https://github.com/fduwjj
1 parent 1af6e04 commit 7c0f315

File tree

14 files changed

+80
-43
lines changed

14 files changed

+80
-43
lines changed

test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
386386
ASSERT_TRUE(
387387
setenv(c10d::TORCH_NCCL_ENABLE_MONITORING[0].c_str(), "1", 1) == 0);
388388
auto tempFilename = c10::str(
389-
std::filesystem::temp_directory_path().string(), "/nccl_trace_rank_");
389+
std::filesystem::temp_directory_path().string(), "/comm_lib_trace_rank_");
390390
ASSERT_TRUE(
391391
setenv("TORCH_NCCL_DEBUG_INFO_TEMP_FILE", tempFilename.c_str(), 1) == 0);
392392
// Enable nccl flight recorder.
@@ -401,7 +401,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
401401
// The only difference is that we are storing traces also in memory for
402402
// validation.
403403
std::string fileNamePrefix = c10d::getCvarString(
404-
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
404+
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/comm_lib_trace_rank_");
405405
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
406406
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
407407
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();

test/distributed/test_c10d_gloo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2449,7 +2449,7 @@ def tearDown(self) -> None:
24492449

24502450
def _verify_trace(self, t, is_json):
24512451
ver = t["version"]
2452-
self.assertEqual(ver, "2.9")
2452+
self.assertEqual(ver, "2.10")
24532453
pg_config = t["pg_config"]
24542454
self.assertEqual(len(pg_config), 1)
24552455
default_pg_info = pg_config["0"]

test/distributed/test_c10d_nccl.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4361,10 +4361,12 @@ def started_or_scheduled(self, timing_enabled):
43614361
class NCCLTraceTest(NCCLTraceTestBase):
43624362
def _verify_trace(self, t, include_collectives, timing_enabled, is_json):
43634363
ver = t["version"]
4364-
self.assertEqual(ver, "2.9")
4365-
nccl_version = t["nccl_version"]
4366-
torch_nccl_version = torch.cuda.nccl.version()
4367-
self.assertEqual(nccl_version, ".".join(str(v) for v in torch_nccl_version))
4364+
self.assertEqual(ver, "2.10")
4365+
comm_lib_version = t["comm_lib_version"]
4366+
torch_comm_lib_version = torch.cuda.nccl.version()
4367+
self.assertEqual(
4368+
comm_lib_version, ".".join(str(v) for v in torch_comm_lib_version)
4369+
)
43684370
pg_config = t["pg_config"]
43694371
self.assertEqual(len(pg_config), 1)
43704372
default_pg_info = pg_config["0"]

third_party/xpu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
77cc792cd265179745d335579d233e6d4f9a2667
1+
77cc792cd265179745d335579d233e6d4f9a2667

tools/flight_recorder/components/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,10 @@ def __init__(
388388
self, event: dict[Any, Any], memberships: dict[str, set[Any]], pg_name: str
389389
):
390390
self.profiling_name = event["profiling_name"]
391-
nccl, name = self.profiling_name.split(":")
392-
assert nccl == "nccl", f"name formatting error? {nccl} != 'nccl'"
391+
comm_lib_backend, name = self.profiling_name.split(":")
392+
assert comm_lib_backend in ["nccl", "xccl"], (
393+
f"name formatting error? {comm_lib_backend} != 'nccl' or 'xccl'"
394+
)
393395
parts = name.split(" ")
394396
type = parts[0]
395397
meta = parts[1] if len(parts) == 2 else None

torch/_C/_distributed_c10d.pyi

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ class Backend:
298298
def _timeout(self) -> timedelta: ...
299299
@_timeout.setter
300300
def _timeout(self, val: timedelta) -> None: ...
301+
global_ranks_in_group: list[int]
302+
group_name: str
301303

302304
def __init__(
303305
self,
@@ -608,8 +610,6 @@ class ProcessGroupGloo(Backend):
608610
class Options(Backend.Options):
609611
devices: list[ProcessGroupGloo.Device]
610612
threads: int
611-
global_ranks_in_group: list[int]
612-
group_name: str
613613

614614
def __init__(self): ...
615615

@@ -651,8 +651,6 @@ class ProcessGroupNCCL(Backend):
651651
is_high_priority_stream: bool
652652
split_from: ProcessGroupNCCL
653653
split_color: int
654-
global_ranks_in_group: list[int]
655-
group_name: str
656654

657655
def __init__(self, is_high_priority_stream: bool = False): ...
658656

@@ -830,12 +828,18 @@ class _SymmetricMemory:
830828
def signal_pad_size(self) -> int: ...
831829

832830
class ProcessGroupXCCL(Backend):
831+
class Options(Backend.Options):
832+
def __init__(self): ...
833+
833834
def __init__(
834835
self,
835836
store: Store,
836837
rank: int,
837838
size: int,
838-
): ...
839+
options: Options,
840+
) -> None: ...
841+
@property
842+
def options(self) -> Options: ... # type: ignore[override]
839843

840844
def _set_process_group(pg: ProcessGroup) -> None: ...
841845
def _current_process_group() -> ProcessGroup: ...

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
4747
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
4848
const std::string backend;
4949
std::string group_name;
50+
std::vector<uint64_t> global_ranks_in_group;
5051
};
5152

5253
explicit Backend(int rank, int size);

torch/csrc/distributed/c10d/FlightRecorder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
3939
auto cacheDirPath = std::filesystem::path(homeDir + "/.cache/torch");
4040
// Create the .cache directory if it doesn't exist
4141
std::filesystem::create_directories(cacheDirPath);
42-
auto defaultLocation = cacheDirPath / "nccl_trace_rank_";
42+
auto defaultLocation = cacheDirPath / "comm_lib_trace_rank_";
4343

4444
// For internal bc compatibility, we keep the old the ENV check.
4545
std::string fileNamePrefix = getCvarString(

torch/csrc/distributed/c10d/FlightRecorder.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ namespace c10d {
2020
// (minor when adding fields, major when changing existing fields)
2121
// Also update both JSON and Pickle dumps to make use of the newly defined
2222
// field(s).
23-
DEFINE_CONSTANT(version_val, "2.9")
23+
DEFINE_CONSTANT(version_val, "2.10")
2424
DEFINE_CONSTANT(entries_key, "entries")
2525
DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state")
26-
DEFINE_CONSTANT(nccl_version_key, "nccl_version")
26+
DEFINE_CONSTANT(comm_lib_version_key, "comm_lib_version")
2727
DEFINE_CONSTANT(version_key, "version")
2828
DEFINE_CONSTANT(pg_config_key, "pg_config")
2929
DEFINE_CONSTANT(pg_status_key, "pg_status")
@@ -179,7 +179,7 @@ struct FlightRecorder {
179179
std::map<size_t, std::shared_ptr<ProcessGroupStatus>> all_pg_status_ = {};
180180
std::map<std::tuple<std::string, std::string>, std::vector<uint64_t>>
181181
pg_name_to_ranks_ = {};
182-
std::string nccl_version_;
182+
std::string comm_lib_version_;
183183

184184
std::optional<size_t> record(
185185
size_t pg_id,
@@ -200,7 +200,7 @@ struct FlightRecorder {
200200
const std::tuple<std::string, std::string>& pg_name,
201201
std::vector<uint64_t> ranks);
202202

203-
void record_accelerator_version(const std::string nccl_version);
203+
void record_accelerator_version(const std::string comm_lib_version);
204204

205205
void update_state(Entry& r);
206206

torch/csrc/distributed/c10d/FlightRecorderDetail.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ void FlightRecorder<EventType>::record_pg_ranks(
128128

129129
template <typename EventType>
130130
void FlightRecorder<EventType>::record_accelerator_version(
131-
const std::string nccl_version) {
131+
const std::string comm_lib_version) {
132132
if (!enabled_) {
133133
return;
134134
}
135135
std::lock_guard<std::mutex> guard(mutex_);
136-
nccl_version_ = std::move(nccl_version);
136+
comm_lib_version_ = std::move(comm_lib_version);
137137
}
138138

139139
template <typename EventType>
@@ -425,7 +425,7 @@ std::string FlightRecorder<EventType>::dump_json(
425425
bool onlyActive) {
426426
json result;
427427
result[version_key_str] = version_val_str;
428-
result[nccl_version_key_str] = nccl_version_;
428+
result[comm_lib_version_key_str] = comm_lib_version_;
429429
result[pg_config_key_str] = getPgConfigJson();
430430
result[pg_status_key_str] = getPgStatusJson();
431431

@@ -522,7 +522,7 @@ std::string FlightRecorder<EventType>::dump(
522522
// common values
523523
result.insert(version_key, version_val);
524524
result.insert(pg_config_key, getPgConfig());
525-
result.insert(nccl_version_key_str, nccl_version_);
525+
result.insert(comm_lib_version_key_str, comm_lib_version_);
526526
result.insert(pg_status_key, getPgStatus());
527527

528528
// collective trace

0 commit comments

Comments
 (0)