Skip to content

Commit d04847d

Browse files
committed
use ggml_div_inplace for normalization
This commit updates the llama_sampler_gpu_top_p_apply_ggml function to use ggml_div_inplace instead of ggml_div as this generated an error on webgpu backends: ```console /home/danbev/work/ai/llama.cpp-debug/ggml/src/ggml-webgpu/ggml-webgpu.cpp:2146: ggml_webgpu: Device error! Reason: 2, Message: Writable storage buffer binding aliasing found between [BindGroup "div_f32"] set at bind group index 0, binding index 1, and [BindGroup "div_f32"] set at bind group index 0, binding index 2, with overlapping ranges (offset: 0, size: 32) and (offset: 0, size: 32) in [Buffer "allocated_buffer"]. - While encoding [ComputePassEncoder (unlabeled)].DispatchWorkgroups(1, 1, 1). - While finishing [CommandEncoder (unlabeled)]. ``` It also sets ggml_data-filtered_ids as an output tensor as it might otherwise be reused before being read.
1 parent d53fa8d commit d04847d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/llama-gpu-sampling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,15 @@ static void llama_sampler_gpu_top_p_apply_ggml(
245245
struct ggml_tensor * top_k_ids = ggml_cont(ctx, ggml_top_k(ctx, softmax, ctx_data->k));
246246
ggml_set_name(top_k_ids, "top_k_ids");
247247
ggml_data->filtered_ids = top_k_ids;
248+
ggml_set_output(ggml_data->filtered_ids);
248249

249250
struct ggml_tensor * prob_rows = ggml_reshape_2d(ctx, softmax, 1, ggml_data->logits->ne[0]);
250251
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, prob_rows, top_k_ids);
251252
ggml_set_name(top_k_rows, "top_k_rows");
252253

253254
struct ggml_tensor * top_k = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
254255
struct ggml_tensor * total = ggml_sum(ctx, top_k);
255-
struct ggml_tensor * norm = ggml_div(ctx, top_k, ggml_repeat(ctx, total, top_k));
256+
struct ggml_tensor * norm = ggml_div_inplace(ctx, top_k, ggml_repeat(ctx, total, top_k));
256257
ggml_data->probs = norm;
257258
ggml_build_forward_expand(gf, ggml_data->probs);
258259
}

0 commit comments

Comments
 (0)