Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 117 additions & 26 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/FlightRecorderDetail.hpp>
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <xccl/NanCheck_XPU.hpp>
#include <xccl/ProcessGroupXCCL.hpp>

Expand All @@ -17,6 +17,12 @@ namespace {
#define XCCL_HAS_AVG 1
#endif // oneCCL version >= 2021.15

#if defined(CCL_MAJOR_VERSION) && \
((CCL_MAJOR_VERSION > 2021) || \
(CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 17))
#define ENABLE_XCCL_PREMUL_SUM_SUPPORT
#endif // oneCCL version >= 2021.17

const std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
{ReduceOp::MIN, ccl::reduction::min},
{ReduceOp::MAX, ccl::reduction::max},
Expand Down Expand Up @@ -44,6 +50,33 @@ const std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kFloat8_e5m2fnuz, ccl::datatype::uint8},
};

struct xcclRedOpRAII {
xcclRedOpRAII() = default;
xcclRedOpRAII(ccl::reduction op) : op_(op) {}
xcclRedOpRAII(ccl::reduction op, const xcclComm_t* comm)
: op_(op), comm_(comm), premul_sum_(true) {}
xcclRedOpRAII(const xcclRedOpRAII&) = delete;
xcclRedOpRAII& operator=(const xcclRedOpRAII&) = delete;
xcclRedOpRAII(xcclRedOpRAII&& tmp) noexcept : xcclRedOpRAII() {
std::swap(tmp.op_, this->op_);
std::swap(tmp.comm_, this->comm_);
std::swap(tmp.premul_sum_, this->premul_sum_);
}
#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
~xcclRedOpRAII() {
if (premul_sum_) {
ccl::reduction_destroy(op_, *comm_);
}
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
Copy link

Copilot AI Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The destructor is only defined when ENABLE_XCCL_PREMUL_SUM_SUPPORT is defined, but the class is used regardless of this macro. This will cause linking errors when the macro is not defined. The destructor should be defined unconditionally with appropriate conditional logic inside.

Suggested change
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
~xcclRedOpRAII() {
#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT)
if (premul_sum_) {
ccl::reduction_destroy(op_, *comm_);
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
}

Copilot uses AI. Check for mistakes.
operator ccl::reduction() const {
return op_;
}
ccl::reduction op_{};
const xcclComm_t* comm_ = nullptr;
bool premul_sum_ = false;
};

bool computeLengthsAndCheckAndGetFlat(
const std::vector<at::Tensor>& tensors,
std::vector<size_t>& lengths,
Expand Down Expand Up @@ -152,7 +185,37 @@ ccl::datatype getXcclDataType(
return it->second;
}

ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
template <typename T, ccl::datatype dataType>
xcclRedOpRAII unpackPreMulSum(
const ReduceOp& reduceOp,
const xcclComm_t& comm) {
const auto* preMulSupplement =
reinterpret_cast<XCCLPreMulSumSupplement*>(reduceOp.supplement_.get());
ccl::reduction preMulSum{};
bool has_tensor = preMulSupplement->tensor_factor.defined();
auto residence = has_tensor
? ccl::scalar_residence_type::scalar_device
: ccl::scalar_residence_type::scalar_host_immediate;
const T* ptr_factor = has_tensor
? preMulSupplement->tensor_factor.const_data_ptr<T>()
: nullptr;
T scalar_factor = T(preMulSupplement->double_factor);
ccl::reduction_create_pre_mul_sum(
&preMulSum,
/*scalar=*/has_tensor ? const_cast<T*>(ptr_factor) : &scalar_factor,
dataType,
residence,
comm);
return xcclRedOpRAII(preMulSum, &comm);
}
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT

xcclRedOpRAII getXcclReduceOp(
const ReduceOp& reduceOp,
at::Tensor& input,
const ccl::datatype& dataType,
xcclComm_t& comm) {
try {
if (input.scalar_type() == at::kBool) {
if (reduceOp == ReduceOp::SUM) {
Expand All @@ -171,6 +234,30 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
return ccl::reduction::sum;
}
#endif
if (reduceOp == ReduceOp::PREMUL_SUM) {
#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT
switch (dataType) {
case ccl::datatype::float16:
return unpackPreMulSum<at::Half, ccl::datatype::float16>(
reduceOp, comm);
case ccl::datatype::float32:
return unpackPreMulSum<float, ccl::datatype::float32>(reduceOp, comm);
case ccl::datatype::bfloat16:
return unpackPreMulSum<float, ccl::datatype::bfloat16>(
Copy link

Copilot AI Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For bfloat16 data type, the template should use at::BFloat16 instead of float. Using float for bfloat16 data will cause type mismatch issues when accessing the tensor data.

Suggested change
return unpackPreMulSum<float, ccl::datatype::bfloat16>(
return unpackPreMulSum<at::BFloat16, ccl::datatype::bfloat16>(

Copilot uses AI. Check for mistakes.
reduceOp, comm);
case ccl::datatype::float64:
return unpackPreMulSum<double, ccl::datatype::float64>(
reduceOp, comm);
default:
C10_THROW_ERROR(
TypeError,
"PreMulSum Data type must be half, float, bfloat16 or double");
return ccl::reduction{};
}
#else
C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17");
#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT
}
return xcclOps.at(reduceOp);
} catch (const std::out_of_range&) {
C10_THROW_ERROR(
Expand Down Expand Up @@ -207,11 +294,11 @@ std::string dump_xccl_trace(
bool includeCollectives,
bool includeStackTraces,
bool onlyActive) {
auto xcclDumpMap = std::unordered_map<
std::string,
std::unordered_map<std::string, std::string>>();
return FlightRecorderXCCL::get()->dump(
xcclDumpMap, includeCollectives, includeStackTraces, onlyActive);
auto xcclDumpMap = std::unordered_map<
std::string,
std::unordered_map<std::string, std::string>>();
return FlightRecorderXCCL::get()->dump(
xcclDumpMap, includeCollectives, includeStackTraces, onlyActive);
}

constexpr int64_t kSynchronizeBusyWaitMillis = 10;
Expand Down Expand Up @@ -317,9 +404,7 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
return true;
}

ProcessGroupXCCL::Options::Options()
: Backend::Options(XCCL_BACKEND_NAME) {}

ProcessGroupXCCL::Options::Options() : Backend::Options(XCCL_BACKEND_NAME) {}

static std::atomic<size_t> process_group_id = 0;

Expand Down Expand Up @@ -362,7 +447,8 @@ ProcessGroupXCCL::ProcessGroupXCCL(
traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000);

this->setGroupUid(options_->group_name);
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here.
// In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just
// do it here.
const auto XcclVersion = getXcclVersion();
FlightRecorderXCCL::get()->record_pg_ranks(
std::make_tuple(pg_uid_, pg_desc_), groupRanks());
Expand Down Expand Up @@ -435,7 +521,8 @@ void ProcessGroupXCCL::setCompletedPgStatus(
pgStatus_->lastCompletedNumelIn = work->numelIn_;
pgStatus_->lastCompletedNumelOut = work->numelOut_;
// To avoid complexity, we're not computing duration.
FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false);
FlightRecorderXCCL::get()->retire_id(
work->trace_id_, /*compute_duration*/ false);
}

void ProcessGroupXCCL::setSequenceNumberForGroup() {}
Expand Down Expand Up @@ -768,9 +855,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});
[this, work](at::ivalue::Future&) { this->setCompletedPgStatus(work); });
work->blockingWait_ = blockingWait_;

work->numelIn_ = 0;
Expand Down Expand Up @@ -881,10 +966,9 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::pointToPoint(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
work->future_->markCompleted(at::IValue(*work->outputs_));
work->future_->addCallback(
[this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});
work->future_->addCallback([this, work](at::ivalue::Future&) {
this->setCompletedPgStatus(work);
});

work->numelIn_ = work->numelOut_ = tensor.numel();
setEnqueuedPgStatus(work);
Expand Down Expand Up @@ -1269,7 +1353,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_impl(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1366,7 +1451,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1528,7 +1614,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce(
ccl::stream& xcclStream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type(), true);
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
const auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1572,7 +1659,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_oop(
ccl::stream& xcclStream) {
const int root = opts.rootRank + opts.rootTensor;
const auto xcclDataType = getXcclDataType(input.scalar_type(), true);
const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
const auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1832,7 +1920,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1927,7 +2016,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::_reduce_scatter_base(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down Expand Up @@ -1986,7 +2076,8 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::reduce_scatter_tensor_coalesced(
at::xpu::XPUStream& stream,
ccl::stream& xcclStream) {
auto xcclDataType = getXcclDataType(input.scalar_type(), true);
auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input);
auto xcclReduceOp =
getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm);
ccl::reduce_scatter(
input.data_ptr(),
output.data_ptr(),
Expand Down
29 changes: 24 additions & 5 deletions src/xccl/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ class TORCH_API ProcessGroupXCCL : public Backend {
};

ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options = Options::create());
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options = Options::create());

C10_DEPRECATED ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
Expand Down Expand Up @@ -415,7 +415,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {

const std::vector<uint64_t>& groupRanks() const;
void setEnqueuedPgStatus(c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work);
void setCompletedPgStatus(c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work);
void setCompletedPgStatus(
c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work);
bool dumpDebuggingInfo(bool includeStackTrace = true);

protected:
Expand Down Expand Up @@ -493,6 +494,24 @@ TORCH_API std::string dump_xccl_trace(
bool onlyActive);

TORCH_API std::string getXcclVersion();

struct XCCLPreMulSumSupplement : _SupplementBase {
double double_factor{0.0};
at::Tensor tensor_factor;
XCCLPreMulSumSupplement(double f) : double_factor{f} {}
XCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} {
TORCH_CHECK_EQ(tensor_factor.numel(), 1);
}
};

template <typename T>
ReduceOp makeXCCLPreMulSum(const T& factor) {
ReduceOp rop;
rop.op_ = ReduceOp::PREMUL_SUM;
rop.supplement_ = c10::make_intrusive<XCCLPreMulSumSupplement>(factor);
return rop;
}

} // namespace c10d

namespace {
Expand Down
41 changes: 39 additions & 2 deletions test/xpu/distributed/test_c10d_ops_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from test_c10d_xccl import init_multigpu_helper, requires_xccl
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_distributed import MultiProcContinuousTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
Expand All @@ -42,7 +42,7 @@
TEST_MULTIGPU = TEST_XPU and torch.xpu.device_count() >= 2


class ProcessGroupXCCLOpTest(MultiProcContinousTest):
class ProcessGroupXCCLOpTest(MultiProcContinuousTest):
@classmethod
def backend_str(cls) -> str:
return "xccl"
Expand Down Expand Up @@ -260,6 +260,28 @@ def reduce(xs, rootRank, rootTensor, op=None):
):
reduce(tensors, self.rank, rt, op)

for factor in (3.0, torch.tensor([5.0], device=local_device_id)):
if isinstance(factor, torch.Tensor):
factor_ref = factor.cpu().item()
else:
factor_ref = factor
float_tensors = [
torch.tensor(
[self.rank + 1.0], device=f"xpu:{local_device_id}"
)
]
float_tensors_ref = [
torch.tensor(
[(self.rank + 1.0) * factor_ref],
device=f"xpu:{local_device_id}",
)
]

reduce(float_tensors_ref, rt, 0)
reduce(float_tensors, rt, 0, c10d._make_xccl_premul_sum(factor))
if self.rank == rt:
self.assertEqual(float_tensors_ref[0], float_tensors[0])

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_allgather_ops(self):
Expand Down Expand Up @@ -713,6 +735,21 @@ def perm(n, k):
expected = torch.tensor(prod_val)
self.assertEqual(expected, output_tensor)

for factor in (3.0, torch.tensor([5.0], device=self.rank)):
if isinstance(factor, torch.Tensor):
factor_ref = factor.cpu().item()
else:
factor_ref = factor
output = [t.float() for t in output]
tensor_lists = [[t.float() for t in tl] for tl in tensor_lists]
output_ref = [t.float() for t in output]
tensor_lists_ref = [
[t.float() * factor_ref for t in tl] for tl in tensor_lists
]
reduce_scatter(output, tensor_lists, c10d._make_xccl_premul_sum(factor))
reduce_scatter(output_ref, tensor_lists_ref, c10d.ReduceOp.SUM)
self.assertEqual(output_ref, output)

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs")
def test_reduce_scatter_base_ops(self):
Expand Down
Loading