Skip to content

Commit c68cb35

Browse files
committed
Add tests for model weights existence coverage
1 parent 8b6a211 commit c68cb35

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/vec_inf/client/test_slurm_script_generator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params):
176176
"module load " in setup
177177
) # Remove module name since it's inconsistent between clusters
178178

179+
def test_generate_server_setup_singularity_no_weights(
180+
self, singularity_params, monkeypatch
181+
):
182+
"""Test server setup when model weights don't exist."""
183+
monkeypatch.setattr(
184+
"vec_inf.client._slurm_script_generator.Path.exists",
185+
lambda self: False,
186+
)
187+
188+
generator = SlurmScriptGenerator(singularity_params)
189+
setup = generator._generate_server_setup()
190+
191+
assert "ray stop" in setup
192+
assert "/path/to/model_weights/test-model" not in setup
193+
179194
def test_generate_launch_cmd_venv(self, basic_params):
180195
"""Test launch command generation with virtual environment."""
181196
generator = SlurmScriptGenerator(basic_params)
@@ -415,6 +430,24 @@ def test_generate_model_launch_script_singularity(
415430
mock_touch.assert_called_once()
416431
mock_write_text.assert_called_once()
417432

433+
@patch("pathlib.Path.touch")
434+
@patch("pathlib.Path.write_text")
435+
def test_generate_model_launch_script_singularity_no_weights(
436+
self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch
437+
):
438+
"""Test batch model launch script when model weights don't exist."""
439+
monkeypatch.setattr(
440+
"vec_inf.client._slurm_script_generator.Path.exists",
441+
lambda self: False,
442+
)
443+
444+
generator = BatchSlurmScriptGenerator(batch_singularity_params)
445+
script_path = generator._generate_model_launch_script("model1")
446+
447+
assert script_path.name == "launch_model1.sh"
448+
call_args = mock_write_text.call_args[0][0]
449+
assert "/path/to/model_weights/model1" not in call_args
450+
418451
@patch("vec_inf.client._slurm_script_generator.datetime")
419452
@patch("pathlib.Path.touch")
420453
@patch("pathlib.Path.write_text")

0 commit comments

Comments
 (0)