@@ -802,6 +802,7 @@ struct vk_mat_vec_push_constants {
802802 uint32_t batch_stride_b;
803803 uint32_t batch_stride_d;
804804 uint32_t enable_bias;
805+ uint32_t enable_scale;
805806 uint32_t ne02;
806807 uint32_t ne12;
807808 uint32_t broadcast2;
@@ -824,6 +825,7 @@ struct vk_mat_vec_id_push_constants {
824825 uint32_t batch_stride_b;
825826 uint32_t batch_stride_d;
826827 uint32_t enable_bias;
828+ uint32_t enable_scale;
827829 uint32_t nei0;
828830 uint32_t ne11;
829831};
@@ -6796,7 +6798,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
67966798 // compute
67976799 const vk_mat_vec_push_constants pc = {
67986800 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
6799- stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
6801+ stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
68006802 (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
68016803 };
68026804 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -7617,13 +7619,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76177619 groups_x = CEIL_DIV(groups_x, groups_z);
76187620 }
76197621
7620- uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
7622+ uint32_t enable_bias = 0;
7623+ uint32_t enable_scale = 0;
7624+ if (ctx->num_additional_fused_ops > 0) {
7625+ if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
7626+ enable_scale = 1;
7627+ } else {
7628+ GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
7629+ enable_bias = 1;
7630+ }
7631+ }
76217632
76227633 vk_buffer d_B = d_D;
76237634 size_t b_buf_offset = 0;
76247635 uint64_t b_sz = 0;
76257636
7626- if (enable_bias) {
7637+ if (enable_bias || enable_scale ) {
76277638 const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
76287639
76297640 bool b_uma = false;
@@ -7645,7 +7656,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76457656 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
76467657 (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
76477658
7648- enable_bias,
7659+ enable_bias, enable_scale,
76497660
76507661 (uint32_t)nei0, (uint32_t)ne11,
76517662 };
@@ -12671,6 +12682,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
1267112682 }
1267212683 }
1267312684
12685+ if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
12686+ // additional constraints specific to this fusion
12687+ const ggml_tensor *mmid = cgraph->nodes[node_idx];
12688+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
12689+ const ggml_tensor *scale = mul->src[1];
12690+
12691+ if (mmid != mul->src[0]) {
12692+ return false;
12693+ }
12694+ // mat-vec only
12695+ if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
12696+ return false;
12697+ }
12698+ // shaders assume the types match
12699+ if (mmid->type != scale->type) {
12700+ return false;
12701+ }
12702+ // shaders assume the bias is contiguous
12703+ if (!ggml_is_contiguous(scale)) {
12704+ return false;
12705+ }
12706+ // unaligned bias isn't handled
12707+ if (get_misalign_bytes(ctx, scale) != 0) {
12708+ return false;
12709+ }
12710+ // shader only indexes by expert index
12711+ if (scale->ne[0] != 1 ||
12712+ scale->ne[1] != mul->ne[1] ||
12713+ scale->ne[2] != 1 ||
12714+ scale->ne[3] != 1) {
12715+ return false;
12716+ }
12717+ }
12718+
1267412719 return true;
1267512720}
1267612721
@@ -12917,6 +12962,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1291712962 ctx->num_additional_fused_ops = 1;
1291812963 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
1291912964 ctx->num_additional_fused_ops = 1;
12965+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
12966+ ctx->num_additional_fused_ops = 1;
1292012967 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
1292112968 ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
1292212969 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@@ -13142,7 +13189,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1314213189 is_src_of(graph->nodes[j], graph->nodes[c]) &&
1314313190 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
1314413191 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
13145- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
13192+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
13193+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
1314613194 ok = false;
1314713195 break;
1314813196 }
0 commit comments