Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,29 @@ llama_tokens common_speculative_gen_draft(

common_sampler_reset(smpl);

// read in from an environment variable for now
const std::vector<float> batch_costs = []() {
std::vector<float> costs;
if (const char* env = std::getenv("GGML_BATCH_COSTS")) {
for (const char* p = env; *p; ) {
char* end;
costs.push_back(std::strtof(p, &end));
p = *end == ',' ? end + 1 : end;
}
}
return costs;
}();
GGML_ASSERT(batch_costs.size() >= 2 && "GGML_BATCH_COSTS must have at least 2 values");

// read in from an environment variable for now (default = 0)
const size_t max_look_ahead = std::getenv("GGML_MAX_LOOK_AHEAD") ? atoi(getenv("GGML_MAX_LOOK_AHEAD")) : 0;

// the current sequence probability, as predicted by the draft
float sequence_p = 1.0;

// the longest draft size we have seen that is +EV
size_t best_size = 0;

// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
Expand All @@ -335,8 +358,11 @@ llama_tokens common_speculative_gen_draft(
break;
}

// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
// only collect +EV draft tokens
sequence_p *= cur_p->data[0].p;
if (sequence_p > batch_costs[std::min(result.size(), batch_costs.size() - 1)]) {
best_size = result.size();
} else if (sequence_p <= batch_costs[std::min(result.size() + max_look_ahead, batch_costs.size() - 1)]) {
break;
}

Expand All @@ -348,6 +374,9 @@ llama_tokens common_speculative_gen_draft(
prompt_dft.push_back(id);
}

// truncate to the best we saw that was +EV
result.resize(best_size);

if (!spec->vocab_dft_compatible) {
std::string detokenized = common_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(spec, detokenized);
Expand Down
Loading