Skip to content

Commit 4f922ce

Browse files
committed
vulkan: fuse mul_mat_id + mul
This comes up in qwen3 moe.
1 parent 8c0d6bb commit 4f922ce

File tree

3 files changed

+93
-9
lines changed

3 files changed

+93
-9
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ struct vk_mat_vec_push_constants {
802802
uint32_t batch_stride_b;
803803
uint32_t batch_stride_d;
804804
uint32_t enable_bias;
805+
uint32_t enable_scale;
805806
uint32_t ne02;
806807
uint32_t ne12;
807808
uint32_t broadcast2;
@@ -824,6 +825,7 @@ struct vk_mat_vec_id_push_constants {
824825
uint32_t batch_stride_b;
825826
uint32_t batch_stride_d;
826827
uint32_t enable_bias;
828+
uint32_t enable_scale;
827829
uint32_t nei0;
828830
uint32_t ne11;
829831
};
@@ -6796,7 +6798,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
67966798
// compute
67976799
const vk_mat_vec_push_constants pc = {
67986800
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
6799-
stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
6801+
stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
68006802
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
68016803
};
68026804
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -7617,13 +7619,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76177619
groups_x = CEIL_DIV(groups_x, groups_z);
76187620
}
76197621

7620-
uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
7622+
uint32_t enable_bias = 0;
7623+
uint32_t enable_scale = 0;
7624+
if (ctx->num_additional_fused_ops > 0) {
7625+
if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
7626+
enable_scale = 1;
7627+
} else {
7628+
GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
7629+
enable_bias = 1;
7630+
}
7631+
}
76217632

76227633
vk_buffer d_B = d_D;
76237634
size_t b_buf_offset = 0;
76247635
uint64_t b_sz = 0;
76257636

7626-
if (enable_bias) {
7637+
if (enable_bias || enable_scale) {
76277638
const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
76287639

76297640
bool b_uma = false;
@@ -7645,7 +7656,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
76457656
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
76467657
(uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
76477658

7648-
enable_bias,
7659+
enable_bias, enable_scale,
76497660

76507661
(uint32_t)nei0, (uint32_t)ne11,
76517662
};
@@ -12671,6 +12682,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
1267112682
}
1267212683
}
1267312684

12685+
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
12686+
// additional constraints specific to this fusion
12687+
const ggml_tensor *mmid = cgraph->nodes[node_idx];
12688+
const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
12689+
const ggml_tensor *scale = mul->src[1];
12690+
12691+
if (mmid != mul->src[0]) {
12692+
return false;
12693+
}
12694+
// mat-vec only
12695+
if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
12696+
return false;
12697+
}
12698+
// shaders assume the types match
12699+
if (mmid->type != scale->type) {
12700+
return false;
12701+
}
12702+
// shaders assume the bias is contiguous
12703+
if (!ggml_is_contiguous(scale)) {
12704+
return false;
12705+
}
12706+
// unaligned bias isn't handled
12707+
if (get_misalign_bytes(ctx, scale) != 0) {
12708+
return false;
12709+
}
12710+
// shader only indexes by expert index
12711+
if (scale->ne[0] != 1 ||
12712+
scale->ne[1] != mul->ne[1] ||
12713+
scale->ne[2] != 1 ||
12714+
scale->ne[3] != 1) {
12715+
return false;
12716+
}
12717+
}
12718+
1267412719
return true;
1267512720
}
1267612721

@@ -12917,6 +12962,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
1291712962
ctx->num_additional_fused_ops = 1;
1291812963
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
1291912964
ctx->num_additional_fused_ops = 1;
12965+
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
12966+
ctx->num_additional_fused_ops = 1;
1292012967
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
1292112968
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
1292212969
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@@ -13142,7 +13189,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
1314213189
is_src_of(graph->nodes[j], graph->nodes[c]) &&
1314313190
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
1314413191
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
13145-
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
13192+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
13193+
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
1314613194
ok = false;
1314713195
break;
1314813196
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ layout (push_constant) uniform parameter
4949
uint batch_stride_d;
5050

5151
uint enable_bias;
52+
uint enable_scale;
5253

5354
#ifdef MUL_MAT_ID
5455
uint nei0;
@@ -129,6 +130,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
129130
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
130131
#endif
131132
}
133+
#ifdef MUL_MAT_ID
134+
if (p.enable_scale != 0) {
135+
const uint expert_idx = gl_GlobalInvocationID.y;
136+
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
137+
}
138+
#endif
132139
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
133140
}
134141
}
@@ -171,6 +178,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
171178
temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
172179
#endif
173180
}
181+
#ifdef MUL_MAT_ID
182+
if (p.enable_scale != 0) {
183+
const uint expert_idx = gl_GlobalInvocationID.y;
184+
temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
185+
}
186+
#endif
174187
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
175188
}
176189
}
@@ -203,6 +216,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
203216
tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
204217
#endif
205218
}
219+
#ifdef MUL_MAT_ID
220+
if (p.enable_scale != 0) {
221+
const uint expert_idx = gl_GlobalInvocationID.y;
222+
tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
223+
}
224+
#endif
206225
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
207226
}
208227
}

tests/test-backend-ops.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,9 +3495,10 @@ struct test_mul_mat_id : public test_case {
34953495
const int64_t n;
34963496
const int64_t k;
34973497
const uint32_t o; // number of outputs
3498+
const bool mul;
34983499

34993500
std::string vars() override {
3500-
return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o);
3501+
return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);
35013502
}
35023503

35033504
double max_nmse_err() override {
@@ -3511,9 +3512,9 @@ struct test_mul_mat_id : public test_case {
35113512

35123513
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
35133514
int n_mats = 8, int n_used = 2, bool b = false,
3514-
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1)
3515+
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)
35153516
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
3516-
m(m), n(n), k(k), o(o) {
3517+
m(m), n(n), k(k), o(o), mul(mul) {
35173518
GGML_ASSERT(n_used <= n_mats);
35183519
}
35193520

@@ -3542,6 +3543,13 @@ struct test_mul_mat_id : public test_case {
35423543
out = ggml_add(ctx, out, out2);
35433544
}
35443545

3546+
if (mul) {
3547+
std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };
3548+
ne[0] = 1;
3549+
ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());
3550+
out = ggml_mul(ctx, out, m);
3551+
}
3552+
35453553
return out;
35463554
}
35473555

@@ -3566,7 +3574,7 @@ struct test_mul_mat_id : public test_case {
35663574
}
35673575
}
35683576

3569-
bool run_whole_graph() override { return o > 1; }
3577+
bool run_whole_graph() override { return o > 1 || mul; }
35703578

35713579
std::string op_desc(ggml_tensor * t) override {
35723580
GGML_UNUSED(t);
@@ -6978,6 +6986,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
69786986
}
69796987
}
69806988

6989+
for (int bs : {1, 4, 512}) {
6990+
for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {
6991+
for (ggml_type type_b : {GGML_TYPE_F32}) {
6992+
// test with mul after (ffn_moe_weighted)
6993+
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));
6994+
}
6995+
}
6996+
}
6997+
69816998
for (ggml_type type_a : base_types) {
69826999
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
69837000
for (int n : {1, 16}) {

0 commit comments

Comments
 (0)