diff --git a/cinn/common/target.cc b/cinn/common/target.cc index c2b26601b8..8646bd4d7e 100644 --- a/cinn/common/target.cc +++ b/cinn/common/target.cc @@ -11,13 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include "cinn/common/target.h" +#ifdef CINN_WITH_CUDA +#include +#include +#endif #include #include +#include "cinn/common/target.h" #include "cinn/runtime/cinn_runtime.h" namespace cinn { @@ -49,6 +52,24 @@ int Target::max_num_threads() const { return 1024; } +int Target::get_multi_processor_count() const { + CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get multi processor count"; + int num_sm = 0; +#ifdef CINN_WITH_CUDA + cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); +#endif + return num_sm; +} + +int Target::get_max_threads_per_sm() const { + CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max threads per stream processor"; + int max_thread = 0; +#ifdef CINN_WITH_CUDA + cudaDeviceGetAttribute(&max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); +#endif + return max_thread; +} + std::vector Target::get_target_libs() const { return libs; } int Target::get_target_bits() const { diff --git a/cinn/common/target.h b/cinn/common/target.h index 33dacefe29..f9fe56efa7 100755 --- a/cinn/common/target.h +++ b/cinn/common/target.h @@ -80,6 +80,10 @@ struct Target { int max_num_threads() const; + int get_multi_processor_count() const; + + int get_max_threads_per_sm() const; + int get_target_bits() const; std::vector get_target_libs() const; diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 8b62d25a5a..c142282a13 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -117,7 +117,14 @@ Variable NetBuilder::Reduce(const std::string& op_type, const Variable& x, const return Reshape(x, new_shape); } } - return CustomInstr(op_type, {x}, {{"dim", dim}, {"keep_dim", keep_dim}}).front(); + // Convert the negative dim to a positive number + std::vector reduce_dim(dim.begin(), dim.end()); + for (int i = 0; i < dim.size(); i++) { + if (reduce_dim[i] < 0) { + reduce_dim[i] = x->shape.size() + reduce_dim[i]; + } + } + return CustomInstr(op_type, {x}, {{"dim", reduce_dim}, {"keep_dim", keep_dim}}).front(); } #define NETBUILDER_UNARY_OP_DEF(func_name__, op_type__) \ diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 53a555d8a9..8c1b5452ff 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -1226,8 +1226,9 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } } + auto masters = GetMasters(node, nodes_inline, nodes_set); // node can be inline. - if (CanbeInline(node, consumers, reducer, nodes_in_order.front(), group, nodes_set, this->shape_dict_)) { + if (CanbeInline(node, consumers, reducer, masters, group, nodes_set, this->shape_dict_)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); ir::ComputeInlineChecker checker(ir_sch, block); if (!checker.Check()) { @@ -1327,7 +1328,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map); + SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group); VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } diff --git a/cinn/hlir/framework/op_lowering_test.cc b/cinn/hlir/framework/op_lowering_test.cc index 3b3601055a..336d46ee9b 100644 --- a/cinn/hlir/framework/op_lowering_test.cc +++ b/cinn/hlir/framework/op_lowering_test.cc @@ -1171,6 +1171,75 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_21) { Compile(net_builder); } +TEST(OpFusionPass, Block_Reduce_Fuse_Broadcast) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold - 10; + int w = 256; + NetBuilder net_builder("Block_Reduce_Fuse_Broadcast"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1}, true); + auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Block_Reduce_Fuse_Elementwise) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold - 10; + int w = 256; + NetBuilder net_builder("Block_Reduce_Fuse_Elementwise"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + } + + Compile(net_builder); +} +TEST(OpFusionPass, Warp_Reduce_Fuse_Broadcast) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold + 10; + int w = 256; + NetBuilder net_builder("Warp_Reduce_Fuse_Broadcast"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1}, true); + auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Warp_Reduce_Fuse_Elementwise) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold + 10; + int w = 256; + NetBuilder net_builder("Warp_Reduce_Fuse_Elementwise"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + } + + Compile(net_builder); +} + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index dd0a30e183..9f4c3b9fd4 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -168,11 +168,10 @@ bool IsConstOp(const framework::Node* node) { } std::vector GetInputShape(const Node* node, const absl::flat_hash_map& shape_dict) { - auto producers = GetProducers(node); - CHECK(producers.size()); + auto input_data = GetInputNodeData(node); + CHECK(input_data.size()); - auto producer_data = GetNodeData(producers.front()); - return shape_dict.at(producer_data->id()); + return shape_dict.at(input_data.front()->id()); } std::vector GetOutputShape(const Node* node, const absl::flat_hash_map& shape_dict) { @@ -577,10 +576,25 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, const std::vector& inshape, const std::vector& axes, const common::Target& target) { + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < inshape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= inshape[i]; + } + } + int warp_reduce_need_sm_count = ceil((need_reduce_last_count * 32) / float(target.get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (target.get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } // find first reduce and second reduce axis. - int lane = 1; - int index = static_cast(axes.size()) - 1; - auto max_num_threads = target.max_num_threads(); + int lane = 1; + int index = static_cast(axes.size()) - 1; + for (; index >= 0; --index) { if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { break; @@ -639,7 +653,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, bool CanbeInline(Node* node, const std::vector consumers, const Node* reducer, - const Node* laster, + const std::unordered_set masters, const GroupPtr& group, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict) { @@ -681,10 +695,14 @@ bool CanbeInline(Node* node, return false; } else { auto node_shape = GetOutputShape(node, shape_dict); - auto last_shape = GetOutputShape(laster, shape_dict); - if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != - std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { - return true; + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + + for (auto master : masters) { + auto master_shape = GetOutputShape(master, shape_dict); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + if (node_size != master_size) { + return true; + } } return false; @@ -1316,7 +1334,7 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); if (!group->output_nodes.count(node)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); - ir_sch.SetBuffer(block, "local", true); + ir_sch.SetBuffer(block, "local"); } if (op_pattern_dict[node->op()] == framework::kReduction) { @@ -1373,11 +1391,14 @@ std::unordered_map GetNodeDataSet(const std::unordered_s return node_data_set; } -Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set) { +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set) { // find consumer std::unordered_set visited; std::queue candidates; candidates.push(node); + std::unordered_set masters; while (!candidates.empty()) { auto candidate = candidates.front(); @@ -1392,19 +1413,20 @@ Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const candidates.push(consumer); visited.insert(consumer); } else { - return consumer; + masters.insert(consumer); } } } - return nullptr; + return masters; } void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { + const std::unordered_map& tensor_map, + const GroupPtr& group) { auto exprs_inorder = ir_sch.GetAllBlocks(); auto node_data_set = GetNodeDataSet(nodes_set); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); @@ -1441,34 +1463,35 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, auto node = node_data->source_node.get(); auto node_shape = shape_dict.at(node_data->id()); - auto master = GetMaster(node, nodes_inline, nodes_set); - if (!master) { + auto masters = GetMasters(node, nodes_inline, nodes_set); + if (masters.empty()) { continue; } - auto master_data = GetNodeData(master); - auto master_shape = shape_dict.at(master_data->id()); - if (op_pattern_dict[master->op()] == framework::kReduction) { - master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); - } + bool do_set_buffer_to_shared = false; + for (auto master : masters) { + auto master_data = GetNodeData(master); + auto master_shape = shape_dict.at(master_data->id()); + if (op_pattern_dict[master->op()] == framework::kReduction) { + master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + } - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); - if (node_size == master_size) { - continue; + if (node_size != master_size) { + if (check_sync_mark(idx, master_data->id())) { + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SyncThreads(loops.back(), false); + sync_mark.insert(master_data->id()); + } + do_set_buffer_to_shared = true; + } } - - { + if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) { auto block = ir_sch.GetBlock(node_data->id()); ir_sch.SetBuffer(block, "shared", true); } - - if (check_sync_mark(idx, master_data->id())) { - auto loops = ir_sch.GetLoops(master_data->id()); - ir_sch.SyncThreads(loops.back(), false); - sync_mark.insert(master_data->id()); - } } } diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index 01a33ae876..db92b74c68 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -60,7 +60,7 @@ Node* FindNearestReducer(const Node* node, const std::unordered_set& node bool CanbeInline(Node* node, const std::vector consumers, const Node* reducer, - const Node* laster, + const std::unordered_set masters, const GroupPtr& group, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict); @@ -72,6 +72,10 @@ Node* GetMasterToComputeAt(Node* node, const std::unordered_map& virtual_consumers, const absl::flat_hash_map& shape_dict); +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set); + void LoopAssignReduce(ir::IRSchedule& ir_sch, const Node* node, const Node* reducer, @@ -90,7 +94,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map); + const std::unordered_map& tensor_map, + const GroupPtr& group); } // namespace framework } // namespace hlir diff --git a/cinn/hlir/op/reduction_test.cc b/cinn/hlir/op/reduction_test.cc index b1986be20f..870dda7d5d 100644 --- a/cinn/hlir/op/reduction_test.cc +++ b/cinn/hlir/op/reduction_test.cc @@ -465,6 +465,53 @@ TEST(Operator, Operator_Reduction_Case_11) { GenReduceCode(shape, dim, "Operator_Reduction_Case_11"); } +TEST(Operator, Operator_Reduction_Case_Warp_Reduce) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {warp_reduce_threshold + 10, 256}; + std::vector dim = {1}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce"); + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Block_Reduce) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {warp_reduce_threshold - 10, 33}; + std::vector dim = {1}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce"); + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {(warp_reduce_threshold + 32) / 2, 2, 10, 256}; + std::vector dim = {2, 3}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {(warp_reduce_threshold - 32) / 2, 2, 10, 33}; + std::vector dim = {2, 3}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); +} } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pass/fusion_helper_base.h b/cinn/hlir/pass/fusion_helper_base.h index 94ef3460aa..7658cc0792 100644 --- a/cinn/hlir/pass/fusion_helper_base.h +++ b/cinn/hlir/pass/fusion_helper_base.h @@ -112,6 +112,17 @@ class FusionHelperBase { return producer_node; } + std::vector GetConsumerNode(const Node* node) const { + std::vector consumer_nodes; + auto node_data = GetNodeData(node); + for (auto& link : node_data->outlinks()) { + auto consumer = link->sink()->safe_as(); + CHECK(consumer); + consumer_nodes.push_back(consumer); + } + return consumer_nodes; + } + bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes) const { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 88f54dc566..0121f8f056 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -617,10 +617,6 @@ class FusionMergePassHelper : public FusionHelperBase { void RecomputeWithCostModel(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind == framework::kReduction) { - CHECK_EQ(fusionable_consumers.size(), 1) << "Find more than one consumer can fuse to " << producer->group_id; - } - // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; @@ -818,14 +814,23 @@ class FusionMergePassHelper : public FusionHelperBase { auto& consumers = input_consumers.second; std::unordered_set updated_consumers; for (auto& consumer : consumers) { - // if group is sub group - if (consumer->belong_groups.size()) { - // inset belong group to consumers. - for (auto& belong_group : consumer->belong_groups) { - updated_consumers.insert(belong_group); + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(belong_group); + } else { + fused_groups.push(belong_group); + } + } } - } else { - updated_consumers.insert(consumer); } } consumers = updated_consumers; @@ -976,7 +981,7 @@ class FusionMergePassHelper : public FusionHelperBase { relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. {OpPatternKind::kElementWise, reduce_fuse_elementwise}, // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, // reduce and injective op must be horizontal relation. {OpPatternKind::kInjective, horizontal_with_injective}, // reduce and reduce must be horizontal relation. diff --git a/cinn/hlir/pass/fusion_merge_pass_test.cc b/cinn/hlir/pass/fusion_merge_pass_test.cc index e834da510c..544f86019c 100755 --- a/cinn/hlir/pass/fusion_merge_pass_test.cc +++ b/cinn/hlir/pass/fusion_merge_pass_test.cc @@ -401,7 +401,7 @@ TEST(FusionMergePass, Reduce_Test_2) { auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 4); + CHECK_EQ(graph->fusion_groups.size(), 3); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); CHECK_EQ(graph->fusion_groups.size(), 2); } diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index 696a55f1cf..82bbabd20f 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -285,11 +285,109 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { return elementwise_fuse_reduce(helper, first, second); } -CONDITION_FUNC(reduce_fuse_reduce) { - // check reduce horizontal with reduce. - if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) { - return false; +CONDITION_FUNC(reduce_fuse_broadcast) { + // if same shape with horizontal relation + if (is_same_size(helper, first, second)) { + return true; } + + // Traversing all reducers in all producers requires two types of conditions to be met. + // The first type is the condition that the reducer itself needs to meet, + // and the second type is the condition that the relationship between each reducer and its consumers with type of + // Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as + // before reduce. + for (auto& node_in_master : first->master_nodes) { + if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) { + continue; + } + Node* reducer = node_in_master; + // First type conditions + // Get some reduce infomation + auto reducer_input_shape = helper->GetNodeInputShape(reducer); + auto reducer_output_shape = helper->GetNodeDataShape(reducer); + auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(reducer->attrs.attr_store.at("keep_dim")); + for (auto& axis : reduce_axes) { + if (axis == -1) { + axis = reducer_input_shape.size() - 1; + } + } + // Check if the reduce axes are continuous + int reduce_size = reducer_input_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= reducer_input_shape[idx - 1]; + } + // Check if the reduce size exceeds the hardware limit + if (helper->target_ == common::DefaultNVGPUTarget() && reduce_size > helper->target_.max_num_threads()) { + return false; + } + + // Second type conditions + // Find directly or indirectly consumers with type of Broadcast in the second group + auto find_broadcasters_in_descendants = [&](const Node* producer) -> std::unordered_set { + std::queue candidates; + std::unordered_set visited_set; + std::unordered_set broadcasters; + candidates.push(producer); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : helper->GetConsumerNode(candidate)) { + if (helper->GetOpKind(consumer) == OpPatternKind::kBroadcast && + second->NodeSet().find(consumer) != second->NodeSet().end()) { + broadcasters.insert(consumer); + } else if (!visited_set.count(consumer)) { + visited_set.insert(consumer); + candidates.push(consumer); + } + } + } + + return broadcasters; + }; + + // Check if each broadcast node meets the conditions + std::unordered_set broadcasters_in_consumers = find_broadcasters_in_descendants(reducer); + for (auto broadcaster : broadcasters_in_consumers) { + auto broadcaster_output_shape = absl::get>(broadcaster->attrs.attr_store.at("out_shape")); + auto broadcast_axes = absl::get>(broadcaster->attrs.attr_store.at("broadcast_axes")); + for (auto& axis : broadcast_axes) { + if (axis == -1) { + axis = broadcaster_output_shape.size() - 1; + } + } + + if (reducer_input_shape != broadcaster_output_shape) { + return false; + } + + if (keep_dim) { + continue; + } else { + // if reducer_output_shape = [1] + if (reducer_output_shape.size() == 1 && reducer_output_shape[0] == 1) { + continue; + } + // check union [reduce_axes, broadcast_axes] = reducer_input_shape + for (int idx = 0; idx < reducer_input_shape.size(); ++idx) { + if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + return false; + } + } + } + } + } + + return true; +} + +CONDITION_FUNC(reduce_fuse_reduce) { if (!limit_args(helper, first, second)) { return false; } diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 026f2c6195..021e66e9d3 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -267,7 +267,7 @@ class OpFusionPassHelper : public FusionHelperBase { // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce. - {framework::kElementWise, without_last_dimension_in_reduce}, + {framework::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape and without last dimension in reduce. {framework::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr. diff --git a/cinn/hlir/pe/reduction.cc b/cinn/hlir/pe/reduction.cc index f51121a850..b3a4f82241 100644 --- a/cinn/hlir/pe/reduction.cc +++ b/cinn/hlir/pe/reduction.cc @@ -686,10 +686,25 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, BlockReduceFunc block_reduce_func, ir::Expr initial) { CHECK(!WithoutLastDimInReduce(A->shape, axes)) << "Can't find last axis in reduce!"; + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < A->shape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= A->shape[i].as_int32(); + } + } + int warp_reduce_need_sm_count = + ceil((need_reduce_last_count * 32) / float(common::DefaultNVGPUTarget().get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (common::DefaultNVGPUTarget().get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } - int lane = A->shape[axes.back()].as_int32(); - int index = static_cast(axes.size()) - 2; - auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int lane = A->shape[axes.back()].as_int32(); + int index = static_cast(axes.size()) - 2; for (; index >= 0; --index) { if (lane >= max_num_threads / 2) { break;