Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ SRC=`pwd`
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"

if [ ! -z ${GG_BUILD_METAL} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON"
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
fi

if [ ! -z ${GG_BUILD_CUDA} ]; then
Expand Down
1 change: 0 additions & 1 deletion ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ option(GGML_WEBGPU "ggml: use WebGPU"
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
option(GGML_ZDNN "ggml: use zDNN" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ extern "C" {
// user-code should use only these functions
//

// TODO: remove in the future
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);

GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
Expand Down
8 changes: 4 additions & 4 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,19 +284,19 @@ __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexc
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
//
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
const type prefix##0 = (pointer)->array[0]; \
const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
GGML_UNUSED(prefix##0);
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
const type prefix##1 = (pointer)->array[1]; \
const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
GGML_UNUSED(prefix##1);
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
const type prefix##2 = (pointer)->array[2]; \
const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
GGML_UNUSED(prefix##2);
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
const type prefix##3 = (pointer)->array[3]; \
const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
GGML_UNUSED(prefix##3);

#define GGML_TENSOR_UNARY_OP_LOCALS \
Expand Down
10 changes: 5 additions & 5 deletions ggml/src/ggml-metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
message(STATUS "Metal framework found")

ggml_add_backend_library(ggml-metal
ggml-metal.m
ggml-metal.cpp
ggml-metal-device.m
ggml-metal-device.cpp
ggml-metal-common.cpp
ggml-metal-context.m
ggml-metal-ops.cpp
)

target_link_libraries(ggml-metal PRIVATE
Expand All @@ -19,10 +23,6 @@ if (GGML_METAL_NDEBUG)
add_compile_definitions(GGML_METAL_NDEBUG)
endif()

if (GGML_METAL_USE_BF16)
add_compile_definitions(GGML_METAL_USE_BF16)
endif()

# copy metal files to bin directory
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
Expand Down
32 changes: 16 additions & 16 deletions ggml/src/ggml-metal/ggml-metal-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct ggml_mem_ranges {
int debug = 0;
};

struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
auto * res = new ggml_mem_ranges;

res->ranges.reserve(256);
Expand All @@ -31,15 +31,15 @@ struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
return res;
}

void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
delete mrs;
}

void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
mrs->ranges.clear();
}

static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
mrs->ranges.push_back(mr);

return true;
Expand Down Expand Up @@ -87,7 +87,7 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
}

static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);

ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
Expand All @@ -99,7 +99,7 @@ static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * t
return ggml_mem_ranges_add(mrs, mr);
}

static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);

ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
Expand All @@ -111,7 +111,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * t
return ggml_mem_ranges_add(mrs, mr);
}

bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (tensor->src[i]) {
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
Expand All @@ -121,7 +121,7 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
return ggml_mem_ranges_add_dst(mrs, tensor);
}

static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
for (size_t i = 0; i < mrs->ranges.size(); i++) {
const auto & cmp = mrs->ranges[i];

Expand Down Expand Up @@ -152,7 +152,7 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
return true;
}

static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);

ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
Expand All @@ -162,7 +162,7 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te
return res;
}

static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);

ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
Expand All @@ -172,7 +172,7 @@ static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_te
return res;
}

bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (tensor->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
Expand Down Expand Up @@ -222,7 +222,7 @@ struct node_info {

static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
// helper to add node src and dst ranges
const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node.node->src[i]) {
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
Expand All @@ -246,7 +246,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
};

// helper to check if a node can run concurrently with the existing set of nodes
const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node.node->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
Expand Down Expand Up @@ -301,10 +301,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
std::vector<bool> used(n, false);

// the memory ranges for the set of currently concurrent nodes
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);

// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);

for (int i0 = 0; i0 < n; i0++) {
if (used[i0]) {
Expand Down Expand Up @@ -375,7 +375,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
return res;
}

void ggml_metal_graph_optimize(ggml_cgraph * gf) {
void ggml_graph_optimize(ggml_cgraph * gf) {
constexpr int MAX_FUSE = 16;

const int n = gf->n_nodes;
Expand Down
14 changes: 7 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,27 @@ enum ggml_mem_range_type {
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
// tasks already in the set)
//
struct ggml_mem_ranges;
typedef struct ggml_mem_ranges * ggml_mem_ranges_t;

struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
ggml_mem_ranges_t ggml_mem_ranges_init(int debug);
void ggml_mem_ranges_free(ggml_mem_ranges_t mrs);

// remove all ranges from the set
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs);

// add src or dst ranges to track
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);

// return false if:
// - new src range overlaps with any existing dst range
// - new dst range overlaps with any existing range (src or dst)
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);

// reorder the nodes in the graph to improve concurrency, while respecting fusion
//
// note: this implementation is generic and not specific to metal
// if it proves to work well, we can start using it for other backends in the future
void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
void ggml_graph_optimize(struct ggml_cgraph * gf);

#ifdef __cplusplus
}
Expand Down
33 changes: 33 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "ggml-metal-device.h"

#ifdef __cplusplus
extern "C" {
#endif

//
// backend context
//

typedef struct ggml_metal * ggml_metal_t;

ggml_metal_t ggml_metal_init(ggml_metal_device_t dev);
void ggml_metal_free(ggml_metal_t ctx);

void ggml_metal_synchronize(ggml_metal_t ctx);

void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);

enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);

void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb);
void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);
bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
void ggml_metal_capture_next_compute(ggml_metal_t ctx);

#ifdef __cplusplus
}
#endif
Loading
Loading