Skip to content

Commit b2418c9

Browse files
committed
fix tests x2
1 parent c1bb315 commit b2418c9

File tree

4 files changed

+27
-20
lines changed

4 files changed

+27
-20
lines changed

src/instructlab/training/batch_loss_manager.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
7777
num_minibatches = len(batch)
7878

7979
# initialize accumulation variables
80-
batch_total_samples = 0.0
81-
batch_total_length = 0.0
82-
accumulated_loss = 0.0
83-
accumulated_aux_loss = 0.0 if self.model.is_gpt_oss else None
80+
batch_total_samples = 0
81+
batch_total_length = 0
82+
accumulated_loss = torch.tensor([0.0], dtype=torch.float32)
83+
accumulated_aux_loss = (
84+
torch.tensor([0.0], dtype=torch.float32) if self.model.is_gpt_oss else None
85+
)
8486
grad_accum_steps = 0
8587

8688
# process each minibatch
@@ -134,22 +136,25 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
134136
def _prepare_model_inputs(self, mb: CollatedItem) -> ModelInputs:
135137
"""Prepare and move model inputs to GPU."""
136138
model_inputs = ModelInputs(
137-
input_ids=mb["input_ids"],
138-
labels=mb["labels"],
139-
position_ids=mb["position_ids"],
139+
input_ids=mb["input_ids"].to(device=self.torch_device),
140+
labels=mb["labels"].to(device=self.torch_device),
140141
)
141-
if "attention_mask" in mb:
142-
model_inputs["attention_mask"] = mb["attention_mask"]
143142

144-
# send tensors to gpu
145-
for k in model_inputs.keys():
146-
model_inputs[k] = model_inputs[k].to(device=self.torch_device)
143+
# add optional fields onto `model_inputs` object
144+
if "attention_mask" in mb:
145+
model_inputs["attention_mask"] = mb["attention_mask"].to(
146+
device=self.torch_device
147+
)
148+
if "position_ids" in mb:
149+
model_inputs["position_ids"] = mb["position_ids"].to(
150+
device=self.torch_device
151+
)
147152

148153
return model_inputs
149154

150155
def _reduce_metrics(
151-
self, batch_total_samples: float, batch_total_length: float
152-
) -> tuple[float, float]:
156+
self, batch_total_samples: int, batch_total_length: int
157+
) -> tuple[int, int]:
153158
"""Reduce rank-specific metrics across devices."""
154159
inputs_to_reduce = torch.tensor(
155160
[batch_total_samples, batch_total_length],

src/instructlab/training/sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,10 @@ def __init__(
249249
else:
250250
self.batch_packer = batch_lengths_to_minibatches_padded
251251
# Create a wrapper for padded collate that includes pad_token_id
252-
self.collate_fn = lambda mb, tokens: padded_mb_collate_fn(
253-
mb, tokens, pad_token_id
252+
self.collate_fn = (
253+
lambda minibatch, batch_num_loss_counted_tokens: padded_mb_collate_fn(
254+
minibatch, batch_num_loss_counted_tokens, pad_token_id
255+
)
254256
)
255257

256258
def __call__(self, batch: list[dict]):

src/instructlab/training/type_definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class ModelInputs(t.TypedDict):
6868

6969
input_ids: Required[torch.Tensor]
7070
labels: Required[torch.Tensor]
71-
position_ids: Required[torch.Tensor]
71+
position_ids: NotRequired[torch.Tensor]
7272
attention_mask: NotRequired[torch.Tensor] # used when not training in padding free
7373

7474

tests/unit/test_data_process.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -890,9 +890,9 @@ def test_with_qwen_tokenizer(self):
890890
self.assertNotIn(unmask_begin_id, result["input_ids"])
891891
self.assertNotIn(unmask_end_id, result["input_ids"])
892892

893-
def test_with_mistral_tokenizer(self):
894-
"""Test reasoning_content functionality with Mistral tokenizer."""
895-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
893+
def test_with_phi_tokenizer(self):
894+
"""Test reasoning_content functionality with Phi-4 tokenizer."""
895+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
896896

897897
# Add the unmask tokens to the tokenizer
898898
tokenizer.add_special_tokens(

0 commit comments

Comments
 (0)