Skip to content

Commit 6704d47

Browse files
Copilot and formatting changes
1 parent ed7522a commit 6704d47

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

src/onnx/parse_multi_head_attention.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -532,14 +532,13 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
532532
make_op("multibroadcast",
533533
{{"out_lens", {batch_size, kv_seq_length}}}),
534534
mask_value_lit);
535-
536-
// Gen list of indicies to compare to the exclusive start of right padding
537-
std::vector<size_t> indicies_vec(kv_seq_length, 0);
538-
std::iota(indicies_vec.begin(), indicies_vec.end(), 0);
539-
auto indicies = info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {kv_seq_length}, {1}}, indicies_vec});
540-
auto indicies_bc = info.add_instruction(make_op("multibroadcast", {{"out_lens", {batch_size, kv_seq_length}}}), indicies);
535+
// Gen list of indices to compare to the exclusive start of right padding
536+
std::vector<size_t> indices_vec(kv_seq_length, 0);
537+
std::iota(indices_vec.begin(), indices_vec.end(), 0);
538+
auto indices = info.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {kv_seq_length}, {1}}, indices_vec});
539+
auto indices_bc = info.add_instruction(make_op("multibroadcast", {{"out_lens", {batch_size, kv_seq_length}}}), indices);
541540
auto right_mask_bc = info.add_instruction(make_op("multibroadcast", {{"out_lens", {batch_size, kv_seq_length}}}), right_mask);
542-
auto in_bool = info.add_instruction(make_op("less"), indicies_bc, right_mask_bc);
541+
auto in_bool = info.add_instruction(make_op("less"), indices_bc, right_mask_bc);
543542
auto where = info.add_instruction(make_op("where"), in_bool, bc_pass, bc_mask);
544543

545544
return info.add_instruction(make_op("convert", {{"target_type", migraphx::shape::int32_type}}), where);
@@ -558,7 +557,7 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
558557

559558
if((attention.key_pad_mode == key_mask_mode_t::direct_2d_pad) or
560559
(attention.key_pad_mode == key_mask_mode_t::direct_3d_pad))
561-
{ // Raw Mask - 0 means mask, 1 means pass through. Apply mask_filter_val to mask indicies
560+
{ // Raw Mask - 0 means mask, 1 means pass through. Apply mask_filter_val to mask indices
562561
// and zero otherwise
563562
// Need to generate from 2 dims or 3 dim cases
564563
return generate_raw_mask_per_batch(info, mask_index, input_shape, attention);

test/onnx/verify/mha_double_head_bias_asym_mask_scale_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ TEST_CASE(mha_double_head_bias_asym_mask_scale_test)
6767
0.818296f};
6868

6969
migraphx::shape mask_shape{migraphx::shape::int32_type, {2, 3}};
70-
std::vector<float> mask_data = {0, 0, 1, 1, 1, 1};
70+
std::vector<int32_t> mask_data = {0, 0, 1, 1, 1, 1};
7171

7272
migraphx::literal query{q_shape, query_data};
7373
migraphx::literal key{k_shape, key_data};

test/onnx/verify/mha_double_head_bias_key_padding_mask_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_passthrough_mask_test)
5353
std::vector<float> bias_data(12, 0.0f);
5454

5555
migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
56-
std::vector<float> mask_data = {1, 1};
56+
std::vector<int32_t> mask_data = {1, 1};
5757

5858
migraphx::literal query{q_shape, query_data};
5959
migraphx::literal key{k_shape, key_data};
@@ -113,7 +113,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_last_mask_test)
113113

114114
// 0 = mask,1 = pass through
115115
migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
116-
std::vector<float> mask_data = {1, 0};
116+
std::vector<int32_t> mask_data = {1, 0};
117117

118118
migraphx::literal query{q_shape, query_data};
119119
migraphx::literal key{k_shape, key_data};
@@ -172,7 +172,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_first_mask_test)
172172
0.85529f};
173173

174174
migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
175-
std::vector<float> mask_data = {0, 1};
175+
std::vector<int32_t> mask_data = {0, 1};
176176

177177
migraphx::literal query{q_shape, query_data};
178178
migraphx::literal key{k_shape, key_data};
@@ -231,7 +231,7 @@ TEST_CASE(mha_double_head_bias_mask_batch1_all_mask_test)
231231
0.84921235f};
232232

233233
migraphx::shape mask_shape{migraphx::shape::int32_type, {1, 2}};
234-
std::vector<float> mask_data = {0, 0};
234+
std::vector<int32_t> mask_data = {0, 0};
235235

236236
migraphx::literal query{q_shape, query_data};
237237
migraphx::literal key{k_shape, key_data};

0 commit comments

Comments
 (0)