@@ -532,14 +532,13 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
532532 make_op (" multibroadcast" ,
533533 {{" out_lens" , {batch_size, kv_seq_length}}}),
534534 mask_value_lit);
535-
536- // Gen list of indicies to compare to the exclusive start of right padding
537- std::vector<size_t > indicies_vec (kv_seq_length, 0 );
538- std::iota (indicies_vec.begin (), indicies_vec.end (), 0 );
539- auto indicies = info.add_literal (migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {kv_seq_length}, {1 }}, indicies_vec});
540- auto indicies_bc = info.add_instruction (make_op (" multibroadcast" , {{" out_lens" , {batch_size, kv_seq_length}}}), indicies);
535+ // Gen list of indices to compare to the exclusive start of right padding
536+ std::vector<size_t > indices_vec (kv_seq_length, 0 );
537+ std::iota (indices_vec.begin (), indices_vec.end (), 0 );
538+ auto indices = info.add_literal (migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {kv_seq_length}, {1 }}, indices_vec});
539+ auto indices_bc = info.add_instruction (make_op (" multibroadcast" , {{" out_lens" , {batch_size, kv_seq_length}}}), indices);
541540 auto right_mask_bc = info.add_instruction (make_op (" multibroadcast" , {{" out_lens" , {batch_size, kv_seq_length}}}), right_mask);
542- auto in_bool = info.add_instruction (make_op (" less" ), indicies_bc , right_mask_bc);
541+ auto in_bool = info.add_instruction (make_op (" less" ), indices_bc , right_mask_bc);
543542 auto where = info.add_instruction (make_op (" where" ), in_bool, bc_pass, bc_mask);
544543
545544 return info.add_instruction (make_op (" convert" , {{" target_type" , migraphx::shape::int32_type}}), where);
@@ -558,7 +557,7 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
558557
559558 if ((attention.key_pad_mode == key_mask_mode_t ::direct_2d_pad) or
560559 (attention.key_pad_mode == key_mask_mode_t ::direct_3d_pad))
561- { // Raw Mask - 0 means mask, 1 means pass through. Apply mask_filter_val to mask indicies
560+ { // Raw Mask - 0 means mask, 1 means pass through. Apply mask_filter_val to mask indices
562561 // and zero otherwise
563562 // Need to generate from 2 dims or 3 dim cases
564563 return generate_raw_mask_per_batch (info, mask_index, input_shape, attention);
0 commit comments