Skip to content

Commit 719699f

Browse files
danielkeyserscopybara-github
authored andcommitted
Make top_k a runtime argument (instead of a model argument).
PiperOrigin-RevId: 696170691
1 parent b94295b commit 719699f

File tree

13 files changed

+31
-26
lines changed

13 files changed

+31
-26
lines changed

backprop/optimize_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ TEST(OptimizeTest, GradientDescent) {
7474
RuntimeConfig runtime = {
7575
.max_generated_tokens = 16,
7676
.temperature = 1.0f,
77-
.verbosity = 0,
7877
.gen = &gen,
78+
.verbosity = 0,
7979
.stream_token = stream_token,
8080
.eos_id = ReverseSequenceSampler::kEndToken,
8181
};

evals/benchmark_helper.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
7474
runtime_config_ = {
7575
.max_generated_tokens = inference.max_generated_tokens,
7676
.temperature = inference.temperature,
77-
.verbosity = app.verbosity,
7877
.gen = &gen_,
78+
.verbosity = app.verbosity,
7979
};
8080
}
8181

evals/cross_entropy.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
139139
RuntimeConfig runtime = {
140140
.max_generated_tokens = max_generated_tokens - 1,
141141
.temperature = 0.0f,
142-
.verbosity = verbosity,
143142
.gen = nullptr,
143+
.verbosity = verbosity,
144144
.stream_token = stream_token,
145145
.sample_func = sample_token,
146146
};

evals/gemma_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ TEST_F(GemmaTest, Multiturn) {
169169
RuntimeConfig runtime_config{
170170
.max_generated_tokens = 64,
171171
.temperature = 0.0f,
172-
.verbosity = 2,
173172
.gen = &s_env->MutableGen(),
173+
.verbosity = 2,
174174
.stream_token = stream_token,
175175
};
176176
TimingInfo timing_info{.verbosity = 0};

evals/run_mmlu.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ void Run(GemmaEnv& env, JsonArgs& json) {
127127
gcpp::RuntimeConfig runtime_config = {
128128
.max_generated_tokens = 30,
129129
.temperature = 0.0f,
130-
.verbosity = env.Verbosity(),
131130
.gen = &env.MutableGen(),
131+
.verbosity = env.Verbosity(),
132132
.stream_token = stream_token,
133133
};
134134
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,

examples/hello_world/run.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ int main(int argc, char** argv) {
8989
gcpp::RuntimeConfig runtime_config = {
9090
.max_generated_tokens = 1024,
9191
.temperature = 1.0,
92-
.verbosity = 0,
9392
.gen = &gen,
93+
.verbosity = 0,
9494
.stream_token = stream_token,
9595
.accept_token =
9696
[&](int token, float /* prob */) {

gemma/configs.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ struct ModelConfig {
178178
size_t vit_seq_len = 0;
179179
size_t num_tensor_scales = 0;
180180
size_t num_vit_scales = 0;
181-
size_t top_k = kTopK;
182181
float att_cap = 0.0f;
183182
float final_cap = 0.0f;
184183
bool absolute_pe = false;

gemma/configs_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ void AssertMatch(const ModelConfig& config) {
374374
}
375375
ASSERT_EQ(TConfig::kVocabSize, config.vocab_size);
376376
ASSERT_EQ(TConfig::kSeqLen, config.seq_len);
377-
ASSERT_EQ(TConfig::kTopK, config.top_k);
377+
// ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value.
378378
ASSERT_EQ(TConfig::kAttCap, config.att_cap);
379379
ASSERT_EQ(TConfig::kFinalCap, config.final_cap);
380380
ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe);

gemma/gemma-inl.h

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,27 +1196,26 @@ class TokenStreamer {
11961196
hwy::BitSet4096<> is_eos_;
11971197
};
11981198

1199-
HWY_INLINE SampleFunc ChooseSampleFunc(int top_k,
1200-
const RuntimeConfig& runtime_config) {
1199+
HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
12011200
// If user provided a sample_func, use it.
12021201
if (runtime_config.sample_func) return runtime_config.sample_func;
12031202

12041203
// Fast path for top-1 with no accept_token.
1205-
if (top_k == 1 && !runtime_config.accept_token) {
1204+
if (runtime_config.top_k == 1 && !runtime_config.accept_token) {
12061205
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
12071206
PROFILER_ZONE("Gen.Sample Top1");
12081207
return Top1OfSoftmax(logits, vocab_size);
12091208
};
12101209
}
12111210

12121211
// General case: Softmax with top-k sampling.
1213-
return [top_k, &runtime_config](float* logits,
1214-
size_t vocab_size) HWY_ATTR -> TokenAndProb {
1212+
return [&runtime_config](float* logits,
1213+
size_t vocab_size) HWY_ATTR -> TokenAndProb {
12151214
PROFILER_ZONE("Gen.Sample general");
12161215
Softmax(logits, vocab_size);
1217-
const int token =
1218-
SampleTopK(logits, top_k, vocab_size, *runtime_config.gen,
1219-
runtime_config.temperature, runtime_config.accept_token);
1216+
const int token = SampleTopK(
1217+
logits, runtime_config.top_k, vocab_size, *runtime_config.gen,
1218+
runtime_config.temperature, runtime_config.accept_token);
12201219
return TokenAndProb{.token = token, .prob = logits[token]};
12211220
};
12221221
}
@@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
12761275
size_t max_prompt_size = MaxQueryLength(queries_prompt);
12771276
size_t max_generated_tokens = runtime_config.max_generated_tokens;
12781277
RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size);
1279-
const SampleFunc sample_token =
1280-
ChooseSampleFunc(weights.weights_config.top_k, runtime_config);
1278+
const SampleFunc sample_token = ChooseSampleFunc(runtime_config);
12811279

12821280
// Prefill stops before min_prompt_size - 1 because the last prompt
12831281
// token is the first input token for generation.

gemma/gemma.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "compression/io.h" // Path
2626
#include "gemma/activations.h"
2727
#include "gemma/common.h"
28+
#include "gemma/configs.h"
2829
#include "gemma/kv_cache.h"
2930
#include "gemma/tokenizer.h"
3031
#include "gemma/weights.h"
@@ -102,9 +103,12 @@ struct RuntimeConfig {
102103
// Max queries per batch (one token from each) during decode.
103104
size_t decode_qbatch_size = 16;
104105

105-
float temperature; // Temperature for sampling.
106+
// Sampling-related parameters.
107+
float temperature; // Temperature for sampling.
108+
size_t top_k = kTopK; // Top-k for sampling.
109+
std::mt19937* gen; // Random number generator used for sampling.
110+
106111
int verbosity; // Controls verbosity of printed messages.
107-
std::mt19937* gen; // Random number generator used for sampling.
108112

109113
// Functions operating on the generated tokens.
110114
StreamFunc stream_token;

0 commit comments

Comments
 (0)