Skip to content

Commit d51a5b4

Browse files
committed
metal : use C++ to store piplienes
ggml-ci
1 parent ff06f86 commit d51a5b4

File tree

6 files changed

+417
-374
lines changed

6 files changed

+417
-374
lines changed

ggml/src/ggml-metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ggml_add_backend_library(ggml-metal
1010
ggml-metal-device.cpp
1111
ggml-metal-common.cpp
1212
ggml-metal-context.m
13+
ggml-metal-context.cpp
1314
ggml-metal-ops.cpp
1415
)
1516

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "ggml-metal-context.h"
2+
3+
#include <string>
4+
#include <unordered_map>
5+
6+
struct ggml_metal_pipelines {
7+
std::unordered_map<std::string, ggml_metal_pipeline_t> data;
8+
};
9+
10+
ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
11+
ggml_metal_pipelines_t res = new ggml_metal_pipelines();
12+
13+
return res;
14+
}
15+
16+
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
17+
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
18+
ggml_metal_pipeline_free(it->second);
19+
}
20+
21+
delete ppls;
22+
}
23+
24+
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) {
25+
ppls->data[name] = pipeline;
26+
}
27+
28+
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
29+
if (ppls->data.find(name) == ppls->data.end()) {
30+
return nullptr;
31+
}
32+
33+
return ppls->data[name];
34+
}

ggml/src/ggml-metal/ggml-metal-context.h

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv);
1818
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
1919
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
2020

21+
//
22+
// MTLComputePipelineState wrapper
23+
//
24+
25+
typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
26+
27+
ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
28+
void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
29+
30+
void * ggml_metal_pipeline_get_obj(ggml_metal_pipeline_t pipeline);
31+
32+
// a collection of pipelines
33+
typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
34+
35+
ggml_metal_pipelines_t ggml_metal_pipelines_init(void);
36+
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
37+
38+
void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
39+
ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
40+
2141
//
2242
// backend
2343
//
@@ -27,20 +47,16 @@ typedef struct ggml_metal * ggml_metal_t;
2747
ggml_metal_t ggml_metal_init(ggml_metal_device_t ctx_dev);
2848
void ggml_metal_free(ggml_metal_t ctx);
2949

30-
typedef void * ggml_metal_pipeline_t;
31-
32-
ggml_metal_pipeline_t ggml_metal_get_pipeline(ggml_metal_t ctx, const char * name);
33-
50+
ggml_metal_pipeline_t ggml_metal_get_pipeline (ggml_metal_t ctx, const char * name);
3451
ggml_metal_pipeline_t ggml_metal_compile_pipeline(ggml_metal_t ctx, const char * base, const char * name, ggml_metal_cv_t cv);
3552

3653
void ggml_metal_synchronize(ggml_metal_t ctx);
3754

3855
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
3956
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
4057

41-
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf);
42-
43-
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);
58+
enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);
59+
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);
4460

4561
void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb);
4662
void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);

0 commit comments

Comments
 (0)