Skip to content

Commit c82b67b

Browse files
committed
add support for dist sampling on GPU
This commit add support for performing distribution sampling on the GPU. It adds a function to the sampler interface for setting input tensors which will be called after the computation graph has been built and scheduled. For the dist sampler this allows it to set a random uniform value that is used to sample from the cumulative distribution.
1 parent 8e438ed commit c82b67b

File tree

9 files changed

+398
-193
lines changed

9 files changed

+398
-193
lines changed

common/llguidance.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,15 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
106106
}
107107

108108
static llama_sampler_i llama_sampler_llg_i = {
109-
/* .name = */ llama_sampler_llg_name,
110-
/* .accept = */ llama_sampler_llg_accept_impl,
111-
/* .apply = */ llama_sampler_llg_apply,
112-
/* .reset = */ llama_sampler_llg_reset,
113-
/* .clone = */ llama_sampler_llg_clone,
114-
/* .free = */ llama_sampler_llg_free,
115-
/* .apply_ggml = */ NULL,
116-
/* .accept_ggml = */ NULL,
109+
/* .name = */ llama_sampler_llg_name,
110+
/* .accept = */ llama_sampler_llg_accept_impl,
111+
/* .apply = */ llama_sampler_llg_apply,
112+
/* .reset = */ llama_sampler_llg_reset,
113+
/* .clone = */ llama_sampler_llg_clone,
114+
/* .free = */ llama_sampler_llg_free,
115+
/* .apply_ggml = */ NULL,
116+
/* .accept_ggml = */ NULL,
117+
/* .set_input_ggml = */ NULL,
117118
};
118119

119120
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,7 @@ extern "C" {
11821182
struct ggml_cgraph * gf,
11831183
struct ggml_tensor * selected_token);
11841184

1185+
void (*set_input_ggml)(struct llama_sampler * smpl);
11851186

11861187
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
11871188
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
@@ -1210,6 +1211,8 @@ extern "C" {
12101211
struct ggml_context * ctx,
12111212
struct ggml_cgraph * gf,
12121213
struct ggml_tensor * selected_token);
1214+
LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl,
1215+
struct ggml_context * ctx);
12131216

12141217
// llama_sampler_chain
12151218
// a type of llama_sampler that can chain multiple samplers one after another

src/llama-gpu-sampling.cpp

Lines changed: 174 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "llama-gpu-sampling.h"
22
#include "ggml.h"
33
#include <cstdio>
4+
#include <chrono>
5+
#include <random>
46

57
static void llama_sampler_gpu_greedy_apply_ggml(
68
struct llama_sampler * smpl,
@@ -27,14 +29,15 @@ static struct llama_sampler * llama_sampler_gpu_greedy_clone(const struct llama_
2729

2830
struct llama_sampler * llama_sampler_gpu_init_greedy() {
2931
static const llama_sampler_i iface = {
30-
/*.name =*/ llama_sampler_gpu_greedy_sampler_name,
31-
/*.accept =*/ nullptr,
32-
/*.apply =*/ nullptr,
33-
/*.reset =*/ nullptr,
34-
/*.clone =*/ llama_sampler_gpu_greedy_clone,
35-
/*.free =*/ nullptr,
36-
/*.apply_ggml =*/ llama_sampler_gpu_greedy_apply_ggml,
37-
/*.accept_ggml =*/ nullptr,
32+
/*.name =*/ llama_sampler_gpu_greedy_sampler_name,
33+
/*.accept =*/ nullptr,
34+
/*.apply =*/ nullptr,
35+
/*.reset =*/ nullptr,
36+
/*.clone =*/ llama_sampler_gpu_greedy_clone,
37+
/*.free =*/ nullptr,
38+
/*.apply_ggml =*/ llama_sampler_gpu_greedy_apply_ggml,
39+
/*.accept_ggml =*/ nullptr,
40+
/*.set_input_ggml =*/ nullptr,
3841
};
3942

4043
auto * sampler = new llama_sampler {
@@ -85,14 +88,15 @@ static struct llama_sampler * llama_sampler_gpu_temp_clone(const struct llama_sa
8588

8689
struct llama_sampler * llama_sampler_gpu_init_temp(float temp) {
8790
static const llama_sampler_i iface = {
88-
/*.name =*/ llama_sampler_gpu_temp_name,
89-
/*.accept =*/ nullptr,
90-
/*.apply =*/ nullptr,
91-
/*.reset =*/ nullptr,
92-
/*.clone =*/ llama_sampler_gpu_temp_clone,
93-
/*.free =*/ llama_sampler_gpu_temp_free,
94-
/*.apply_ggml =*/ llama_sampler_gpu_temp_apply_ggml,
95-
/*.accept_ggml =*/ nullptr,
91+
/*.name =*/ llama_sampler_gpu_temp_name,
92+
/*.accept =*/ nullptr,
93+
/*.apply =*/ nullptr,
94+
/*.reset =*/ nullptr,
95+
/*.clone =*/ llama_sampler_gpu_temp_clone,
96+
/*.free =*/ llama_sampler_gpu_temp_free,
97+
/*.apply_ggml =*/ llama_sampler_gpu_temp_apply_ggml,
98+
/*.accept_ggml =*/ nullptr,
99+
/*.set_input_ggml =*/ nullptr,
96100
};
97101

98102
auto * ctx_data = new llama_sampler_gpu_temp_ctx {
@@ -141,14 +145,15 @@ static struct llama_sampler * llama_sampler_gpu_softmax_clone(const struct llama
141145

142146
struct llama_sampler * llama_sampler_gpu_init_softmax() {
143147
static const llama_sampler_i iface = {
144-
/*.name =*/ llama_sampler_gpu_softmax_name,
145-
/*.accept =*/ nullptr,
146-
/*.apply =*/ nullptr,
147-
/*.reset =*/ nullptr,
148-
/*.clone =*/ llama_sampler_gpu_softmax_clone,
149-
/*.free =*/ llama_sampler_gpu_softmax_free,
150-
/*.apply_ggml =*/ llama_sampler_gpu_softmax_apply_ggml,
151-
/*.accept_ggml =*/ nullptr,
148+
/*.name =*/ llama_sampler_gpu_softmax_name,
149+
/*.accept =*/ nullptr,
150+
/*.apply =*/ nullptr,
151+
/*.reset =*/ nullptr,
152+
/*.clone =*/ llama_sampler_gpu_softmax_clone,
153+
/*.free =*/ llama_sampler_gpu_softmax_free,
154+
/*.apply_ggml =*/ llama_sampler_gpu_softmax_apply_ggml,
155+
/*.accept_ggml =*/ nullptr,
156+
/*.set_input_ggml =*/ nullptr,
152157
};
153158

154159
auto * ctx_data = new llama_sampler_gpu_softmax_ctx {
@@ -204,14 +209,15 @@ static struct llama_sampler * llama_sampler_gpu_top_k_clone(const struct llama_s
204209

205210
struct llama_sampler * llama_sampler_gpu_init_top_k(int32_t k) {
206211
static const llama_sampler_i iface = {
207-
/*.name =*/ llama_sampler_gpu_top_k_name,
208-
/*.accept =*/ nullptr,
209-
/*.apply =*/ nullptr,
210-
/*.reset =*/ nullptr,
211-
/*.clone =*/ llama_sampler_gpu_top_k_clone,
212-
/*.free =*/ llama_sampler_gpu_top_k_free,
213-
/*.apply_ggml =*/ llama_sampler_gpu_top_k_apply_ggml,
214-
/*.accept_ggml =*/ nullptr,
212+
/*.name =*/ llama_sampler_gpu_top_k_name,
213+
/*.accept =*/ nullptr,
214+
/*.apply =*/ nullptr,
215+
/*.reset =*/ nullptr,
216+
/*.clone =*/ llama_sampler_gpu_top_k_clone,
217+
/*.free =*/ llama_sampler_gpu_top_k_free,
218+
/*.apply_ggml =*/ llama_sampler_gpu_top_k_apply_ggml,
219+
/*.accept_ggml =*/ nullptr,
220+
/*.set_input_ggml =*/ nullptr,
215221
};
216222

217223
auto * ctx_data = new llama_sampler_gpu_top_k_ctx {
@@ -274,14 +280,15 @@ static struct llama_sampler * llama_sampler_gpu_top_p_clone(const struct llama_s
274280

275281
struct llama_sampler * llama_sampler_gpu_init_top_p(int32_t k) {
276282
static const llama_sampler_i iface = {
277-
/*.name =*/ llama_sampler_gpu_top_p_name,
278-
/*.accept =*/ nullptr,
279-
/*.apply =*/ nullptr,
280-
/*.reset =*/ nullptr,
281-
/*.clone =*/ llama_sampler_gpu_top_p_clone,
282-
/*.free =*/ llama_sampler_gpu_top_p_free,
283-
/*.apply_ggml =*/ llama_sampler_gpu_top_p_apply_ggml,
284-
/*.accept_ggml =*/ nullptr,
283+
/*.name =*/ llama_sampler_gpu_top_p_name,
284+
/*.accept =*/ nullptr,
285+
/*.apply =*/ nullptr,
286+
/*.reset =*/ nullptr,
287+
/*.clone =*/ llama_sampler_gpu_top_p_clone,
288+
/*.free =*/ llama_sampler_gpu_top_p_free,
289+
/*.apply_ggml =*/ llama_sampler_gpu_top_p_apply_ggml,
290+
/*.accept_ggml =*/ nullptr,
291+
/*.set_input_ggml =*/ nullptr,
285292
};
286293

287294
auto * ctx_data = new llama_sampler_gpu_top_p_ctx {
@@ -295,3 +302,130 @@ struct llama_sampler * llama_sampler_gpu_init_top_p(int32_t k) {
295302

296303
return sampler;
297304
}
305+
306+
static uint32_t get_rng_seed(uint32_t seed) {
307+
if (seed == LLAMA_DEFAULT_SEED) {
308+
// use system clock if std::random_device is not a true RNG
309+
static bool is_rd_prng = std::random_device().entropy() == 0;
310+
if (is_rd_prng) {
311+
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
312+
}
313+
std::random_device rd;
314+
return rd();
315+
}
316+
return seed;
317+
}
318+
319+
struct llama_sampler_gpu_dist_ctx {
320+
const uint32_t seed;
321+
uint32_t seed_cur;
322+
std::mt19937 rng;
323+
324+
struct ggml_tensor * uniform;
325+
};
326+
327+
static void llama_sampler_gpu_dist_apply_ggml(
328+
struct llama_sampler * smpl,
329+
struct ggml_context * ctx,
330+
struct ggml_cgraph * gf,
331+
struct llama_sampler_ggml_data * ggml_data) {
332+
auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx;
333+
printf("gpu dist: Building sampler with seed=%d\n", sctx->seed);
334+
335+
// Create the uniform random scalar input tensor. This will be set by
336+
// llama_sampler_gpu_dist_set_input_ggml after this graph is built, but
337+
// before it is executed.
338+
struct ggml_tensor * uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
339+
sctx->uniform = uniform;
340+
ggml_set_name(uniform, "uniform");
341+
ggml_set_input(uniform);
342+
ggml_set_output(uniform);
343+
344+
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
345+
ggml_set_name(softmax, "softmax");
346+
347+
struct ggml_tensor * cumsum = ggml_cumsum(ctx, softmax);
348+
ggml_set_name(cumsum, "cumsum");
349+
350+
// Broadcast the random uniform value to match cumsums’s shape
351+
struct ggml_tensor * rnd_rep = ggml_repeat(ctx, sctx->uniform, cumsum);
352+
ggml_set_name(rnd_rep, "dist_rand_rep");
353+
354+
// Each entry in rnd_rep has the random value in it so we subtract this
355+
// tensor with the cumsum tensor. Recall that each entry in cumsum is the
356+
// cumulative probability up to that index. While the entry is smaller than
357+
// the random value the difference is positive, but once we exceed the
358+
// random value the difference becomes zero or negative.
359+
struct ggml_tensor * diff = ggml_sub(ctx, rnd_rep, cumsum);
360+
ggml_set_name(diff, "dist_rnd_minus_cumsum");
361+
362+
// The ggml_step function produces a tensor where entries are 1 if the
363+
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
364+
// the index where the cumulative probability exceeds the random value are 1,
365+
// and all entries after that are 0.
366+
struct ggml_tensor * mask = ggml_step(ctx, diff);
367+
ggml_set_name(mask, "dist_mask");
368+
369+
// Taking the sum of the mask gives us the index entry where the cumulative
370+
// threshold is first exceeded and this is our sampled token index as a float.
371+
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
372+
ggml_set_name(idxf, "dist_index_f32");
373+
374+
// Cast the float index to integer.
375+
struct ggml_tensor * idx = ggml_cast(ctx, idxf, GGML_TYPE_I32);
376+
ggml_set_name(idx, "dist_index_i32");
377+
ggml_set_output(idx);
378+
ggml_data->sampled_token = idx;
379+
}
380+
381+
static const char * llama_sampler_gpu_dist_name(const struct llama_sampler *) {
382+
return "gpu-dist";
383+
}
384+
385+
static void llama_sampler_gpu_dist_free(struct llama_sampler * smpl) {
386+
auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx;
387+
delete sctx;
388+
}
389+
390+
static struct llama_sampler * llama_sampler_gpu_dist_clone(const struct llama_sampler * smpl) {
391+
auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx;
392+
return llama_sampler_gpu_init_dist(sctx->seed);
393+
}
394+
395+
static void llama_sampler_gpu_dist_set_input_ggml(struct llama_sampler * smpl) {
396+
auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx;
397+
GGML_ASSERT(sctx->uniform != nullptr);
398+
399+
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
400+
const float rnd = dist(sctx->rng);
401+
ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(rnd));
402+
}
403+
404+
struct llama_sampler * llama_sampler_gpu_init_dist(uint32_t seed) {
405+
static const llama_sampler_i iface = {
406+
/*.name =*/ llama_sampler_gpu_dist_name,
407+
/*.accept =*/ nullptr,
408+
/*.apply =*/ nullptr,
409+
/*.reset =*/ nullptr,
410+
/*.clone =*/ llama_sampler_gpu_dist_clone,
411+
/*.free =*/ llama_sampler_gpu_dist_free,
412+
/*.apply_ggml =*/ llama_sampler_gpu_dist_apply_ggml,
413+
/*.accept_ggml =*/ nullptr,
414+
/*.set_input_ggml =*/ llama_sampler_gpu_dist_set_input_ggml,
415+
};
416+
417+
auto seed_cur = get_rng_seed(seed);
418+
auto * ctx_data = new llama_sampler_gpu_dist_ctx {
419+
/*.seed =*/ seed,
420+
/*.seed_cur =*/ seed_cur,
421+
/*.rng =*/ std::mt19937(seed_cur),
422+
/*.random =*/ nullptr,
423+
};
424+
425+
auto * sampler = new llama_sampler {
426+
/*.iface =*/ &iface,
427+
/*.ctx =*/ ctx_data,
428+
};
429+
430+
return sampler;
431+
}

src/llama-gpu-sampling.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ LLAMA_API struct llama_sampler * llama_sampler_gpu_init_top_k(int32_t k);
2020
// TODO: implement real top-p sampling on GPU.
2121
LLAMA_API struct llama_sampler * llama_sampler_gpu_init_top_p(int32_t k);
2222

23+
LLAMA_API struct llama_sampler * llama_sampler_gpu_init_dist(uint32_t seed);
24+
2325
#ifdef __cplusplus
2426
}
2527
#endif

src/llama-graph.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,12 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
464464

465465
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
466466
GGML_UNUSED(ubatch);
467+
for (const auto & [seq_id, sampler] : samplers) {
468+
if (sampler->iface->set_input_ggml) {
469+
printf("llm_graph_input_sampling::set_input: setting sampler input for seq_id %d\n", seq_id);
470+
sampler->iface->set_input_ggml(sampler);
471+
}
472+
}
467473
}
468474

469475
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {

src/llama-graph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
387387
class llm_graph_input_sampling : public llm_graph_input_i {
388388
public:
389389
llm_graph_input_sampling(int32_t n_vocab, bool sorted, std::unordered_map<llama_seq_id, llama_sampler*> samplers) :
390-
n_vocab(n_vocab), sorted_value(sorted) {
390+
n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) {
391391

392392
sampler_versions.reserve(samplers.size());
393393
for (const auto & [seq_id, sampler] : samplers) {
@@ -406,6 +406,7 @@ class llm_graph_input_sampling : public llm_graph_input_i {
406406

407407
// Track sampler chain version for reuse
408408
std::unordered_map<llama_seq_id, uint64_t> sampler_versions;
409+
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
409410
};
410411

411412
//

0 commit comments

Comments
 (0)