Skip to content

Commit d53fa8d

Browse files
committed
sampling : add support for GPU sampling (wip)
This is a work in progress to add support for GPU sampling. The motivation for this feature is to enable sampling to be performed directly on the GPU as part of the computation graph being executed, allowing for some or all of the sampling to be done on the GPU. For example, the GPU sampler chain might select/sample a token directly in which case only the sampled token needs to be transferred from device memory to host memory. It is also possible for the GPU samplers to perform filtering of the logits, or compute and filter the probability distribution, in which case only the filtered logits or probabilites need to be transferred back to system memory for further processing by CPU samplers. Currently the GPU sampling works in a similar manner to how pooling works, it is a function that is called by build_graph: ```c++ // add GPU sampling layers (if any) llm->build_sampling(*this, params); ``` GPU samplers can be configured by creating sampler chains, where each sampler chain is associated with a specific sequence id: ```c++ struct llama_sampler_chain_params params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(params); llama_sampler_chain_add(chain, llama_sampler_gpu_init_greedy()); std::vector<llama_sampler_seq_config> sampler_configs = { { 0, gpu_sampler_chain } }; ``` The struct is defined as: ```c++ struct llama_sampler_seq_config { llama_seq_id seq_id; struct llama_sampler * sampler; }; ``` These sampler configs are then passed as context params: ```c++ llama_context_params cparams = llama_context_default_params(); cparams.samplers = sampler_configs.data(); cparams.n_samplers = sampler_configs.size(); ``` When the graph is built, the configured sampler's _apply function is called which allows them to add operations/nodes to the computation graph. This enables the sampling to happen fully, or partially on the GPU. The samplers could sample a single token in which case that is what will be transferred from the device memory to host memory after llama_decode has been called. The sampled token can then be retrieved using: ```c++ llama_token id = llama_get_sampled_token_ith(test_ctx.ctx, index); ``` Is it also possible to run a GPU sampler that only filters the logits and then only the filtered logits are transferred back to the host and the sampling can proceed on the CPU with the normal (CPU) sampler chain. In this case the CPU samplers are configured as usual but they will now operate on already filtered logits. Similar to the above handling of logits, it is possible for a GPU samplers to compute the full probability distribution and transfer that to the host. And the CPU samplers can then operate on the those probabilities. Building and running the tests: Download a model for testing: ```console $ cd models && wget https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf ``` Building the test: ```console $ cmake --build build --target test-gpu-sampling -j8 ``` Runing all tests: ```console $ env LLAMACPP_TEST_MODELFILE=../models/stories15M-q4_0.gguf \ ctest --test-dir build -R '^test-gpu-sampling$' -V ``` The following individual tests are available: ```console $ ctest --test-dir build -N -R test-gpu-sampling- Test 35: test-gpu-sampling-greedy Test 36: test-gpu-sampling-temp Test 37: test-gpu-sampling-softmax Test 38: test-gpu-sampling-top_k Test 39: test-gpu-sampling-top_p Test 40: test-gpu-sampling-mul_seq Total Tests: 6 ``` These can be run individually, for example: ```console $ env LLAMACPP_TEST_MODELFILE=../models/stories15M-q4_0.gguf \ ctest --test-dir build -R 'test-gpu-sampling-temp' -V ``` TODO: - [ ] Allow GPU samplers to pre-allocate state tensors - [ ] Integrate GPU samplers with llama-server - [ ] Implement true top-p sampler on GPU - [ ] Add missing GPU samplers (e.g. typical, mirostat, etc)
1 parent 66d8ecc commit d53fa8d

14 files changed

+1709
-191
lines changed

common/llguidance.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,14 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
106106
}
107107

108108
static llama_sampler_i llama_sampler_llg_i = {
109-
/* .name = */ llama_sampler_llg_name,
110-
/* .accept = */ llama_sampler_llg_accept_impl,
111-
/* .apply = */ llama_sampler_llg_apply,
112-
/* .reset = */ llama_sampler_llg_reset,
113-
/* .clone = */ llama_sampler_llg_clone,
114-
/* .free = */ llama_sampler_llg_free,
109+
/* .name = */ llama_sampler_llg_name,
110+
/* .accept = */ llama_sampler_llg_accept_impl,
111+
/* .apply = */ llama_sampler_llg_apply,
112+
/* .reset = */ llama_sampler_llg_reset,
113+
/* .clone = */ llama_sampler_llg_clone,
114+
/* .free = */ llama_sampler_llg_free,
115+
/* .apply_ggml = */ NULL,
116+
/* .accept_ggml = */ NULL,
115117
};
116118

117119
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,

include/llama.h

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,13 @@ extern "C" {
210210
bool sorted; // note: do not assume the data is sorted - always check this flag
211211
} llama_token_data_array;
212212

213+
struct llama_sampler_ggml_data {
214+
struct ggml_tensor * logits; // [n_vocab] - GGML_TYPE_F32
215+
struct ggml_tensor * probs; // [n_vocab, n_vocab] - GGML_TYPE_F32
216+
struct ggml_tensor * sampled_token; // [1, n_vocab] - GGML_TYPE_I32
217+
struct ggml_tensor * filtered_ids; // [k] - GGML_TYPE_I32
218+
};
219+
213220
typedef bool (*llama_progress_callback)(float progress, void * user_data);
214221

215222
// Input data for llama_encode/llama_decode
@@ -300,6 +307,11 @@ extern "C" {
300307
bool no_host; // bypass host buffer allowing extra buffers to be used
301308
};
302309

310+
struct llama_sampler_seq_config {
311+
llama_seq_id seq_id;
312+
struct llama_sampler * sampler;
313+
};
314+
303315
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
304316
// https://github.com/ggml-org/llama.cpp/pull/7544
305317
struct llama_context_params {
@@ -348,6 +360,10 @@ extern "C" {
348360
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
349361
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
350362
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
363+
364+
// GPU sampler chain configuration
365+
struct llama_sampler_seq_config * samplers;
366+
size_t n_samplers;
351367
};
352368

353369
// model quantization parameters
@@ -948,6 +964,29 @@ extern "C" {
948964
// otherwise: float[n_embd] (1-dimensional)
949965
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
950966

967+
// Get the GPU sampled token for the ith token.
968+
// Returns LLAMA_TOKEN_NULL if no token was sampled.
969+
LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
970+
971+
// Get the GPU sampled probabilites for the ith token
972+
// The index matches llama_get_sampled_token_ith().
973+
// Returns NULL if no probabilites were generated.
974+
LLAMA_API float * llama_get_sampled_probs_ith(struct llama_context * ctx, int32_t i);
975+
976+
// Get the GPU sampled logits for the ith token
977+
// Returns NULL if no logits were sampled.
978+
LLAMA_API float * llama_get_sampled_logits_ith(struct llama_context * ctx, int32_t i);
979+
980+
// Get the GPU sampled token ids associated with the sampled logits for the ith token
981+
// Returns NULL if no logits were sampled.
982+
LLAMA_API llama_token * llama_get_sampled_token_ids_ith(struct llama_context * ctx, int32_t i);
983+
984+
// Get the number of GPU sampled logits for the ith token.
985+
LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
986+
987+
// Get the number of GPU sampled probabilites for the ith token.
988+
LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
989+
951990
//
952991
// Vocab
953992
//
@@ -1133,6 +1172,17 @@ extern "C" {
11331172
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
11341173
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
11351174

1175+
void (*apply_ggml)( struct llama_sampler * smpl,
1176+
struct ggml_context * ctx,
1177+
struct ggml_cgraph * gf,
1178+
struct llama_sampler_ggml_data * ggml_data);
1179+
1180+
void (*accept_ggml)( struct llama_sampler * smpl,
1181+
struct ggml_context * ctx,
1182+
struct ggml_cgraph * gf,
1183+
struct ggml_tensor * selected_token);
1184+
1185+
11361186
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
11371187
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
11381188
};
@@ -1150,7 +1200,16 @@ extern "C" {
11501200
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
11511201
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
11521202
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
1153-
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
1203+
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
1204+
LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl,
1205+
struct ggml_context * ctx,
1206+
struct ggml_cgraph * gf,
1207+
struct llama_sampler_ggml_data * ggml_data);
1208+
1209+
LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl,
1210+
struct ggml_context * ctx,
1211+
struct ggml_cgraph * gf,
1212+
struct ggml_tensor * selected_token);
11541213

11551214
// llama_sampler_chain
11561215
// a type of llama_sampler that can chain multiple samplers one after another
@@ -1164,6 +1223,7 @@ extern "C" {
11641223

11651224
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
11661225
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
1226+
LLAMA_API uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain);
11671227

11681228
// available samplers:
11691229

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_library(llama
3131
llama-model.cpp
3232
llama-quant.cpp
3333
llama-sampling.cpp
34+
llama-gpu-sampling.cpp
3435
llama-vocab.cpp
3536
unicode-data.cpp
3637
unicode.cpp

0 commit comments

Comments
 (0)