From 13424e926fc825222c70b7c59fcfe55b5dcaa16e Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 25 Oct 2025 17:32:07 -0500 Subject: [PATCH 1/4] Fix issue when broadcasting a reduction with different dimensions as the input --- src/shape_transform_descriptor.cpp | 28 ++++++++++++++++++++++ src/simplify_reshapes.cpp | 2 +- test/simplify_reshapes_test.cpp | 37 ++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 209e13ade3d..0fb13e948d7 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -233,6 +233,34 @@ shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector< } // TODO: Only simplify if the subs was changed result.simplify(); + if(broadcast) + { + const bool broadcast_only_dims = std::equal(dimensions.begin(), dimensions.end(), result.dimensions.begin(), result.dimensions.end(), [](const auto& src_dim, const auto& dst_dim) { + if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) + return false; + auto match_sub_dim = [](const dimension::sub& src_sub, const dimension::sub& dst_sub) { + if(src_sub.len == 1) + return true; + return src_sub.len == dst_sub.len; + }; + auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), + src_dim.subdimensions.end(), + dst_dim.subdimensions.begin(), + dst_dim.subdimensions.end(), + match_sub_dim); + if(src_it == src_dim.subdimensions.end()) + return true; + // One mismatch is fine as long as the dimension is still the same size + if(src_dim.len() != dst_dim.len()) + return false; + return std::equal(std::next(src_it), src_dim.subdimensions.end(), + std::next(dst_it), + dst_dim.subdimensions.end(), + match_sub_dim); + }); + if(not broadcast_only_dims) + return {}; + } return result; } diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index b98b51b74ce..4c6041a5706 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -263,7 +263,7 @@ struct find_op_shape_transform_op auto desc = desc1.rebase(x_ins->inputs().front()->get_shape().lens(), true); if(not desc.empty()) return desc; - if(not is_reduce(x_ins)) + if(not is_reduce(x_ins) or any_of(ops, [](const operation& op) { return contains({"broadcast", "multibroadcast"}, op.name()); })) return desc1; // Find a broadcast to append to improve the reduction analysis auto output_path = get_output_path(input_ins); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 7191dee8700..cfee8258091 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2666,6 +2666,43 @@ TEST_CASE(reduce_broadcast_reshape_pointwise2) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reduce_transpose_broadcast_pointwise_diff_size) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 128, 128, 3}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {1, 3, 256, 256}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto reduce_sum = + m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x); + auto transpose = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), reduce_sum); + auto broadcast = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), transpose); + auto add = m1.add_instruction(migraphx::make_op("add"), broadcast, y); + auto relu = m1.add_instruction(migraphx::make_op("relu"), add); + m1.add_return({relu}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto transpose = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), x); + auto reduce_sum = + m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), transpose); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), reduce_sum); + auto add = m2.add_instruction(migraphx::make_op("add"), broadcast, y); + auto relu = m2.add_instruction(migraphx::make_op("relu"), add); + m2.add_return({relu}); + } + EXPECT(m1.sort() == m2.sort()); +} + + TEST_CASE(transpose_contiguous_reshape_binary_packed) { migraphx::module m1; From fdb5604dd42823f8844897abefdc95ab318ddac5 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 25 Oct 2025 17:32:17 -0500 Subject: [PATCH 2/4] Format --- src/shape_transform_descriptor.cpp | 53 +++++++++++++++++------------- src/simplify_reshapes.cpp | 4 ++- test/simplify_reshapes_test.cpp | 9 +++-- 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 0fb13e948d7..aeaa5a9b6f4 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -235,29 +235,36 @@ shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector< result.simplify(); if(broadcast) { - const bool broadcast_only_dims = std::equal(dimensions.begin(), dimensions.end(), result.dimensions.begin(), result.dimensions.end(), [](const auto& src_dim, const auto& dst_dim) { - if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) - return false; - auto match_sub_dim = [](const dimension::sub& src_sub, const dimension::sub& dst_sub) { - if(src_sub.len == 1) - return true; - return src_sub.len == dst_sub.len; - }; - auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), - src_dim.subdimensions.end(), - dst_dim.subdimensions.begin(), - dst_dim.subdimensions.end(), - match_sub_dim); - if(src_it == src_dim.subdimensions.end()) - return true; - // One mismatch is fine as long as the dimension is still the same size - if(src_dim.len() != dst_dim.len()) - return false; - return std::equal(std::next(src_it), src_dim.subdimensions.end(), - std::next(dst_it), - dst_dim.subdimensions.end(), - match_sub_dim); - }); + const bool broadcast_only_dims = + std::equal(dimensions.begin(), + dimensions.end(), + result.dimensions.begin(), + result.dimensions.end(), + [](const auto& src_dim, const auto& dst_dim) { + if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) + return false; + auto match_sub_dim = [](const dimension::sub& src_sub, + const dimension::sub& dst_sub) { + if(src_sub.len == 1) + return true; + return src_sub.len == dst_sub.len; + }; + auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), + src_dim.subdimensions.end(), + dst_dim.subdimensions.begin(), + dst_dim.subdimensions.end(), + match_sub_dim); + if(src_it == src_dim.subdimensions.end()) + return true; + // One mismatch is fine as long as the dimension is still the same size + if(src_dim.len() != dst_dim.len()) + return false; + return std::equal(std::next(src_it), + src_dim.subdimensions.end(), + std::next(dst_it), + dst_dim.subdimensions.end(), + match_sub_dim); + }); if(not broadcast_only_dims) return {}; } diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 4c6041a5706..8d5df817d5b 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -263,7 +263,9 @@ struct find_op_shape_transform_op auto desc = desc1.rebase(x_ins->inputs().front()->get_shape().lens(), true); if(not desc.empty()) return desc; - if(not is_reduce(x_ins) or any_of(ops, [](const operation& op) { return contains({"broadcast", "multibroadcast"}, op.name()); })) + if(not is_reduce(x_ins) or any_of(ops, [](const operation& op) { + return contains({"broadcast", "multibroadcast"}, op.name()); + })) return desc1; // Find a broadcast to append to improve the reduction analysis auto output_path = get_output_path(input_ins); diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index cfee8258091..62606e3af25 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -2676,8 +2676,8 @@ TEST_CASE(reduce_transpose_broadcast_pointwise_diff_size) auto y = m1.add_parameter("y", s2); auto reduce_sum = m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x); - auto transpose = - m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), reduce_sum); + auto transpose = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), reduce_sum); auto broadcast = m1.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), transpose); auto add = m1.add_instruction(migraphx::make_op("add"), broadcast, y); @@ -2695,14 +2695,13 @@ TEST_CASE(reduce_transpose_broadcast_pointwise_diff_size) m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), transpose); auto broadcast = m2.add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), reduce_sum); - auto add = m2.add_instruction(migraphx::make_op("add"), broadcast, y); - auto relu = m2.add_instruction(migraphx::make_op("relu"), add); + auto add = m2.add_instruction(migraphx::make_op("add"), broadcast, y); + auto relu = m2.add_instruction(migraphx::make_op("relu"), add); m2.add_return({relu}); } EXPECT(m1.sort() == m2.sort()); } - TEST_CASE(transpose_contiguous_reshape_binary_packed) { migraphx::module m1; From 33cdfe68bd0776d785186cfcc4600d859885d316 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 28 Oct 2025 15:32:42 -0500 Subject: [PATCH 3/4] Refactor function --- src/shape_transform_descriptor.cpp | 70 +++++++++++++++--------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index aeaa5a9b6f4..c3d236fb683 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -185,6 +185,39 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector< return result; } +static bool is_broadcast_only(const std::vector& src_dims, const std::vector& dst_dims) +{ + return std::equal(src_dims.begin(), + src_dims.end(), + dst_dims.begin(), + dst_dims.end(), + [](const auto& src_dim, const auto& dst_dim) { + if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) + return false; + auto match_sub_dim = [](const dimension::sub& src_sub, + const dimension::sub& dst_sub) { + if(src_sub.len == 1) + return true; + return src_sub.len == dst_sub.len; + }; + auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), + src_dim.subdimensions.end(), + dst_dim.subdimensions.begin(), + dst_dim.subdimensions.end(), + match_sub_dim); + if(src_it == src_dim.subdimensions.end()) + return true; + // One mismatch is fine as long as the dimension is still the same size + if(src_dim.len() != dst_dim.len()) + return false; + return std::equal(std::next(src_it), + src_dim.subdimensions.end(), + std::next(dst_it), + dst_dim.subdimensions.end(), + match_sub_dim); + }); +} + shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector& dims, bool broadcast) const { @@ -233,41 +266,8 @@ shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector< } // TODO: Only simplify if the subs was changed result.simplify(); - if(broadcast) - { - const bool broadcast_only_dims = - std::equal(dimensions.begin(), - dimensions.end(), - result.dimensions.begin(), - result.dimensions.end(), - [](const auto& src_dim, const auto& dst_dim) { - if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) - return false; - auto match_sub_dim = [](const dimension::sub& src_sub, - const dimension::sub& dst_sub) { - if(src_sub.len == 1) - return true; - return src_sub.len == dst_sub.len; - }; - auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), - src_dim.subdimensions.end(), - dst_dim.subdimensions.begin(), - dst_dim.subdimensions.end(), - match_sub_dim); - if(src_it == src_dim.subdimensions.end()) - return true; - // One mismatch is fine as long as the dimension is still the same size - if(src_dim.len() != dst_dim.len()) - return false; - return std::equal(std::next(src_it), - src_dim.subdimensions.end(), - std::next(dst_it), - dst_dim.subdimensions.end(), - match_sub_dim); - }); - if(not broadcast_only_dims) - return {}; - } + if(broadcast and not is_broadcast_only(dimensions, result.dimensions)) + return {}; return result; } From 3d11fdffae20944aab9e06c20ec311fa9b269fc8 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 28 Oct 2025 15:32:45 -0500 Subject: [PATCH 4/4] Format --- src/shape_transform_descriptor.cpp | 59 +++++++++++++++--------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index c3d236fb683..804b9bc9e46 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -185,37 +185,38 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector< return result; } -static bool is_broadcast_only(const std::vector& src_dims, const std::vector& dst_dims) +static bool is_broadcast_only(const std::vector& src_dims, + const std::vector& dst_dims) { return std::equal(src_dims.begin(), - src_dims.end(), - dst_dims.begin(), - dst_dims.end(), - [](const auto& src_dim, const auto& dst_dim) { - if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) - return false; - auto match_sub_dim = [](const dimension::sub& src_sub, - const dimension::sub& dst_sub) { - if(src_sub.len == 1) - return true; - return src_sub.len == dst_sub.len; - }; - auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), - src_dim.subdimensions.end(), - dst_dim.subdimensions.begin(), - dst_dim.subdimensions.end(), - match_sub_dim); - if(src_it == src_dim.subdimensions.end()) - return true; - // One mismatch is fine as long as the dimension is still the same size - if(src_dim.len() != dst_dim.len()) - return false; - return std::equal(std::next(src_it), - src_dim.subdimensions.end(), - std::next(dst_it), - dst_dim.subdimensions.end(), - match_sub_dim); - }); + src_dims.end(), + dst_dims.begin(), + dst_dims.end(), + [](const auto& src_dim, const auto& dst_dim) { + if(src_dim.subdimensions.size() != dst_dim.subdimensions.size()) + return false; + auto match_sub_dim = [](const dimension::sub& src_sub, + const dimension::sub& dst_sub) { + if(src_sub.len == 1) + return true; + return src_sub.len == dst_sub.len; + }; + auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(), + src_dim.subdimensions.end(), + dst_dim.subdimensions.begin(), + dst_dim.subdimensions.end(), + match_sub_dim); + if(src_it == src_dim.subdimensions.end()) + return true; + // One mismatch is fine as long as the dimension is still the same size + if(src_dim.len() != dst_dim.len()) + return false; + return std::equal(std::next(src_it), + src_dim.subdimensions.end(), + std::next(dst_it), + dst_dim.subdimensions.end(), + match_sub_dim); + }); } shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector& dims,