diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 8bd771b8a2e..2f0c95da4df 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -51,6 +51,10 @@ inline namespace MIGRAPHX_INLINE_NS { namespace match { +struct supports_dynamic_shapes +{ +}; + struct matcher_context { matcher_context(module& m) : mod(&m) {} @@ -407,10 +411,28 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES_FOR) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_VALIDATE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TIME_MATCHERS) +MIGRAPHX_PRED_MATCHER(not_dynamic_shape, instruction_ref ins) +{ + return not ins->get_shape().dynamic(); +} + +template +auto get_matcher(const Finder& f) +{ + if constexpr(std::is_base_of{}) + { + return f.matcher(); + } + else + { + return not_dynamic_shape(f.matcher()); + } +} + template auto make_match_runner_with_trace(source_location location, Finder& f) { - auto m = f.matcher(); + auto m = get_matcher(f); const int trace = value_of(MIGRAPHX_TRACE_MATCHES{}); const bool validate = enabled(MIGRAPHX_VALIDATE_MATCHES{}); const auto trace_filter = string_value_of(MIGRAPHX_TRACE_MATCHES_FOR{}); @@ -485,7 +507,7 @@ auto make_match_runner_with_trace(source_location location, Finder& f) template auto make_match_runner(Finder& f) { - auto m = f.matcher(); + auto m = get_matcher(f); return [=, &f](auto& mod, instruction_ref ins) -> bool { match::matcher_result r = match::match_instruction(get_module(mod), ins, m); if(r.result == get_module(mod).end()) diff --git a/src/simplify_dyn_ops.cpp b/src/simplify_dyn_ops.cpp index fb6b41d3b7c..91eaad6c60f 100644 --- a/src/simplify_dyn_ops.cpp +++ b/src/simplify_dyn_ops.cpp @@ -39,7 +39,7 @@ inline namespace MIGRAPHX_INLINE_NS { * into multibroadcast op with a static output shape attribute. * */ -struct find_broadcast_with_dims_static +struct find_broadcast_with_dims_static : match::supports_dynamic_shapes { auto matcher() const { @@ -80,7 +80,7 @@ struct find_broadcast_with_dims_static * At time of writing, Resize allows either 1 or 2 inputs * but the 1-input case is never created by Onnx parsing. */ -struct find_resize_static +struct find_resize_static : match::supports_dynamic_shapes { auto matcher() const @@ -168,7 +168,7 @@ struct find_resize_static * To: * broadcast_op(argument_with_static_shape); broadcast_op.out_lens = constant_output_dims */ -struct find_static_2in_broadcasts +struct find_static_2in_broadcasts : match::supports_dynamic_shapes { auto matcher() const { @@ -201,7 +201,7 @@ struct find_static_2in_broadcasts * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_2in_slice +struct find_const_2in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -255,7 +255,7 @@ struct find_const_2in_slice * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_3in_slice +struct find_const_3in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -266,10 +266,10 @@ struct find_const_3in_slice void apply(module& m, const match::matcher_result& mr) const { - auto ins = mr.result; - auto inputs = ins->inputs(); - auto slice_op = any_cast(ins->get_operator()); - auto set_attrs = slice_op.get_set_attributes(); + auto ins = mr.result; + auto inputs = ins->inputs(); + auto slice_op = any_cast(ins->get_operator()); + auto set_attrs = slice_op.get_set_attributes(); std::vector starts_vec; std::vector ends_vec; std::vector axes_vec; @@ -314,7 +314,7 @@ struct find_const_3in_slice * To: * slice(data); slice.starts, slice.ends. slice.axes set */ -struct find_const_4in_slice +struct find_const_4in_slice : match::supports_dynamic_shapes { auto matcher() const { @@ -351,7 +351,7 @@ struct find_const_4in_slice * Simplify dimensions_of to a literal when the input arugment has a static shape * or the dynamic dimensions from `start` to `end` are fixed. */ -struct find_static_dimensions_of +struct find_static_dimensions_of : match::supports_dynamic_shapes { auto matcher() const { return match::name("dimensions_of")(); } @@ -396,7 +396,7 @@ struct find_static_dimensions_of * To: * reshape(data); reshape.dims = constant_output_dims */ -struct find_const_alloc_reshapes +struct find_const_alloc_reshapes : match::supports_dynamic_shapes { auto matcher() const { @@ -430,7 +430,7 @@ struct find_const_alloc_reshapes * To: * literal */ -struct find_const_alloc_fill +struct find_const_alloc_fill : match::supports_dynamic_shapes { auto matcher() const { @@ -454,7 +454,7 @@ struct find_const_alloc_fill * To: * multibroadcast(static_shape_arg); output_lens = static_broadcast_for_doted_shape */ -struct find_static_broadcast_for_dot +struct find_static_broadcast_for_dot : match::supports_dynamic_shapes { auto matcher() const { @@ -496,7 +496,7 @@ struct find_static_broadcast_for_dot * (on_value - off_value) * mask + off_value when we have `fill` working * on the GPU. */ -struct find_static_onehot +struct find_static_onehot : match::supports_dynamic_shapes { auto matcher() const { @@ -530,7 +530,7 @@ struct find_static_onehot depth_ins->eval().visit([&](auto d) { depth_val = d[0]; }); values_ins = onehot_inputs[2]; } - shape values_shape = values_ins->get_shape(); + shape values_shape = values_ins->get_shape(); std::vector static_output_lens = indices_shape.lens(); auto normalized_axis = (onehot_op.axis < 0) ? onehot_op.axis + indices_shape.ndim() + 1 : onehot_op.axis; @@ -574,7 +574,7 @@ struct find_static_onehot * This version ignores dynamic_dimension opt values. * Intended to be run after the other simplify_dyn_ops passes. */ -struct simplify_select_module_output_shape +struct simplify_select_module_output_shape : match::supports_dynamic_shapes { auto matcher() const { return match::name("select_module"); } diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 113cc3c75b0..6328a575bd3 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -89,6 +89,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_REWRITE_LRN) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_SET_GEMM_PROVIDER) +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FULL_DYNAMIC) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -179,7 +180,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // clang-format off return { - split_single_dyn_dim{}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), split_single_dyn_dim{}), dead_code_elimination{}, simplify_dyn_ops{}, dead_code_elimination{}, @@ -204,7 +205,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti insert_pad{{"convolution"}}, dead_code_elimination{}, inline_module{}, - rewrite_pooling{.rewrite_lrn = (not MIGRAPHX_USE_MIOPEN or enabled(MIGRAPHX_REWRITE_LRN{}))}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), rewrite_pooling{.rewrite_lrn = (not MIGRAPHX_USE_MIOPEN or enabled(MIGRAPHX_REWRITE_LRN{}))}), dead_code_elimination{}, rewrite_gelu{options.fast_math}, optimize_module{}, @@ -227,13 +228,13 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(mlir_attention_enabled(&ctx), fuse_attention{}), dead_code_elimination{}, optimize_module{}, - fuse_pointwise_reduce{}, + enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), dead_code_elimination{}, #ifndef _WIN32 enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), #endif dead_code_elimination{}, - enable_pass(mlir_enabled(), fuse_mlir{&ctx}), + enable_pass(mlir_enabled() and disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_mlir{&ctx}), dead_code_elimination{}, fuse_concat{}, dead_code_elimination{},