Skip to content

Commit 716a3d3

Browse files
committed
metal : migrate ops to separate functions
ggml-ci
1 parent d51a5b4 commit 716a3d3

File tree

9 files changed

+4624
-4496
lines changed

9 files changed

+4624
-4496
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,19 +284,19 @@ __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexc
284284
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
285285
//
286286
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
287-
const type prefix##0 = (pointer)->array[0]; \
287+
const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
288288
GGML_UNUSED(prefix##0);
289289
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
290290
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
291-
const type prefix##1 = (pointer)->array[1]; \
291+
const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
292292
GGML_UNUSED(prefix##1);
293293
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
294294
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
295-
const type prefix##2 = (pointer)->array[2]; \
295+
const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
296296
GGML_UNUSED(prefix##2);
297297
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
298298
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
299-
const type prefix##3 = (pointer)->array[3]; \
299+
const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
300300
GGML_UNUSED(prefix##3);
301301

302302
#define GGML_TENSOR_UNARY_OP_LOCALS \

ggml/src/ggml-metal/ggml-metal-context.h

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,19 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
2727
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
2828
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
2929

30-
void * ggml_metal_pipeline_get_obj(ggml_metal_pipeline_t pipeline);
30+
void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
31+
int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
32+
33+
void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
34+
int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
35+
36+
void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
37+
int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
38+
39+
void ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
40+
size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
41+
42+
int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
3143

3244
// a collection of pipelines
3345
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
@@ -38,6 +50,37 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
3850
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
3951
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
4052

53+
//
54+
// MTLCommandBuffer wrapper
55+
//
56+
57+
typedef void * ggml_metal_cmd_buf_t;
58+
59+
//
60+
// MTLComputeCommandEncoder wrapper
61+
//
62+
63+
typedef struct ggml_metal_encoder * ggml_metal_encoder_t;
64+
65+
ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent);
66+
void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
67+
68+
void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
69+
void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
70+
71+
void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
72+
73+
void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
74+
void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
75+
76+
void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx);
77+
78+
void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2);
79+
80+
void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder);
81+
82+
void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
83+
4184
//
4285
// backend
4386
//
@@ -63,6 +106,39 @@ void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort
63106
bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
64107
void ggml_metal_capture_next_compute(ggml_metal_t ctx);
65108

109+
//
110+
// graph encoder
111+
//
112+
113+
typedef struct ggml_metal_graph_encoder * ggml_metal_graph_encoder_t;
114+
115+
// TODO: tmp
116+
#include "ggml-metal-common.h"
117+
118+
// TODO: tmp
119+
struct ggml_metal_graph_encoder {
120+
ggml_metal_t ctx;
121+
122+
const struct ggml_metal_device_props * props_dev;
123+
124+
ggml_metal_encoder_t encoder;
125+
126+
ggml_mem_ranges_t mem_ranges;
127+
128+
struct ggml_cgraph * gf;
129+
130+
int idx_start;
131+
int idx_end;
132+
133+
bool use_fusion;
134+
135+
int debug_fusion;
136+
};
137+
138+
bool ggml_metal_graph_encoder_concurrency_reset(ggml_metal_graph_encoder_t ctx);
139+
bool ggml_metal_graph_encoder_concurrency_check(ggml_metal_graph_encoder_t ctx, const struct ggml_tensor * node);
140+
bool ggml_metal_graph_encoder_concurrency_add (ggml_metal_graph_encoder_t ctx, const struct ggml_tensor * node);
141+
66142
#ifdef __cplusplus
67143
}
68144
#endif

0 commit comments

Comments
 (0)