From 0d38985cce97af4a9e4afc5afba99495b8a60d58 Mon Sep 17 00:00:00 2001 From: Han Chao Date: Thu, 21 Aug 2025 15:53:59 +0800 Subject: [PATCH] Support premul_sum --- src/xccl/ProcessGroupXCCL.cpp | 143 +++++++++++++++++---- src/xccl/ProcessGroupXCCL.hpp | 29 ++++- test/xpu/distributed/test_c10d_ops_xccl.py | 41 +++++- 3 files changed, 180 insertions(+), 33 deletions(-) diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 240a54cea..bb502cb9d 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -1,7 +1,7 @@ #ifdef USE_C10D_XCCL -#include #include +#include #include #include @@ -17,6 +17,12 @@ namespace { #define XCCL_HAS_AVG 1 #endif // oneCCL version >= 2021.15 +#if defined(CCL_MAJOR_VERSION) && \ + ((CCL_MAJOR_VERSION > 2021) || \ + (CCL_MAJOR_VERSION == 2021) && (CCL_MINOR_VERSION >= 17)) +#define ENABLE_XCCL_PREMUL_SUM_SUPPORT +#endif // oneCCL version >= 2021.17 + const std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, @@ -44,6 +50,33 @@ const std::map xcclDatatypes = { {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; +struct xcclRedOpRAII { + xcclRedOpRAII() = default; + xcclRedOpRAII(ccl::reduction op) : op_(op) {} + xcclRedOpRAII(ccl::reduction op, const xcclComm_t* comm) + : op_(op), comm_(comm), premul_sum_(true) {} + xcclRedOpRAII(const xcclRedOpRAII&) = delete; + xcclRedOpRAII& operator=(const xcclRedOpRAII&) = delete; + xcclRedOpRAII(xcclRedOpRAII&& tmp) noexcept : xcclRedOpRAII() { + std::swap(tmp.op_, this->op_); + std::swap(tmp.comm_, this->comm_); + std::swap(tmp.premul_sum_, this->premul_sum_); + } +#if defined(ENABLE_XCCL_PREMUL_SUM_SUPPORT) + ~xcclRedOpRAII() { + if (premul_sum_) { + ccl::reduction_destroy(op_, *comm_); + } + } +#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT + operator ccl::reduction() const { + return op_; + } + ccl::reduction op_{}; + const xcclComm_t* comm_ = nullptr; + bool premul_sum_ = false; +}; + bool computeLengthsAndCheckAndGetFlat( const std::vector& tensors, std::vector& lengths, @@ -152,7 +185,37 @@ ccl::datatype getXcclDataType( return it->second; } -ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { +#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT +template +xcclRedOpRAII unpackPreMulSum( + const ReduceOp& reduceOp, + const xcclComm_t& comm) { + const auto* preMulSupplement = + reinterpret_cast(reduceOp.supplement_.get()); + ccl::reduction preMulSum{}; + bool has_tensor = preMulSupplement->tensor_factor.defined(); + auto residence = has_tensor + ? ccl::scalar_residence_type::scalar_device + : ccl::scalar_residence_type::scalar_host_immediate; + const T* ptr_factor = has_tensor + ? preMulSupplement->tensor_factor.const_data_ptr() + : nullptr; + T scalar_factor = T(preMulSupplement->double_factor); + ccl::reduction_create_pre_mul_sum( + &preMulSum, + /*scalar=*/has_tensor ? const_cast(ptr_factor) : &scalar_factor, + dataType, + residence, + comm); + return xcclRedOpRAII(preMulSum, &comm); +} +#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT + +xcclRedOpRAII getXcclReduceOp( + const ReduceOp& reduceOp, + at::Tensor& input, + const ccl::datatype& dataType, + xcclComm_t& comm) { try { if (input.scalar_type() == at::kBool) { if (reduceOp == ReduceOp::SUM) { @@ -171,6 +234,30 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { return ccl::reduction::sum; } #endif + if (reduceOp == ReduceOp::PREMUL_SUM) { +#ifdef ENABLE_XCCL_PREMUL_SUM_SUPPORT + switch (dataType) { + case ccl::datatype::float16: + return unpackPreMulSum( + reduceOp, comm); + case ccl::datatype::float32: + return unpackPreMulSum(reduceOp, comm); + case ccl::datatype::bfloat16: + return unpackPreMulSum( + reduceOp, comm); + case ccl::datatype::float64: + return unpackPreMulSum( + reduceOp, comm); + default: + C10_THROW_ERROR( + TypeError, + "PreMulSum Data type must be half, float, bfloat16 or double"); + return ccl::reduction{}; + } +#else + C10_THROW_ERROR(ValueError, "PreMulSum requires oneCCL>=2021.17"); +#endif // ENABLE_XCCL_PREMUL_SUM_SUPPORT + } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( @@ -207,11 +294,11 @@ std::string dump_xccl_trace( bool includeCollectives, bool includeStackTraces, bool onlyActive) { - auto xcclDumpMap = std::unordered_map< - std::string, - std::unordered_map>(); - return FlightRecorderXCCL::get()->dump( - xcclDumpMap, includeCollectives, includeStackTraces, onlyActive); + auto xcclDumpMap = std::unordered_map< + std::string, + std::unordered_map>(); + return FlightRecorderXCCL::get()->dump( + xcclDumpMap, includeCollectives, includeStackTraces, onlyActive); } constexpr int64_t kSynchronizeBusyWaitMillis = 10; @@ -317,9 +404,7 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } -ProcessGroupXCCL::Options::Options() - : Backend::Options(XCCL_BACKEND_NAME) {} - +ProcessGroupXCCL::Options::Options() : Backend::Options(XCCL_BACKEND_NAME) {} static std::atomic process_group_id = 0; @@ -362,7 +447,8 @@ ProcessGroupXCCL::ProcessGroupXCCL( traceBufferSize_ = getCvarInt({"TORCH_FR_BUFFER_SIZE"}, 2000); this->setGroupUid(options_->group_name); - // In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just do it here. + // In PGNCCL, the pg ranks are recorded on comm setup in each op, but we just + // do it here. const auto XcclVersion = getXcclVersion(); FlightRecorderXCCL::get()->record_pg_ranks( std::make_tuple(pg_uid_, pg_desc_), groupRanks()); @@ -435,7 +521,8 @@ void ProcessGroupXCCL::setCompletedPgStatus( pgStatus_->lastCompletedNumelIn = work->numelIn_; pgStatus_->lastCompletedNumelOut = work->numelOut_; // To avoid complexity, we're not computing duration. - FlightRecorderXCCL::get()->retire_id(work->trace_id_, /*compute_duration*/false); + FlightRecorderXCCL::get()->retire_id( + work->trace_id_, /*compute_duration*/ false); } void ProcessGroupXCCL::setSequenceNumberForGroup() {} @@ -768,9 +855,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::ListType::create(c10::TensorType::get()), devices); work->future_->markCompleted(at::IValue(*work->outputs_)); work->future_->addCallback( - [this, work](at::ivalue::Future&) { - this->setCompletedPgStatus(work); - }); + [this, work](at::ivalue::Future&) { this->setCompletedPgStatus(work); }); work->blockingWait_ = blockingWait_; work->numelIn_ = 0; @@ -881,10 +966,9 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get()), devices); work->future_->markCompleted(at::IValue(*work->outputs_)); - work->future_->addCallback( - [this, work](at::ivalue::Future&) { - this->setCompletedPgStatus(work); - }); + work->future_->addCallback([this, work](at::ivalue::Future&) { + this->setCompletedPgStatus(work); + }); work->numelIn_ = work->numelOut_ = tensor.numel(); setEnqueuedPgStatus(work); @@ -1269,7 +1353,8 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( at::xpu::XPUStream& stream, ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -1366,7 +1451,8 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( at::xpu::XPUStream& stream, ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -1528,7 +1614,8 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( ccl::stream& xcclStream) { const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + const auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1572,7 +1659,8 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( ccl::stream& xcclStream) { const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + const auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1832,7 +1920,8 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( at::xpu::XPUStream& stream, ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1927,7 +2016,8 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::xpu::XPUStream& stream, ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), @@ -1986,7 +2076,8 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( at::xpu::XPUStream& stream, ccl::stream& xcclStream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto xcclReduceOp = + getXcclReduceOp(opts.reduceOp, input, xcclDataType, comm); ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 42c199735..06a1d63a7 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -129,10 +129,10 @@ class TORCH_API ProcessGroupXCCL : public Backend { }; ProcessGroupXCCL( - const c10::intrusive_ptr& store, - int rank, - int size, - c10::intrusive_ptr options = Options::create()); + const c10::intrusive_ptr& store, + int rank, + int size, + c10::intrusive_ptr options = Options::create()); C10_DEPRECATED ProcessGroupXCCL( const c10::intrusive_ptr& store, @@ -415,7 +415,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { const std::vector& groupRanks() const; void setEnqueuedPgStatus(c10::intrusive_ptr work); - void setCompletedPgStatus(c10::intrusive_ptr work); + void setCompletedPgStatus( + c10::intrusive_ptr work); bool dumpDebuggingInfo(bool includeStackTrace = true); protected: @@ -493,6 +494,24 @@ TORCH_API std::string dump_xccl_trace( bool onlyActive); TORCH_API std::string getXcclVersion(); + +struct XCCLPreMulSumSupplement : _SupplementBase { + double double_factor{0.0}; + at::Tensor tensor_factor; + XCCLPreMulSumSupplement(double f) : double_factor{f} {} + XCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} { + TORCH_CHECK_EQ(tensor_factor.numel(), 1); + } +}; + +template +ReduceOp makeXCCLPreMulSum(const T& factor) { + ReduceOp rop; + rop.op_ = ReduceOp::PREMUL_SUM; + rop.supplement_ = c10::make_intrusive(factor); + return rop; +} + } // namespace c10d namespace { diff --git a/test/xpu/distributed/test_c10d_ops_xccl.py b/test/xpu/distributed/test_c10d_ops_xccl.py index 95c577de1..341f3360b 100644 --- a/test/xpu/distributed/test_c10d_ops_xccl.py +++ b/test/xpu/distributed/test_c10d_ops_xccl.py @@ -23,7 +23,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) from test_c10d_xccl import init_multigpu_helper, requires_xccl -from torch.testing._internal.common_distributed import MultiProcContinousTest +from torch.testing._internal.common_distributed import MultiProcContinuousTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -42,7 +42,7 @@ TEST_MULTIGPU = TEST_XPU and torch.xpu.device_count() >= 2 -class ProcessGroupXCCLOpTest(MultiProcContinousTest): +class ProcessGroupXCCLOpTest(MultiProcContinuousTest): @classmethod def backend_str(cls) -> str: return "xccl" @@ -260,6 +260,28 @@ def reduce(xs, rootRank, rootTensor, op=None): ): reduce(tensors, self.rank, rt, op) + for factor in (3.0, torch.tensor([5.0], device=local_device_id)): + if isinstance(factor, torch.Tensor): + factor_ref = factor.cpu().item() + else: + factor_ref = factor + float_tensors = [ + torch.tensor( + [self.rank + 1.0], device=f"xpu:{local_device_id}" + ) + ] + float_tensors_ref = [ + torch.tensor( + [(self.rank + 1.0) * factor_ref], + device=f"xpu:{local_device_id}", + ) + ] + + reduce(float_tensors_ref, rt, 0) + reduce(float_tensors, rt, 0, c10d._make_xccl_premul_sum(factor)) + if self.rank == rt: + self.assertEqual(float_tensors_ref[0], float_tensors[0]) + @requires_xccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") def test_allgather_ops(self): @@ -713,6 +735,21 @@ def perm(n, k): expected = torch.tensor(prod_val) self.assertEqual(expected, output_tensor) + for factor in (3.0, torch.tensor([5.0], device=self.rank)): + if isinstance(factor, torch.Tensor): + factor_ref = factor.cpu().item() + else: + factor_ref = factor + output = [t.float() for t in output] + tensor_lists = [[t.float() for t in tl] for tl in tensor_lists] + output_ref = [t.float() for t in output] + tensor_lists_ref = [ + [t.float() * factor_ref for t in tl] for tl in tensor_lists + ] + reduce_scatter(output, tensor_lists, c10d._make_xccl_premul_sum(factor)) + reduce_scatter(output_ref, tensor_lists_ref, c10d.ReduceOp.SUM) + self.assertEqual(output_ref, output) + @requires_xccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "XCCL test requires 2+ GPUs") def test_reduce_scatter_base_ops(self):