Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/onnx/parse_multi_head_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
const multi_head_attention_parameters& attention) const
{
auto batch_size = attention.batch_size;
auto total_seq_len = attention.q_sequence_length;
auto total_seq_len = attention.kv_sequence_length;
auto num_heads = attention.num_heads;

// Other two cases require us to generate masks from sequence or total sequence length pads.
Expand Down Expand Up @@ -509,6 +509,32 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
return info.add_instruction(make_op("where"), in_bool, bc_mask, bc_pass);
}

// Convert per batch right padding values and generate raw mask
// Used so we can leverage
instruction_ref
get_raw_mask_from_right_padding(const onnx_parser::node_info& info,
const instruction_ref right_mask,
const multi_head_attention_parameters& attention) const
{
auto batch_size = attention.batch_size;
auto kv_seq_length = attention.kv_sequence_length;

// Gen list of indices to compare to the exclusive start of right padding
std::vector<size_t> indices_vec(kv_seq_length, 0);
std::iota(indices_vec.begin(), indices_vec.end(), 0);
auto indices = info.add_literal(migraphx::literal{
migraphx::shape{migraphx::shape::int32_type, {static_cast<size_t>(kv_seq_length)}, {1}},
indices_vec});
auto indices_bc = info.add_instruction(
make_op("multibroadcast", {{"out_lens", {batch_size, kv_seq_length}}}), indices);
auto right_mask_bc = info.add_instruction(
make_op("multibroadcast", {{"out_lens", {batch_size, kv_seq_length}}}), right_mask);
auto in_bool = info.add_instruction(make_op("less"), indices_bc, right_mask_bc);

return info.add_instruction(
make_op("convert", {{"target_type", migraphx::shape::int32_type}}), in_bool);
}

std::optional<instruction_ref>
create_input_mask(const onnx_parser::node_info& info,
const instruction_ref mask_index,
Expand All @@ -522,11 +548,16 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>

if((attention.key_pad_mode == key_mask_mode_t::direct_2d_pad) or
(attention.key_pad_mode == key_mask_mode_t::direct_3d_pad))
{ // Raw Mask - 0 means mask, 1 means pass through. Apply mask_filter_val to mask indicies
{ // Raw Mask - 0 means mask, 1 means pass through. Apply mask_filter_val to mask indices
// and zero otherwise
// Need to generate from 2 dims or 3 dim cases
return generate_raw_mask_per_batch(info, mask_index, input_shape, attention);
}
else if(attention.key_pad_mode == key_mask_mode_t::right_pad)
{
auto right_mask = get_raw_mask_from_right_padding(info, mask_index, attention);
return generate_raw_mask_per_batch(info, right_mask, input_shape, attention);
}

return nullopt;
}
Expand Down
18 changes: 18 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9414,6 +9414,24 @@ def mha_double_head_bias_mask_batch1_test():
return ([node], [query, key, value, bias, key_padding_mask], [out])


@onnx_test()
def mha_double_head_bias_mask_right_batch2_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [2, 2, 4])
key = helper.make_tensor_value_info("k", TensorProto.FLOAT, [2, 2, 4])
value = helper.make_tensor_value_info("v", TensorProto.FLOAT, [2, 2, 4])
bias = helper.make_tensor_value_info("bias", TensorProto.FLOAT, [12])
key_padding_mask = helper.make_tensor_value_info("key_padding_mask", TensorProto.INT32, [2])
out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [2, 2, 4])

node = helper.make_node('MultiHeadAttention',
inputs=['q', 'k', 'v', 'bias', 'key_padding_mask'],
outputs=['out'],
num_heads=2,
domain='com.microsoft')

return ([node], [query, key, value, bias, key_padding_mask], [out])


@onnx_test()
def mha_bias_asym_mask_2d_scale_test():
query = helper.make_tensor_value_info("q", TensorProto.FLOAT, [2, 3, 4])
Expand Down
36 changes: 36 additions & 0 deletions test/onnx/mha_double_head_bias_mask_right_batch2_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
 +mha_double_head_bias_mask_right_batch2_test:¤
[
q
k
v
bias
key_padding_maskout"MultiHeadAttention*
num_heads :com.microsoft+mha_double_head_bias_mask_right_batch2_testZ
q



Z
k



Z
v



Z
bias


 Z
key_padding_mask


b
out



B
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ TEST_CASE(mha_double_head_bias_asym_mask_scale_test)
0.818296f};

migraphx::shape mask_shape{migraphx::shape::int32_type, {2, 3}};
std::vector<float> mask_data = {0, 0, 1, 1, 1, 1};
std::vector<int32_t> mask_data = {0, 0, 1, 1, 1, 1};

migraphx::literal query{q_shape, query_data};
migraphx::literal key{k_shape, key_data};
Expand Down
123 changes: 119 additions & 4 deletions test/onnx/verify/mha_double_head_bias_key_padding_mask_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_passthrough_mask_test)
std::vector<float> bias_data(12, 0.0f);

migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
std::vector<float> mask_data = {1, 1};
std::vector<int32_t> mask_data = {1, 1};

migraphx::literal query{q_shape, query_data};
migraphx::literal key{k_shape, key_data};
Expand Down Expand Up @@ -113,7 +113,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_last_mask_test)

// 0 = mask,1 = pass through
migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
std::vector<float> mask_data = {1, 0};
std::vector<int32_t> mask_data = {1, 0};

migraphx::literal query{q_shape, query_data};
migraphx::literal key{k_shape, key_data};
Expand Down Expand Up @@ -172,7 +172,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_first_mask_test)
0.85529f};

migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
std::vector<float> mask_data = {0, 1};
std::vector<int32_t> mask_data = {0, 1};

migraphx::literal query{q_shape, query_data};
migraphx::literal key{k_shape, key_data};
Expand Down Expand Up @@ -231,7 +231,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_all_mask_test)
0.84921235f};

migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
std::vector<float> mask_data = {0, 0};
std::vector<int32_t> mask_data = {0, 0};

migraphx::literal query{q_shape, query_data};
migraphx::literal key{k_shape, key_data};
Expand All @@ -255,3 +255,118 @@ TEST_CASE(mha_double_head_bias_mask_batch1_all_mask_test)

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

TEST_CASE(mha_double_head_bias_mask_batch2_right_mask_test)
{
auto p = optimize_onnx("mha_double_head_bias_mask_right_batch2_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::parameter_map pp;

migraphx::shape q_shape{migraphx::shape::float_type, {2, 2, 4}};
std::vector<float> query_data = {1.46175,
1.4676,
1.05493,
0.900047,
1.67605,
1.30483,
1.21247,
0.897198,
1.46175,
1.4676,
1.05493,
0.900047,
1.67605,
1.30483,
1.21247,
0.897198};

migraphx::shape k_shape{migraphx::shape::float_type, {2, 2, 4}};
std::vector<float> key_data = {1.71781,
2.04228,
1.88613,
1.76649,
1.62908,
2.07181,
1.79497,
2.00843,
1.71781,
2.04228,
1.88613,
1.76649,
1.62908,
2.07181,
1.79497,
2.00843};

migraphx::shape value_shape{migraphx::shape::float_type, {2, 2, 4}};
std::vector<float> value_data = {1.06769,
1.36994,
1.26663,
1.35326,
1.18959,
1.56367,
1.01132,
1.55191,
1.06769,
1.36994,
1.26663,
1.35326,
1.18959,
1.56367,
1.01132,
1.55191};

migraphx::shape bias_shape{migraphx::shape::float_type, {12}};
std::vector<float> bias_data = {0.751496f,
0.557292f,
0.6720010f,
0.1879267f,
0.352546f,
0.600021f,
0.0552079f,
0.5959239f,
0.0404032f,
0.1882552f,
0.2718655f,
0.84921235f};

migraphx::shape mask_shape{migraphx::shape::int32_type, {2}};
std::vector<int32_t> mask_data = {2, 1};

migraphx::literal query{q_shape, query_data};
migraphx::literal key{k_shape, key_data};
migraphx::literal value{value_shape, value_data};
migraphx::literal bias{bias_shape, bias_data};
migraphx::literal mask{mask_shape, mask_data};

pp["q"] = query.get_argument();
pp["k"] = key.get_argument();
pp["v"] = value.get_argument();
pp["bias"] = bias.get_argument();
pp["key_padding_mask"] = mask.get_argument();

auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

// Gold data from AttentionNoMaskIndex from attention_op_test.cc from Onnxruntime
std::vector<float> gold = {1.10809f,
1.5582f,
1.5385f,
2.20247f,
1.10809f,
1.5582f,
1.5385f,
2.20247f,
1.10809f,
1.5582f,
1.5385f,
2.20247f,
1.10809f,
1.5582f,
1.5385f,
2.20247f};

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
Loading