@@ -1196,27 +1196,26 @@ class TokenStreamer {
1196
1196
hwy::BitSet4096<> is_eos_;
1197
1197
};
1198
1198
1199
- HWY_INLINE SampleFunc ChooseSampleFunc (int top_k,
1200
- const RuntimeConfig& runtime_config) {
1199
+ HWY_INLINE SampleFunc ChooseSampleFunc (const RuntimeConfig& runtime_config) {
1201
1200
// If user provided a sample_func, use it.
1202
1201
if (runtime_config.sample_func ) return runtime_config.sample_func ;
1203
1202
1204
1203
// Fast path for top-1 with no accept_token.
1205
- if (top_k == 1 && !runtime_config.accept_token ) {
1204
+ if (runtime_config. top_k == 1 && !runtime_config.accept_token ) {
1206
1205
return [](float * logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
1207
1206
PROFILER_ZONE (" Gen.Sample Top1" );
1208
1207
return Top1OfSoftmax (logits, vocab_size);
1209
1208
};
1210
1209
}
1211
1210
1212
1211
// General case: Softmax with top-k sampling.
1213
- return [top_k, &runtime_config](float * logits,
1214
- size_t vocab_size) HWY_ATTR -> TokenAndProb {
1212
+ return [&runtime_config](float * logits,
1213
+ size_t vocab_size) HWY_ATTR -> TokenAndProb {
1215
1214
PROFILER_ZONE (" Gen.Sample general" );
1216
1215
Softmax (logits, vocab_size);
1217
- const int token =
1218
- SampleTopK ( logits, top_k, vocab_size, *runtime_config.gen ,
1219
- runtime_config.temperature , runtime_config.accept_token );
1216
+ const int token = SampleTopK (
1217
+ logits, runtime_config. top_k , vocab_size, *runtime_config.gen ,
1218
+ runtime_config.temperature , runtime_config.accept_token );
1220
1219
return TokenAndProb{.token = token, .prob = logits[token]};
1221
1220
};
1222
1221
}
@@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
1276
1275
size_t max_prompt_size = MaxQueryLength (queries_prompt);
1277
1276
size_t max_generated_tokens = runtime_config.max_generated_tokens ;
1278
1277
RangeChecks (weights.weights_config , max_generated_tokens, max_prompt_size);
1279
- const SampleFunc sample_token =
1280
- ChooseSampleFunc (weights.weights_config .top_k , runtime_config);
1278
+ const SampleFunc sample_token = ChooseSampleFunc (runtime_config);
1281
1279
1282
1280
// Prefill stops before min_prompt_size - 1 because the last prompt
1283
1281
// token is the first input token for generation.
0 commit comments