-
Notifications
You must be signed in to change notification settings - Fork 296
Adds support for gemma_270m to checkpoint converter #2380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
// Use IntelliSense to learn about possible attributes. | ||
// Hover to view descriptions of existing attributes. | ||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | ||
"version": "0.2.0", | ||
"configurations": [ | ||
{ | ||
"name": "Pytest: Prompt for test target", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"module": "pytest", | ||
"console": "integratedTerminal", | ||
"justMyCode": true, | ||
"args": [ | ||
"${input:pytestTarget}" | ||
] | ||
} | ||
], | ||
|
||
// 2. This 'inputs' array defines the prompt | ||
"inputs": [ | ||
{ | ||
"id": "pytestTarget", | ||
"type": "promptString", | ||
"description": "Enter the pytest target (file path or file::function)", | ||
"default": "" | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -10,8 +10,8 @@ | |||||||||||||||||||||||||||||
Usage: | ||||||||||||||||||||||||||||||
```shell | ||||||||||||||||||||||||||||||
cd tools/checkpoint_conversion | ||||||||||||||||||||||||||||||
python convert_gemma_checkpoints.py --preset gemma3_instruct_1b | ||||||||||||||||||||||||||||||
python convert_gemma_checkpoints.py --preset gemma3_instruct_4b | ||||||||||||||||||||||||||||||
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b | ||||||||||||||||||||||||||||||
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b | ||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -43,6 +43,15 @@ | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
PRESET_MAP = { | ||||||||||||||||||||||||||||||
# === Text === | ||||||||||||||||||||||||||||||
# 270M | ||||||||||||||||||||||||||||||
"gemma3_instruct_270m": { | ||||||||||||||||||||||||||||||
"model": gm.nn.Gemma3_270M, | ||||||||||||||||||||||||||||||
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_IT, | ||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||
"gemma3_270m": { | ||||||||||||||||||||||||||||||
"model": gm.nn.Gemma3_270M, | ||||||||||||||||||||||||||||||
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_PT, | ||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||
# 1B | ||||||||||||||||||||||||||||||
"gemma3_1b": { | ||||||||||||||||||||||||||||||
"model": gm.nn.Gemma3_1B, | ||||||||||||||||||||||||||||||
|
@@ -493,7 +502,6 @@ def validate_output( | |||||||||||||||||||||||||||||
params=flax_params, | ||||||||||||||||||||||||||||||
multi_turn=False, | ||||||||||||||||||||||||||||||
cache_length=256 if length <= 256 else 512, | ||||||||||||||||||||||||||||||
# max_out_length=length, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
flax_output = flax_sampler.chat(input_str, images=image) | ||||||||||||||||||||||||||||||
print("🔶 Flax output:", flax_output) | ||||||||||||||||||||||||||||||
|
@@ -508,11 +516,11 @@ def main(_): | |||||||||||||||||||||||||||||
assert preset in presets, ( | ||||||||||||||||||||||||||||||
f"Invalid preset {preset}. Must be one of {','.join(presets)}" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
text_only = "text" in preset or "1b" in preset | ||||||||||||||||||||||||||||||
text_only = "text" in preset or "1b" in preset or "270m" in preset | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
print("🏃 Loading Flax model and tokeniser") | ||||||||||||||||||||||||||||||
flax_kwargs = {} | ||||||||||||||||||||||||||||||
if text_only and "1b" not in preset: | ||||||||||||||||||||||||||||||
if text_only and "1b" not in preset and "270m" not in preset: | ||||||||||||||||||||||||||||||
flax_kwargs["text_only"] = True | ||||||||||||||||||||||||||||||
Comment on lines
+519
to
524
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for determining if a model is For example, you could add PRESET_MAP = {
# ...
"gemma3_instruct_270m": {
"model": gm.nn.Gemma3_270M,
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_IT,
"is_text_only": True,
"needs_flax_text_only_kwarg": False,
},
"gemma3_4b_text": {
"model": gm.nn.Gemma3_4B,
"params": gm.ckpts.CheckpointPath.GEMMA3_4B_PT,
"is_text_only": True,
"needs_flax_text_only_kwarg": True,
},
# ...
} Then, the logic in
Suggested change
Style Guide ReferencesFootnotes |
||||||||||||||||||||||||||||||
flax_model = PRESET_MAP[preset]["model"](**flax_kwargs) | ||||||||||||||||||||||||||||||
flax_config = flax_model.config | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove? Or add to .gitignore?