Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/include/migraphx/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ inline namespace MIGRAPHX_INLINE_NS {

namespace match {

struct supports_dynamic_shapes
{
};

struct matcher_context
{
matcher_context(module& m) : mod(&m) {}
Expand Down Expand Up @@ -407,10 +411,28 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS)

MIGRAPHX_PRED_MATCHER(not_dynamic_shape, instruction_ref ins)
{
return not ins->get_shape().dynamic();
}

template <class Finder>
auto get_matcher(const Finder& f)
{
if constexpr(std::is_base_of<supports_dynamic_shapes, Finder>{})
{
return f.matcher();
}
else
{
return not_dynamic_shape(f.matcher());
}
}

template <class Finder>
auto make_match_runner_with_trace(source_location location, Finder& f)
{
auto m = f.matcher();
auto m = get_matcher(f);
const int trace = value_of(MIGRAPHX_TRACE_MATCHES{});
const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{});
const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{});
Expand Down Expand Up @@ -485,7 +507,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f)
template <class Finder>
auto make_match_runner(Finder& f)
{
auto m = f.matcher();
auto m = get_matcher(f);
return [=, &f](auto& mod, instruction_ref ins) -> bool {
match::matcher_result r = match::match_instruction(get_module(mod), ins, m);
if(r.result == get_module(mod).end())
Expand Down
34 changes: 17 additions & 17 deletions src/simplify_dyn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ inline namespace MIGRAPHX_INLINE_NS {
* into multibroadcast op with a static output shape attribute.
*
*/
struct find_broadcast_with_dims_static
struct find_broadcast_with_dims_static : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -80,7 +80,7 @@ struct find_broadcast_with_dims_static
* At time of writing, Resize allows either 1 or 2 inputs
* but the 1-input case is never created by Onnx parsing.
*/
struct find_resize_static
struct find_resize_static : match::supports_dynamic_shapes
{

auto matcher() const
Expand Down Expand Up @@ -168,7 +168,7 @@ struct find_resize_static
* To:
* broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims
*/
struct find_static_2in_broadcasts
struct find_static_2in_broadcasts : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -201,7 +201,7 @@ struct find_static_2in_broadcasts
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_2in_slice
struct find_const_2in_slice : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -255,7 +255,7 @@ struct find_const_2in_slice
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_3in_slice
struct find_const_3in_slice : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand All @@ -266,10 +266,10 @@ struct find_const_3in_slice

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
auto ins = mr.result;
auto inputs = ins->inputs();
auto slice_op = any_cast<op::slice>(ins->get_operator());
auto set_attrs = slice_op.get_set_attributes();
std::vector<int64_t> starts_vec;
std::vector<int64_t> ends_vec;
std::vector<int64_t> axes_vec;
Expand Down Expand Up @@ -314,7 +314,7 @@ struct find_const_3in_slice
* To:
* slice(data); slice.starts, slice.ends. slice.axes set
*/
struct find_const_4in_slice
struct find_const_4in_slice : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -351,7 +351,7 @@ struct find_const_4in_slice
* Simplify dimensions_of to a literal when the input arugment has a static shape
* or the dynamic dimensions from `start` to `end` are fixed.
*/
struct find_static_dimensions_of
struct find_static_dimensions_of : match::supports_dynamic_shapes
{
auto matcher() const { return match::name("dimensions_of")(); }

Expand Down Expand Up @@ -396,7 +396,7 @@ struct find_static_dimensions_of
* To:
* reshape(data); reshape.dims = constant_output_dims
*/
struct find_const_alloc_reshapes
struct find_const_alloc_reshapes : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -430,7 +430,7 @@ struct find_const_alloc_reshapes
* To:
* literal
*/
struct find_const_alloc_fill
struct find_const_alloc_fill : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand All @@ -454,7 +454,7 @@ struct find_const_alloc_fill
* To:
* multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape
*/
struct find_static_broadcast_for_dot
struct find_static_broadcast_for_dot : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -496,7 +496,7 @@ struct find_static_broadcast_for_dot
* (on_value - off_value) * mask + off_value when we have `fill` working
* on the GPU.
*/
struct find_static_onehot
struct find_static_onehot : match::supports_dynamic_shapes
{
auto matcher() const
{
Expand Down Expand Up @@ -530,7 +530,7 @@ struct find_static_onehot
depth_ins->eval().visit([&](auto d) { depth_val = d[0]; });
values_ins = onehot_inputs[2];
}
shape values_shape = values_ins->get_shape();
shape values_shape = values_ins->get_shape();
std::vector<std::size_t> static_output_lens = indices_shape.lens();
auto normalized_axis =
(onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis;
Expand Down Expand Up @@ -574,7 +574,7 @@ struct find_static_onehot
* This version ignores dynamic_dimension opt values.
* Intended to be run after the other simplify_dyn_ops passes.
*/
struct simplify_select_module_output_shape
struct simplify_select_module_output_shape : match::supports_dynamic_shapes
{
auto matcher() const { return match::name("select_module"); }

Expand Down
9 changes: 5 additions & 4 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK)
#endif
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC)

std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{
Expand Down Expand Up @@ -179,7 +180,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// clang-format off
return
{
split_single_dyn_dim{},
enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), split_single_dyn_dim{}),
dead_code_elimination{},
simplify_dyn_ops{},
dead_code_elimination{},
Expand All @@ -204,7 +205,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
insert_pad{{"convolution"}},
dead_code_elimination{},
inline_module{},
rewrite_pooling{.rewrite_lrn = (not MIGRAPHX_USE_MIOPEN or enabled(MIGRAPHX_REWRITE_LRN{}))},
enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), rewrite_pooling{.rewrite_lrn = (not MIGRAPHX_USE_MIOPEN or enabled(MIGRAPHX_REWRITE_LRN{}))}),
dead_code_elimination{},
rewrite_gelu{options.fast_math},
optimize_module{},
Expand All @@ -227,13 +228,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
enable_pass(mlir_attention_enabled(&ctx), fuse_attention{}),
dead_code_elimination{},
optimize_module{},
fuse_pointwise_reduce{},
enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}),
dead_code_elimination{},
#ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif
dead_code_elimination{},
enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
enable_pass(mlir_enabled() and disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_mlir{&ctx}),
dead_code_elimination{},
fuse_concat{},
dead_code_elimination{},
Expand Down
Loading