diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..8659be2f04 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,29 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Pytest: Prompt for test target", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "console": "integratedTerminal", + "justMyCode": true, + "args": [ + "${input:pytestTarget}" + ] + } + ], + + // 2. This 'inputs' array defines the prompt + "inputs": [ + { + "id": "pytestTarget", + "type": "promptString", + "description": "Enter the pytest target (file path or file::function)", + "default": "" + } + ] +} \ No newline at end of file diff --git a/tools/checkpoint_conversion/convert_gemma3_checkpoints.py b/tools/checkpoint_conversion/convert_gemma3_checkpoints.py index 19945cb6ee..2105ae4ff1 100644 --- a/tools/checkpoint_conversion/convert_gemma3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gemma3_checkpoints.py @@ -10,8 +10,8 @@ Usage: ```shell cd tools/checkpoint_conversion -python convert_gemma_checkpoints.py --preset gemma3_instruct_1b -python convert_gemma_checkpoints.py --preset gemma3_instruct_4b +python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b +python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b ``` """ @@ -43,6 +43,15 @@ PRESET_MAP = { # === Text === + # 270M + "gemma3_instruct_270m": { + "model": gm.nn.Gemma3_270M, + "params": gm.ckpts.CheckpointPath.GEMMA3_270M_IT, + }, + "gemma3_270m": { + "model": gm.nn.Gemma3_270M, + "params": gm.ckpts.CheckpointPath.GEMMA3_270M_PT, + }, # 1B "gemma3_1b": { "model": gm.nn.Gemma3_1B, @@ -493,7 +502,6 @@ def validate_output( params=flax_params, multi_turn=False, cache_length=256 if length <= 256 else 512, - # max_out_length=length, ) flax_output = flax_sampler.chat(input_str, images=image) print("🔶 Flax output:", flax_output) @@ -508,11 +516,11 @@ def main(_): assert preset in presets, ( f"Invalid preset {preset}. Must be one of {','.join(presets)}" ) - text_only = "text" in preset or "1b" in preset + text_only = "text" in preset or "1b" in preset or "270m" in preset print("🏃 Loading Flax model and tokeniser") flax_kwargs = {} - if text_only and "1b" not in preset: + if text_only and "1b" not in preset and "270m" not in preset: flax_kwargs["text_only"] = True flax_model = PRESET_MAP[preset]["model"](**flax_kwargs) flax_config = flax_model.config