Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/shape_transform_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,40 @@ shape_transform_descriptor shape_transform_descriptor::create(const std::vector<
return result;
}

static bool is_broadcast_only(const std::vector<dimension>& src_dims,
const std::vector<dimension>& dst_dims)
{
return std::equal(src_dims.begin(),
src_dims.end(),
dst_dims.begin(),
dst_dims.end(),
[](const auto& src_dim, const auto& dst_dim) {
if(src_dim.subdimensions.size() != dst_dim.subdimensions.size())
return false;
auto match_sub_dim = [](const dimension::sub& src_sub,
const dimension::sub& dst_sub) {
if(src_sub.len == 1)
return true;
return src_sub.len == dst_sub.len;
};
auto [src_it, dst_it] = std::mismatch(src_dim.subdimensions.begin(),
src_dim.subdimensions.end(),
dst_dim.subdimensions.begin(),
dst_dim.subdimensions.end(),
match_sub_dim);
if(src_it == src_dim.subdimensions.end())
return true;
// One mismatch is fine as long as the dimension is still the same size
if(src_dim.len() != dst_dim.len())
return false;
return std::equal(std::next(src_it),
src_dim.subdimensions.end(),
std::next(dst_it),
dst_dim.subdimensions.end(),
match_sub_dim);
});
}

shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector<std::size_t>& dims,
bool broadcast) const
{
Expand Down Expand Up @@ -233,6 +267,8 @@ shape_transform_descriptor shape_transform_descriptor::rebase(const std::vector<
}
// TODO: Only simplify if the subs was changed
result.simplify();
if(broadcast and not is_broadcast_only(dimensions, result.dimensions))
return {};

return result;
}
Expand Down
4 changes: 3 additions & 1 deletion src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ struct find_op_shape_transform_op
auto desc = desc1.rebase(x_ins->inputs().front()->get_shape().lens(), true);
if(not desc.empty())
return desc;
if(not is_reduce(x_ins))
if(not is_reduce(x_ins) or any_of(ops, [](const operation& op) {
return contains({"broadcast", "multibroadcast"}, op.name());
}))
return desc1;
// Find a broadcast to append to improve the reduction analysis
auto output_path = get_output_path(input_ins);
Expand Down
36 changes: 36 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,42 @@ TEST_CASE(reduce_broadcast_reshape_pointwise2)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(reduce_transpose_broadcast_pointwise_diff_size)
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {1, 128, 128, 3}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {1, 3, 256, 256}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto reduce_sum =
m1.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1, 2}}}), x);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), reduce_sum);
auto broadcast = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), transpose);
auto add = m1.add_instruction(migraphx::make_op("add"), broadcast, y);
auto relu = m1.add_instruction(migraphx::make_op("relu"), add);
m1.add_return({relu});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto transpose =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), x);
auto reduce_sum =
m2.add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 3}}}), transpose);
auto broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), reduce_sum);
auto add = m2.add_instruction(migraphx::make_op("add"), broadcast, y);
auto relu = m2.add_instruction(migraphx::make_op("relu"), add);
m2.add_return({relu});
}
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(transpose_contiguous_reshape_binary_packed)
{
migraphx::module m1;
Expand Down
Loading