Skip to content

Commit 488614e

Browse files
committed
'granite-3.3-2b-instruct' for smoketest; smaller smoke dataset
replaces `instructlab/granite-7b-lab` with `ibm-granite/granite-3.3-2b-instruct` for smoketest so we can use smaller runner and smoketest will run faster. also adds dataset subsampling logic so we can make the testing run go arbitrarily quicker. Signed-off-by: James Kunstle <[email protected]>
1 parent 425f5ec commit 488614e

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

tests/smoke/test_train.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import tempfile
88

99
# Third Party
10+
from datasets import load_dataset
1011
from transformers import AutoModelForCausalLM
1112
import huggingface_hub
1213
import pytest
@@ -46,9 +47,12 @@
4647
"rdzv_endpoint": "127.0.0.1:12345",
4748
}
4849

49-
REFERENCE_TEST_MODEL = "instructlab/granite-7b-lab"
50+
REFERENCE_TEST_MODEL = "ibm-granite/granite-3.3-2b-instruct"
5051
RUNNER_CPUS_EXPECTED = 4
5152

53+
# Number of samples to randomly sample from the processed dataset for faster training
54+
NUM_SAMPLES_TO_KEEP = 5000
55+
5256

5357
@pytest.fixture(scope="module")
5458
def custom_tmp_dir() -> Generator[pathlib.Path, None, None]:
@@ -190,7 +194,10 @@ def chat_template_in_repo_path() -> pathlib.Path:
190194
def cached_training_data(
191195
prepared_data_dir: pathlib.Path, cached_test_model: pathlib.Path
192196
) -> pathlib.Path:
193-
"""Renders test data in model template, tokenizes, and saves to fs"""
197+
"""
198+
Renders test data in model template, tokenizes, and saves to filesystem.
199+
Subsamples NUM_SAMPLES_TO_KEEP examples to speed up tests.
200+
"""
194201

195202
data_in_repo = data_in_repo_path()
196203
chat_template = chat_template_in_repo_path()
@@ -206,7 +213,19 @@ def cached_training_data(
206213

207214
data_process.main(data_process_args)
208215

209-
return prepared_data_dir / "data.jsonl"
216+
# Load the processed data and sample a subset
217+
output_path = prepared_data_dir / "data.jsonl"
218+
dataset = load_dataset("json", data_files=str(output_path), split="train")
219+
220+
# Randomly sample NUM_SAMPLES_TO_KEEP examples
221+
sampled_dataset = dataset.shuffle(seed=42).select(
222+
range(min(NUM_SAMPLES_TO_KEEP, len(dataset)))
223+
)
224+
225+
# Write the sampled data back to the same file
226+
sampled_dataset.to_json(str(output_path), num_proc=RUNNER_CPUS_EXPECTED)
227+
228+
return output_path
210229

211230

212231
@pytest.mark.slow

0 commit comments

Comments
 (0)