@@ -278,3 +278,153 @@ llama_tokens common_speculative_gen_draft(
278278
279279 return result;
280280}
281+
282+ llama_tokens common_speculative_gen_draft_eagle (
283+ struct common_speculative * spec,
284+ struct common_speculative_params params,
285+ const llama_tokens & prompt_tgt,
286+ llama_token id_last,
287+ std::vector<uint8_t > & data) {
288+ auto & batch = spec->batch ;
289+ auto & ctx = spec->ctx ;
290+ auto & smpl = spec->smpl ;
291+ auto & prompt = spec->prompt ;
292+
293+ auto * mem = llama_get_memory (ctx);
294+
295+ int reuse_i = 0 ;
296+ int reuse_n = 0 ;
297+
298+ const int n_ctx = llama_n_ctx (ctx) - params.n_draft ;
299+
300+ const int i_start = std::max<int >(1 , (int ) prompt_tgt.size () - n_ctx);
301+
302+ int n_accepted_draft_tokens = data.size () / sizeof (float ) / llama_model_n_embd (llama_get_model (ctx)) - 1 ;
303+
304+ // reuse as much as possible from the old draft context
305+ // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
306+ for (int i = 0 ; i < (int ) prompt.size (); ++i) {
307+ int cur = 0 ;
308+ while (i_start + cur < (int ) prompt_tgt.size () &&
309+ i + cur < (int ) prompt.size () &&
310+ prompt_tgt[i_start + cur] == prompt[i + cur]) {
311+ cur++;
312+ }
313+
314+ cur = (cur - n_accepted_draft_tokens) > 0 ? (cur - n_accepted_draft_tokens) : cur;
315+
316+ if ((cur >= params.n_reuse || n_ctx >= (int ) prompt_tgt.size ()) && cur > reuse_n) {
317+ reuse_i = i;
318+ reuse_n = cur;
319+ }
320+ }
321+
322+ LOG_DBG (" %s: reuse_i = %d, reuse_n = %d, prompt = %d\n " , __func__, reuse_i, reuse_n, (int ) prompt.size ());
323+
324+ llama_tokens result;
325+ result.reserve (params.n_draft );
326+
327+ if (reuse_n == 0 ) {
328+ llama_memory_clear (mem, false );
329+
330+ prompt.clear ();
331+ } else {
332+ // this happens when a previous draft has been discarded (for example, due to being too small), but the
333+ // target model agreed with it. in this case, we simply pass back the previous results to save compute
334+ if (reuse_i + reuse_n < (int ) prompt.size () && prompt[reuse_i + reuse_n] == id_last) {
335+ for (int i = reuse_i + reuse_n + 1 ; i < (int ) prompt.size (); ++i) {
336+ result.push_back (prompt[i]);
337+
338+ if (params.n_draft <= (int ) result.size ()) {
339+ break ;
340+ }
341+ }
342+
343+ return result;
344+ }
345+
346+ if (reuse_i > 0 ) {
347+ llama_memory_seq_rm (mem, 0 , 0 , reuse_i);
348+ llama_memory_seq_add (mem, 0 , reuse_i, -1 , -reuse_i);
349+
350+ prompt.erase (prompt.begin (), prompt.begin () + reuse_i);
351+ }
352+
353+ if (reuse_n < (int ) prompt.size ()) {
354+ llama_memory_seq_rm (mem, 0 , reuse_n, -1 );
355+
356+ prompt.erase (prompt.begin () + reuse_n, prompt.end ());
357+ }
358+ }
359+
360+ // prepare a batch to evaluate any new tokens in the prompt
361+ common_batch_clear (batch);
362+
363+ for (size_t i = i_start + reuse_n; i < prompt_tgt.size (); ++i) {
364+ // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
365+ common_batch_add (batch, prompt_tgt[i], i - i_start, { 0 }, (i < prompt_tgt.size () - 1 ) ? false : true );
366+
367+ prompt.push_back (prompt_tgt[i]);
368+ }
369+
370+ // we should rarely end-up here during normal decoding
371+ if (batch.n_tokens > 0 ) {
372+ // LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
373+
374+ llama_decode_eagle (ctx, batch, data.data ());
375+ }
376+
377+ const llama_pos n_past = prompt.size ();
378+
379+ LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
380+
381+ common_batch_clear (batch);
382+ common_batch_add (batch, id_last, n_past, { 0 }, true );
383+
384+ prompt.push_back (id_last);
385+
386+ // LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
387+
388+ llama_decode_eagle (ctx, batch, data.data ());
389+
390+ common_sampler_reset (smpl);
391+
392+ // sample n_draft tokens from the draft model
393+ for (int i = 0 ; i < params.n_draft ; ++i) {
394+ common_batch_clear (batch);
395+
396+ common_sampler_sample (smpl, ctx, -1 , true );
397+
398+ const auto * cur_p = common_sampler_get_candidates (smpl);
399+
400+ for (int k = 0 ; k < std::min (1 , (int ) cur_p->size ); ++k) {
401+ LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
402+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx, cur_p->data [k].id ).c_str ());
403+ }
404+
405+ // add drafted token for each sequence
406+ const llama_token id = cur_p->data [0 ].id ;
407+
408+ common_sampler_accept (smpl, id, true );
409+
410+ result.push_back (id);
411+
412+ if (params.n_draft <= (int ) result.size ()) {
413+ break ;
414+ }
415+
416+ // only collect very high-confidence draft tokens
417+ if (cur_p->data [0 ].p < params.p_min ) {
418+ break ;
419+ }
420+
421+ common_batch_add (batch, id, n_past + i + 1 , { 0 }, true );
422+
423+ // evaluate the drafted tokens on the draft model
424+ llama_decode_eagle (ctx, batch, data.data ());
425+
426+ prompt.push_back (id);
427+ }
428+
429+ return result;
430+ }
0 commit comments