Skip to content

Commit 0320ac5

Browse files
authored
metal : refactor + optimize v2 (#15995)
* metal : improve naming * metal : refactor device ggml-ci * cont : props ggml-ci * metal : apply ggml_mem_ranges_t ggml-ci * metal : remove GGML_METAL_USE_BF16 ggml-ci * metal : refactor device buffer ggml-ci * cont : fix naming * metal : sync before destroying the backend ggml-ci * metal : refactor context ggml-ci * metal : migrate ggml-metal.m to ggml-metal.cpp ggml-ci * metal : adjust ops API ggml-ci * metal : use C++ to store piplienes ggml-ci * metal : migrate ops to separate functions ggml-ci * metal : add ggml_metal_library_t ggml-ci * metal : improve naming ggml-ci * metal : cleanp ggml-ci * metal : add support for GGML_OP_LOG ggml-ci * metal : fix error handling ggml-ci
1 parent a7a98e0 commit 0320ac5

19 files changed

+7870
-7181
lines changed

ci/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ SRC=`pwd`
4545
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
4646

4747
if [ ! -z ${GG_BUILD_METAL} ]; then
48-
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON"
48+
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
4949
fi
5050

5151
if [ ! -z ${GG_BUILD_CUDA} ]; then

ggml/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ option(GGML_WEBGPU "ggml: use WebGPU"
190190
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
191191
option(GGML_ZDNN "ggml: use zDNN" OFF)
192192
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
193-
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
194193
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
195194
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
196195
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})

ggml/include/ggml-metal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ extern "C" {
3939
// user-code should use only these functions
4040
//
4141

42+
// TODO: remove in the future
4243
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
4344

4445
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);

ggml/include/ggml.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,19 +284,19 @@ __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexc
284284
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
285285
//
286286
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
287-
const type prefix##0 = (pointer)->array[0]; \
287+
const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
288288
GGML_UNUSED(prefix##0);
289289
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
290290
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
291-
const type prefix##1 = (pointer)->array[1]; \
291+
const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
292292
GGML_UNUSED(prefix##1);
293293
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
294294
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
295-
const type prefix##2 = (pointer)->array[2]; \
295+
const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
296296
GGML_UNUSED(prefix##2);
297297
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
298298
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
299-
const type prefix##3 = (pointer)->array[3]; \
299+
const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
300300
GGML_UNUSED(prefix##3);
301301

302302
#define GGML_TENSOR_UNARY_OP_LOCALS \

ggml/src/ggml-metal/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
55
message(STATUS "Metal framework found")
66

77
ggml_add_backend_library(ggml-metal
8-
ggml-metal.m
8+
ggml-metal.cpp
9+
ggml-metal-device.m
10+
ggml-metal-device.cpp
911
ggml-metal-common.cpp
12+
ggml-metal-context.m
13+
ggml-metal-ops.cpp
1014
)
1115

1216
target_link_libraries(ggml-metal PRIVATE
@@ -19,10 +23,6 @@ if (GGML_METAL_NDEBUG)
1923
add_compile_definitions(GGML_METAL_NDEBUG)
2024
endif()
2125

22-
if (GGML_METAL_USE_BF16)
23-
add_compile_definitions(GGML_METAL_USE_BF16)
24-
endif()
25-
2626
# copy metal files to bin directory
2727
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
2828
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)

ggml/src/ggml-metal/ggml-metal-common.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct ggml_mem_ranges {
2222
int debug = 0;
2323
};
2424

25-
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
25+
ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
2626
auto * res = new ggml_mem_ranges;
2727

2828
res->ranges.reserve(256);
@@ -31,15 +31,15 @@ struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
3131
return res;
3232
}
3333

34-
void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
34+
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
3535
delete mrs;
3636
}
3737

38-
void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
38+
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
3939
mrs->ranges.clear();
4040
}
4141

42-
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
42+
static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
4343
mrs->ranges.push_back(mr);
4444

4545
return true;
@@ -87,7 +87,7 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
8787
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
8888
}
8989

90-
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
90+
static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
9191
GGML_ASSERT(tensor);
9292

9393
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
@@ -99,7 +99,7 @@ static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * t
9999
return ggml_mem_ranges_add(mrs, mr);
100100
}
101101

102-
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
102+
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
103103
GGML_ASSERT(tensor);
104104

105105
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
@@ -111,7 +111,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * t
111111
return ggml_mem_ranges_add(mrs, mr);
112112
}
113113

114-
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
114+
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
115115
for (int i = 0; i < GGML_MAX_DIMS; i++) {
116116
if (tensor->src[i]) {
117117
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
@@ -121,7 +121,7 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
121121
return ggml_mem_ranges_add_dst(mrs, tensor);
122122
}
123123

124-
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
124+
static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
125125
for (size_t i = 0; i < mrs->ranges.size(); i++) {
126126
const auto & cmp = mrs->ranges[i];
127127

@@ -152,7 +152,7 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
152152
return true;
153153
}
154154

155-
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
155+
static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
156156
GGML_ASSERT(tensor);
157157

158158
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
@@ -162,7 +162,7 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te
162162
return res;
163163
}
164164

165-
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
165+
static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
166166
GGML_ASSERT(tensor);
167167

168168
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
@@ -172,7 +172,7 @@ static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_te
172172
return res;
173173
}
174174

175-
bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
175+
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
176176
for (int i = 0; i < GGML_MAX_DIMS; i++) {
177177
if (tensor->src[i]) {
178178
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
@@ -222,7 +222,7 @@ struct node_info {
222222

223223
static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
224224
// helper to add node src and dst ranges
225-
const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
225+
const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
226226
for (int i = 0; i < GGML_MAX_SRC; i++) {
227227
if (node.node->src[i]) {
228228
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
@@ -246,7 +246,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
246246
};
247247

248248
// helper to check if a node can run concurrently with the existing set of nodes
249-
const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
249+
const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
250250
for (int i = 0; i < GGML_MAX_SRC; i++) {
251251
if (node.node->src[i]) {
252252
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
@@ -301,10 +301,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
301301
std::vector<bool> used(n, false);
302302

303303
// the memory ranges for the set of currently concurrent nodes
304-
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
304+
ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);
305305

306306
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
307-
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
307+
ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);
308308

309309
for (int i0 = 0; i0 < n; i0++) {
310310
if (used[i0]) {
@@ -375,7 +375,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
375375
return res;
376376
}
377377

378-
void ggml_metal_graph_optimize(ggml_cgraph * gf) {
378+
void ggml_graph_optimize(ggml_cgraph * gf) {
379379
constexpr int MAX_FUSE = 16;
380380

381381
const int n = gf->n_nodes;

ggml/src/ggml-metal/ggml-metal-common.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,27 @@ enum ggml_mem_range_type {
2525
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
2626
// tasks already in the set)
2727
//
28-
struct ggml_mem_ranges;
28+
typedef struct ggml_mem_ranges * ggml_mem_ranges_t;
2929

30-
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
31-
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
30+
ggml_mem_ranges_t ggml_mem_ranges_init(int debug);
31+
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs);
3232

3333
// remove all ranges from the set
34-
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
34+
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs);
3535

3636
// add src or dst ranges to track
37-
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
37+
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
3838

3939
// return false if:
4040
// - new src range overlaps with any existing dst range
4141
// - new dst range overlaps with any existing range (src or dst)
42-
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
42+
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
4343

4444
// reorder the nodes in the graph to improve concurrency, while respecting fusion
4545
//
4646
// note: this implementation is generic and not specific to metal
4747
// if it proves to work well, we can start using it for other backends in the future
48-
void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
48+
void ggml_graph_optimize(struct ggml_cgraph * gf);
4949

5050
#ifdef __cplusplus
5151
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include "ggml-metal-device.h"
4+
5+
#ifdef __cplusplus
6+
extern "C" {
7+
#endif
8+
9+
//
10+
// backend context
11+
//
12+
13+
typedef struct ggml_metal * ggml_metal_t;
14+
15+
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev);
16+
void ggml_metal_free(ggml_metal_t ctx);
17+
18+
void ggml_metal_synchronize(ggml_metal_t ctx);
19+
20+
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
21+
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
22+
23+
enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);
24+
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);
25+
26+
void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb);
27+
void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);
28+
bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
29+
void ggml_metal_capture_next_compute(ggml_metal_t ctx);
30+
31+
#ifdef __cplusplus
32+
}
33+
#endif

0 commit comments

Comments
 (0)