Skip to content

Commit 4dfbeed

Browse files
authored
[EAGLE] Initial support for EAGLE model (#1)
* [EAGLE] Initial support for EAGLE model Added the definition and ```decode``` function for the EAGLE model. Also added a sequence-based speculative decoding example using the EAGLE draft model under ```examples/speculative-simple-eagle```.
1 parent 10bb545 commit 4dfbeed

File tree

16 files changed

+1129
-0
lines changed

16 files changed

+1129
-0
lines changed

common/speculative.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,153 @@ llama_tokens common_speculative_gen_draft(
278278

279279
return result;
280280
}
281+
282+
llama_tokens common_speculative_gen_draft_eagle(
283+
struct common_speculative * spec,
284+
struct common_speculative_params params,
285+
const llama_tokens & prompt_tgt,
286+
llama_token id_last,
287+
std::vector<uint8_t> & data) {
288+
auto & batch = spec->batch;
289+
auto & ctx = spec->ctx;
290+
auto & smpl = spec->smpl;
291+
auto & prompt = spec->prompt;
292+
293+
auto * mem = llama_get_memory(ctx);
294+
295+
int reuse_i = 0;
296+
int reuse_n = 0;
297+
298+
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
299+
300+
const int i_start = std::max<int>(1, (int) prompt_tgt.size() - n_ctx);
301+
302+
int n_accepted_draft_tokens = data.size() / sizeof(float) / llama_model_n_embd(llama_get_model(ctx)) - 1;
303+
304+
// reuse as much as possible from the old draft context
305+
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
306+
for (int i = 0; i < (int) prompt.size(); ++i) {
307+
int cur = 0;
308+
while (i_start + cur < (int) prompt_tgt.size() &&
309+
i + cur < (int) prompt.size() &&
310+
prompt_tgt[i_start + cur] == prompt[i + cur]) {
311+
cur++;
312+
}
313+
314+
cur = (cur - n_accepted_draft_tokens) > 0 ? (cur - n_accepted_draft_tokens) : cur;
315+
316+
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
317+
reuse_i = i;
318+
reuse_n = cur;
319+
}
320+
}
321+
322+
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
323+
324+
llama_tokens result;
325+
result.reserve(params.n_draft);
326+
327+
if (reuse_n == 0) {
328+
llama_memory_clear(mem, false);
329+
330+
prompt.clear();
331+
} else {
332+
// this happens when a previous draft has been discarded (for example, due to being too small), but the
333+
// target model agreed with it. in this case, we simply pass back the previous results to save compute
334+
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
335+
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
336+
result.push_back(prompt[i]);
337+
338+
if (params.n_draft <= (int) result.size()) {
339+
break;
340+
}
341+
}
342+
343+
return result;
344+
}
345+
346+
if (reuse_i > 0) {
347+
llama_memory_seq_rm (mem, 0, 0, reuse_i);
348+
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
349+
350+
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
351+
}
352+
353+
if (reuse_n < (int) prompt.size()) {
354+
llama_memory_seq_rm (mem, 0, reuse_n, -1);
355+
356+
prompt.erase(prompt.begin() + reuse_n, prompt.end());
357+
}
358+
}
359+
360+
// prepare a batch to evaluate any new tokens in the prompt
361+
common_batch_clear(batch);
362+
363+
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
364+
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
365+
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, (i < prompt_tgt.size() - 1) ? false : true);
366+
367+
prompt.push_back(prompt_tgt[i]);
368+
}
369+
370+
// we should rarely end-up here during normal decoding
371+
if (batch.n_tokens > 0) {
372+
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
373+
374+
llama_decode_eagle(ctx, batch, data.data());
375+
}
376+
377+
const llama_pos n_past = prompt.size();
378+
379+
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
380+
381+
common_batch_clear(batch);
382+
common_batch_add (batch, id_last, n_past, { 0 }, true);
383+
384+
prompt.push_back(id_last);
385+
386+
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
387+
388+
llama_decode_eagle(ctx, batch, data.data());
389+
390+
common_sampler_reset(smpl);
391+
392+
// sample n_draft tokens from the draft model
393+
for (int i = 0; i < params.n_draft; ++i) {
394+
common_batch_clear(batch);
395+
396+
common_sampler_sample(smpl, ctx, -1, true);
397+
398+
const auto * cur_p = common_sampler_get_candidates(smpl);
399+
400+
for (int k = 0; k < std::min(1, (int) cur_p->size); ++k) {
401+
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
402+
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
403+
}
404+
405+
// add drafted token for each sequence
406+
const llama_token id = cur_p->data[0].id;
407+
408+
common_sampler_accept(smpl, id, true);
409+
410+
result.push_back(id);
411+
412+
if (params.n_draft <= (int) result.size()) {
413+
break;
414+
}
415+
416+
// only collect very high-confidence draft tokens
417+
if (cur_p->data[0].p < params.p_min) {
418+
break;
419+
}
420+
421+
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
422+
423+
// evaluate the drafted tokens on the draft model
424+
llama_decode_eagle(ctx, batch, data.data());
425+
426+
prompt.push_back(id);
427+
}
428+
429+
return result;
430+
}

common/speculative.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@ llama_tokens common_speculative_gen_draft(
2626
struct common_speculative_params params,
2727
const llama_tokens & prompt,
2828
llama_token id_last);
29+
30+
llama_tokens common_speculative_gen_draft_eagle(
31+
struct common_speculative * spec,
32+
struct common_speculative_params params,
33+
const llama_tokens & prompt,
34+
llama_token id_last,
35+
std::vector<uint8_t> & data);

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ else()
3131
add_subdirectory(simple-chat)
3232
add_subdirectory(speculative)
3333
add_subdirectory(speculative-simple)
34+
add_subdirectory(speculative-simple-eagle)
3435
add_subdirectory(gen-docs)
3536
add_subdirectory(training)
3637
if (NOT GGML_BACKEND_DL)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-speculative-simple-eagle)
2+
add_executable(${TARGET} speculative-simple-eagle.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# llama.cpp/examples/speculative-simple-eagle
2+
3+
Demonstration of basic greedy speculative decoding for EAGLE
4+
5+
```bash
6+
./bin/llama-speculative-simple-eagle \
7+
-m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
8+
-md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
9+
-f test.txt -c 0 -ngl 99 --color \
10+
--sampling-seq k --top-k 1 -fa --temp 0.0 \
11+
-ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9
12+
```

0 commit comments

Comments
 (0)