Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove? Or add to .gitignore?

// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Pytest: Prompt for test target",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"${input:pytestTarget}"
]
}
],

// 2. This 'inputs' array defines the prompt
"inputs": [
{
"id": "pytestTarget",
"type": "promptString",
"description": "Enter the pytest target (file path or file::function)",
"default": ""
}
]
}
18 changes: 13 additions & 5 deletions tools/checkpoint_conversion/convert_gemma3_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
Usage:
```shell
cd tools/checkpoint_conversion
python convert_gemma_checkpoints.py --preset gemma3_instruct_1b
python convert_gemma_checkpoints.py --preset gemma3_instruct_4b
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b
```
"""

Expand Down Expand Up @@ -43,6 +43,15 @@

PRESET_MAP = {
# === Text ===
# 270M
"gemma3_instruct_270m": {
"model": gm.nn.Gemma3_270M,
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_IT,
},
"gemma3_270m": {
"model": gm.nn.Gemma3_270M,
"params": gm.ckpts.CheckpointPath.GEMMA3_270M_PT,
},
# 1B
"gemma3_1b": {
"model": gm.nn.Gemma3_1B,
Expand Down Expand Up @@ -493,7 +502,6 @@ def validate_output(
params=flax_params,
multi_turn=False,
cache_length=256 if length <= 256 else 512,
# max_out_length=length,
)
flax_output = flax_sampler.chat(input_str, images=image)
print("🔶 Flax output:", flax_output)
Expand All @@ -508,11 +516,11 @@ def main(_):
assert preset in presets, (
f"Invalid preset {preset}. Must be one of {','.join(presets)}"
)
text_only = "text" in preset or "1b" in preset
text_only = "text" in preset or "1b" in preset or "270m" in preset

print("🏃 Loading Flax model and tokeniser")
flax_kwargs = {}
if text_only and "1b" not in preset:
if text_only and "1b" not in preset and "270m" not in preset:
flax_kwargs["text_only"] = True
Comment on lines +519 to 524

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining if a model is text_only and if it needs the text_only kwarg for Flax is based on string matching in the preset name. This can be fragile and hard to maintain as new model sizes are added. To improve reusability and ensure all presets are handled robustly, consider making this information explicit in the PRESET_MAP.1

For example, you could add is_text_only and needs_flax_text_only_kwarg flags to each preset dictionary:

PRESET_MAP = {
    # ...
    "gemma3_instruct_270m": {
        "model": gm.nn.Gemma3_270M,
        "params": gm.ckpts.CheckpointPath.GEMMA3_270M_IT,
        "is_text_only": True,
        "needs_flax_text_only_kwarg": False,
    },
    "gemma3_4b_text": {
        "model": gm.nn.Gemma3_4B,
        "params": gm.ckpts.CheckpointPath.GEMMA3_4B_PT,
        "is_text_only": True,
        "needs_flax_text_only_kwarg": True,
    },
    # ...
}

Then, the logic in main() would be much cleaner and less error-prone when adding new presets.

Suggested change
text_only = "text" in preset or "1b" in preset or "270m" in preset
print("🏃 Loading Flax model and tokeniser")
flax_kwargs = {}
if text_only and "1b" not in preset:
if text_only and "1b" not in preset and "270m" not in preset:
flax_kwargs["text_only"] = True
preset_info = PRESET_MAP[preset]
text_only = preset_info.get("is_text_only", False)
print("🏃 Loading Flax model and tokeniser")
flax_kwargs = {}
if preset_info.get("needs_flax_text_only_kwarg", False):
flax_kwargs["text_only"] = True

Style Guide References

Footnotes

  1. Checkpoint conversion scripts should be reusable and able to handle all presets for a model. Relying on string matching in preset names can make the script less robust and harder to maintain when new presets are added. (link)

flax_model = PRESET_MAP[preset]["model"](**flax_kwargs)
flax_config = flax_model.config
Expand Down
Loading