Skip to content

Commit ff06f86

Browse files
committed
metal : adjust ops API
ggml-ci
1 parent 967037f commit ff06f86

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

ggml/src/ggml-metal/ggml-metal-context.m

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,11 +1408,11 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
14081408
// src1 is a row
14091409
GGML_ASSERT(ne11 == 1);
14101410

1411-
pipeline = ggml_metal_op_bin_get_pipeline(ctx, dst->op, n_fuse, true);
1411+
pipeline = ggml_metal_op_bin_get_pipeline(node->op, ctx, n_fuse, true);
14121412

14131413
bcast_row = true;
14141414
} else {
1415-
pipeline = ggml_metal_op_bin_get_pipeline(ctx, dst->op, n_fuse, false);
1415+
pipeline = ggml_metal_op_bin_get_pipeline(node->op, ctx, n_fuse, false);
14161416
}
14171417

14181418
if (n_fuse > 1) {
@@ -1602,7 +1602,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
16021602
};
16031603

16041604
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_ADD].pipeline;
1605-
const id<MTLComputePipelineState> pipeline = ggml_metal_op_bin_get_pipeline(ctx, GGML_OP_ADD, 1, false);
1605+
const id<MTLComputePipelineState> pipeline = ggml_metal_op_bin_get_pipeline(GGML_OP_ADD, ctx, 1, false);
16061606

16071607
[encoder setComputePipelineState:pipeline];
16081608
[encoder setBytes:&args length:sizeof(args) atIndex:0];
@@ -3517,7 +3517,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
35173517
}
35183518
}
35193519

3520-
const id<MTLComputePipelineState> pipeline = ggml_metal_op_rms_norm_get_pipeline(ctx, node, n_fuse);
3520+
const id<MTLComputePipelineState> pipeline = ggml_metal_op_rms_norm_get_pipeline(node, ctx, n_fuse);
35213521

35223522
int nth = 32; // SIMD width
35233523

@@ -4257,7 +4257,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42574257
/*.logit_softcap =*/ logit_softcap,
42584258
};
42594259

4260-
id<MTLComputePipelineState> pipeline = ggml_metal_op_flash_attn_ext_get_pipeline(ctx, node, has_mask, has_sinks, has_bias, has_scap, nsg);
4260+
id<MTLComputePipelineState> pipeline = ggml_metal_op_flash_attn_ext_get_pipeline(node, ctx, has_mask, has_sinks, has_bias, has_scap, nsg);
42614261

42624262
[encoder setComputePipelineState:pipeline];
42634263
[encoder setBytes:&args length:sizeof(args) atIndex:0];
@@ -4372,7 +4372,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
43724372
/*.logit_softcap =*/ logit_softcap,
43734373
};
43744374

4375-
id<MTLComputePipelineState> pipeline = ggml_metal_op_flash_attn_ext_vec_get_pipeline(ctx, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
4375+
id<MTLComputePipelineState> pipeline = ggml_metal_op_flash_attn_ext_vec_get_pipeline(node, ctx, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
43764376

43774377
GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
43784378

@@ -4426,7 +4426,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
44264426
nrows,
44274427
};
44284428

4429-
id<MTLComputePipelineState> pipeline0 = ggml_metal_op_flash_attn_ext_vec_reduce_get_pipeline(ctx, node, ne20, nwg);
4429+
id<MTLComputePipelineState> pipeline0 = ggml_metal_op_flash_attn_ext_vec_reduce_get_pipeline(node, ctx, ne20, nwg);
44304430

44314431
[encoder setComputePipelineState:pipeline0];
44324432
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
4848
}
4949

5050
ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_get_pipeline(
51-
ggml_metal_t ctx,
5251
ggml_tensor * op,
52+
ggml_metal_t ctx,
5353
bool has_mask,
5454
bool has_sinks,
5555
bool has_bias,
@@ -107,8 +107,8 @@ ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_get_pipeline(
107107
}
108108

109109
ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_get_pipeline(
110-
ggml_metal_t ctx,
111110
ggml_tensor * op,
111+
ggml_metal_t ctx,
112112
bool has_mask,
113113
bool has_sinks,
114114
bool has_bias,
@@ -168,8 +168,8 @@ ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_get_pipeline(
168168
}
169169

170170
ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_reduce_get_pipeline(
171-
ggml_metal_t ctx,
172171
ggml_tensor * op,
172+
ggml_metal_t ctx,
173173
int32_t dv,
174174
int32_t nwg) {
175175
char base[256];
@@ -198,8 +198,8 @@ ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_reduce_get_pipeline(
198198
}
199199

200200
ggml_metal_pipeline_t ggml_metal_op_bin_get_pipeline(
201-
ggml_metal_t ctx,
202201
enum ggml_op op,
202+
ggml_metal_t ctx,
203203
int32_t n_fuse,
204204
bool row) {
205205
char base[256];
@@ -231,8 +231,8 @@ ggml_metal_pipeline_t ggml_metal_op_bin_get_pipeline(
231231
}
232232

233233
ggml_metal_pipeline_t ggml_metal_op_rms_norm_get_pipeline(
234-
ggml_metal_t ctx,
235234
ggml_tensor * op,
235+
ggml_metal_t ctx,
236236
int32_t n_fuse) {
237237
char base[256];
238238
char name[256];

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
1818
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
1919

2020
ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_get_pipeline(
21-
ggml_metal_t ctx,
2221
struct ggml_tensor * op,
22+
ggml_metal_t ctx,
2323
bool has_mask,
2424
bool has_sinks,
2525
bool has_bias,
2626
bool has_scap,
2727
int32_t nsg);
2828

2929
ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_get_pipeline(
30-
ggml_metal_t ctx,
3130
struct ggml_tensor * op,
31+
ggml_metal_t ctx,
3232
bool has_mask,
3333
bool has_sinks,
3434
bool has_bias,
@@ -37,20 +37,20 @@ ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_get_pipeline(
3737
int32_t nwg);
3838

3939
ggml_metal_pipeline_t ggml_metal_op_flash_attn_ext_vec_reduce_get_pipeline(
40-
ggml_metal_t ctx,
4140
struct ggml_tensor * op,
41+
ggml_metal_t ctx,
4242
int32_t dv,
4343
int32_t nwg);
4444

4545
ggml_metal_pipeline_t ggml_metal_op_bin_get_pipeline(
46-
ggml_metal_t ctx,
4746
enum ggml_op op,
47+
ggml_metal_t ctx,
4848
int32_t n_fuse,
4949
bool row);
5050

5151
ggml_metal_pipeline_t ggml_metal_op_rms_norm_get_pipeline(
52-
ggml_metal_t ctx,
5352
struct ggml_tensor * op,
53+
ggml_metal_t ctx,
5454
int32_t n_fuse);
5555

5656
#ifdef __cplusplus

ggml/src/ggml-metal/ggml-metal.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,7 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const
689689

690690
GGML_UNUSED(reg);
691691
}
692+
692693
static ggml_backend_reg_i ggml_backend_metal_reg_i = {
693694
/* .get_name = */ ggml_backend_metal_reg_get_name,
694695
/* .device_count = */ ggml_backend_metal_reg_device_count,

0 commit comments

Comments
 (0)