11#include " llama-gpu-sampling.h"
22#include " ggml.h"
33#include < cstdio>
4+ #include < chrono>
5+ #include < random>
46
57static void llama_sampler_gpu_greedy_apply_ggml (
68 struct llama_sampler * smpl,
@@ -27,14 +29,15 @@ static struct llama_sampler * llama_sampler_gpu_greedy_clone(const struct llama_
2729
2830struct llama_sampler * llama_sampler_gpu_init_greedy () {
2931 static const llama_sampler_i iface = {
30- /* .name =*/ llama_sampler_gpu_greedy_sampler_name,
31- /* .accept =*/ nullptr ,
32- /* .apply =*/ nullptr ,
33- /* .reset =*/ nullptr ,
34- /* .clone =*/ llama_sampler_gpu_greedy_clone,
35- /* .free =*/ nullptr ,
36- /* .apply_ggml =*/ llama_sampler_gpu_greedy_apply_ggml,
37- /* .accept_ggml =*/ nullptr ,
32+ /* .name =*/ llama_sampler_gpu_greedy_sampler_name,
33+ /* .accept =*/ nullptr ,
34+ /* .apply =*/ nullptr ,
35+ /* .reset =*/ nullptr ,
36+ /* .clone =*/ llama_sampler_gpu_greedy_clone,
37+ /* .free =*/ nullptr ,
38+ /* .apply_ggml =*/ llama_sampler_gpu_greedy_apply_ggml,
39+ /* .accept_ggml =*/ nullptr ,
40+ /* .set_input_ggml =*/ nullptr ,
3841 };
3942
4043 auto * sampler = new llama_sampler {
@@ -85,14 +88,15 @@ static struct llama_sampler * llama_sampler_gpu_temp_clone(const struct llama_sa
8588
8689struct llama_sampler * llama_sampler_gpu_init_temp (float temp) {
8790 static const llama_sampler_i iface = {
88- /* .name =*/ llama_sampler_gpu_temp_name,
89- /* .accept =*/ nullptr ,
90- /* .apply =*/ nullptr ,
91- /* .reset =*/ nullptr ,
92- /* .clone =*/ llama_sampler_gpu_temp_clone,
93- /* .free =*/ llama_sampler_gpu_temp_free,
94- /* .apply_ggml =*/ llama_sampler_gpu_temp_apply_ggml,
95- /* .accept_ggml =*/ nullptr ,
91+ /* .name =*/ llama_sampler_gpu_temp_name,
92+ /* .accept =*/ nullptr ,
93+ /* .apply =*/ nullptr ,
94+ /* .reset =*/ nullptr ,
95+ /* .clone =*/ llama_sampler_gpu_temp_clone,
96+ /* .free =*/ llama_sampler_gpu_temp_free,
97+ /* .apply_ggml =*/ llama_sampler_gpu_temp_apply_ggml,
98+ /* .accept_ggml =*/ nullptr ,
99+ /* .set_input_ggml =*/ nullptr ,
96100 };
97101
98102 auto * ctx_data = new llama_sampler_gpu_temp_ctx {
@@ -141,14 +145,15 @@ static struct llama_sampler * llama_sampler_gpu_softmax_clone(const struct llama
141145
142146struct llama_sampler * llama_sampler_gpu_init_softmax () {
143147 static const llama_sampler_i iface = {
144- /* .name =*/ llama_sampler_gpu_softmax_name,
145- /* .accept =*/ nullptr ,
146- /* .apply =*/ nullptr ,
147- /* .reset =*/ nullptr ,
148- /* .clone =*/ llama_sampler_gpu_softmax_clone,
149- /* .free =*/ llama_sampler_gpu_softmax_free,
150- /* .apply_ggml =*/ llama_sampler_gpu_softmax_apply_ggml,
151- /* .accept_ggml =*/ nullptr ,
148+ /* .name =*/ llama_sampler_gpu_softmax_name,
149+ /* .accept =*/ nullptr ,
150+ /* .apply =*/ nullptr ,
151+ /* .reset =*/ nullptr ,
152+ /* .clone =*/ llama_sampler_gpu_softmax_clone,
153+ /* .free =*/ llama_sampler_gpu_softmax_free,
154+ /* .apply_ggml =*/ llama_sampler_gpu_softmax_apply_ggml,
155+ /* .accept_ggml =*/ nullptr ,
156+ /* .set_input_ggml =*/ nullptr ,
152157 };
153158
154159 auto * ctx_data = new llama_sampler_gpu_softmax_ctx {
@@ -204,14 +209,15 @@ static struct llama_sampler * llama_sampler_gpu_top_k_clone(const struct llama_s
204209
205210struct llama_sampler * llama_sampler_gpu_init_top_k (int32_t k) {
206211 static const llama_sampler_i iface = {
207- /* .name =*/ llama_sampler_gpu_top_k_name,
208- /* .accept =*/ nullptr ,
209- /* .apply =*/ nullptr ,
210- /* .reset =*/ nullptr ,
211- /* .clone =*/ llama_sampler_gpu_top_k_clone,
212- /* .free =*/ llama_sampler_gpu_top_k_free,
213- /* .apply_ggml =*/ llama_sampler_gpu_top_k_apply_ggml,
214- /* .accept_ggml =*/ nullptr ,
212+ /* .name =*/ llama_sampler_gpu_top_k_name,
213+ /* .accept =*/ nullptr ,
214+ /* .apply =*/ nullptr ,
215+ /* .reset =*/ nullptr ,
216+ /* .clone =*/ llama_sampler_gpu_top_k_clone,
217+ /* .free =*/ llama_sampler_gpu_top_k_free,
218+ /* .apply_ggml =*/ llama_sampler_gpu_top_k_apply_ggml,
219+ /* .accept_ggml =*/ nullptr ,
220+ /* .set_input_ggml =*/ nullptr ,
215221 };
216222
217223 auto * ctx_data = new llama_sampler_gpu_top_k_ctx {
@@ -274,14 +280,15 @@ static struct llama_sampler * llama_sampler_gpu_top_p_clone(const struct llama_s
274280
275281struct llama_sampler * llama_sampler_gpu_init_top_p (int32_t k) {
276282 static const llama_sampler_i iface = {
277- /* .name =*/ llama_sampler_gpu_top_p_name,
278- /* .accept =*/ nullptr ,
279- /* .apply =*/ nullptr ,
280- /* .reset =*/ nullptr ,
281- /* .clone =*/ llama_sampler_gpu_top_p_clone,
282- /* .free =*/ llama_sampler_gpu_top_p_free,
283- /* .apply_ggml =*/ llama_sampler_gpu_top_p_apply_ggml,
284- /* .accept_ggml =*/ nullptr ,
283+ /* .name =*/ llama_sampler_gpu_top_p_name,
284+ /* .accept =*/ nullptr ,
285+ /* .apply =*/ nullptr ,
286+ /* .reset =*/ nullptr ,
287+ /* .clone =*/ llama_sampler_gpu_top_p_clone,
288+ /* .free =*/ llama_sampler_gpu_top_p_free,
289+ /* .apply_ggml =*/ llama_sampler_gpu_top_p_apply_ggml,
290+ /* .accept_ggml =*/ nullptr ,
291+ /* .set_input_ggml =*/ nullptr ,
285292 };
286293
287294 auto * ctx_data = new llama_sampler_gpu_top_p_ctx {
@@ -295,3 +302,131 @@ struct llama_sampler * llama_sampler_gpu_init_top_p(int32_t k) {
295302
296303 return sampler;
297304}
305+
306+ static uint32_t get_rng_seed (uint32_t seed) {
307+ if (seed == LLAMA_DEFAULT_SEED) {
308+ // use system clock if std::random_device is not a true RNG
309+ static bool is_rd_prng = std::random_device ().entropy () == 0 ;
310+ if (is_rd_prng) {
311+ return (uint32_t ) std::chrono::system_clock::now ().time_since_epoch ().count ();
312+ }
313+ std::random_device rd;
314+ return rd ();
315+ }
316+ return seed;
317+ }
318+
319+ struct llama_sampler_gpu_dist_ctx {
320+ const uint32_t seed;
321+ uint32_t seed_cur;
322+ std::mt19937 rng;
323+
324+ struct ggml_tensor * uniform;
325+ };
326+
327+ static void llama_sampler_gpu_dist_apply_ggml (
328+ struct llama_sampler * smpl,
329+ struct ggml_context * ctx,
330+ struct ggml_cgraph * gf,
331+ struct llama_sampler_ggml_data * ggml_data) {
332+ GGML_UNUSED (gf);
333+ auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx ;
334+ printf (" gpu dist: Building sampler with seed=%d\n " , sctx->seed );
335+
336+ // Create the uniform random scalar input tensor. This will be set by
337+ // llama_sampler_gpu_dist_set_input_ggml after this graph is built, but
338+ // before it is executed.
339+ struct ggml_tensor * uniform = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 );
340+ sctx->uniform = uniform;
341+ ggml_set_name (uniform, " uniform" );
342+ ggml_set_input (uniform);
343+ ggml_set_output (uniform);
344+
345+ struct ggml_tensor * softmax = ggml_soft_max (ctx, ggml_data->logits );
346+ ggml_set_name (softmax, " softmax" );
347+
348+ struct ggml_tensor * cumsum = ggml_cumsum (ctx, softmax);
349+ ggml_set_name (cumsum, " cumsum" );
350+
351+ // Broadcast the random uniform value to match cumsums’s shape
352+ struct ggml_tensor * rnd_rep = ggml_repeat (ctx, sctx->uniform , cumsum);
353+ ggml_set_name (rnd_rep, " dist_rand_rep" );
354+
355+ // Each entry in rnd_rep has the random value in it so we subtract this
356+ // tensor with the cumsum tensor. Recall that each entry in cumsum is the
357+ // cumulative probability up to that index. While the entry is smaller than
358+ // the random value the difference is positive, but once we exceed the
359+ // random value the difference becomes zero or negative.
360+ struct ggml_tensor * diff = ggml_sub (ctx, rnd_rep, cumsum);
361+ ggml_set_name (diff, " dist_rnd_minus_cumsum" );
362+
363+ // The ggml_step function produces a tensor where entries are 1 if the
364+ // corresponding entry in diff is > 0, and 0 otherwise. So all values up to
365+ // the index where the cumulative probability exceeds the random value are 1,
366+ // and all entries after that are 0.
367+ struct ggml_tensor * mask = ggml_step (ctx, diff);
368+ ggml_set_name (mask, " dist_mask" );
369+
370+ // Taking the sum of the mask gives us the index entry where the cumulative
371+ // threshold is first exceeded and this is our sampled token index as a float.
372+ struct ggml_tensor * idxf = ggml_sum (ctx, mask);
373+ ggml_set_name (idxf, " dist_index_f32" );
374+
375+ // Cast the float index to integer.
376+ struct ggml_tensor * idx = ggml_cast (ctx, idxf, GGML_TYPE_I32);
377+ ggml_set_name (idx, " dist_index_i32" );
378+ ggml_set_output (idx);
379+ ggml_data->sampled_token = idx;
380+ }
381+
382+ static const char * llama_sampler_gpu_dist_name (const struct llama_sampler *) {
383+ return " gpu-dist" ;
384+ }
385+
386+ static void llama_sampler_gpu_dist_free (struct llama_sampler * smpl) {
387+ auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx ;
388+ delete sctx;
389+ }
390+
391+ static struct llama_sampler * llama_sampler_gpu_dist_clone (const struct llama_sampler * smpl) {
392+ auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx ;
393+ return llama_sampler_gpu_init_dist (sctx->seed );
394+ }
395+
396+ static void llama_sampler_gpu_dist_set_input_ggml (struct llama_sampler * smpl) {
397+ auto * sctx = (llama_sampler_gpu_dist_ctx *) smpl->ctx ;
398+ GGML_ASSERT (sctx->uniform != nullptr );
399+
400+ std::uniform_real_distribution<float > dist (0 .0f , 1 .0f );
401+ const float rnd = dist (sctx->rng );
402+ ggml_backend_tensor_set (sctx->uniform , &rnd, 0 , sizeof (rnd));
403+ }
404+
405+ struct llama_sampler * llama_sampler_gpu_init_dist (uint32_t seed) {
406+ static const llama_sampler_i iface = {
407+ /* .name =*/ llama_sampler_gpu_dist_name,
408+ /* .accept =*/ nullptr ,
409+ /* .apply =*/ nullptr ,
410+ /* .reset =*/ nullptr ,
411+ /* .clone =*/ llama_sampler_gpu_dist_clone,
412+ /* .free =*/ llama_sampler_gpu_dist_free,
413+ /* .apply_ggml =*/ llama_sampler_gpu_dist_apply_ggml,
414+ /* .accept_ggml =*/ nullptr ,
415+ /* .set_input_ggml =*/ llama_sampler_gpu_dist_set_input_ggml,
416+ };
417+
418+ auto seed_cur = get_rng_seed (seed);
419+ auto * ctx_data = new llama_sampler_gpu_dist_ctx {
420+ /* .seed =*/ seed,
421+ /* .seed_cur =*/ seed_cur,
422+ /* .rng =*/ std::mt19937 (seed_cur),
423+ /* .random =*/ nullptr ,
424+ };
425+
426+ auto * sampler = new llama_sampler {
427+ /* .iface =*/ &iface,
428+ /* .ctx =*/ ctx_data,
429+ };
430+
431+ return sampler;
432+ }
0 commit comments