@@ -77,10 +77,10 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
77
77
num_minibatches = len (batch )
78
78
79
79
# initialize accumulation variables
80
- batch_total_samples = 0.0
81
- batch_total_length = 0.0
80
+ batch_total_samples = 0
81
+ batch_total_length = 0
82
82
accumulated_loss = 0.0
83
- accumulated_aux_loss = 0.0 if self . model . is_gpt_oss else None
83
+ accumulated_aux_loss = 0.0
84
84
grad_accum_steps = 0
85
85
86
86
# process each minibatch
@@ -134,22 +134,25 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
134
134
def _prepare_model_inputs (self , mb : CollatedItem ) -> ModelInputs :
135
135
"""Prepare and move model inputs to GPU."""
136
136
model_inputs = ModelInputs (
137
- input_ids = mb ["input_ids" ],
138
- labels = mb ["labels" ],
139
- position_ids = mb ["position_ids" ],
137
+ input_ids = mb ["input_ids" ].to (device = self .torch_device ),
138
+ labels = mb ["labels" ].to (device = self .torch_device ),
140
139
)
141
- if "attention_mask" in mb :
142
- model_inputs ["attention_mask" ] = mb ["attention_mask" ]
143
140
144
- # send tensors to gpu
145
- for k in model_inputs .keys ():
146
- model_inputs [k ] = model_inputs [k ].to (device = self .torch_device )
141
+ # add optional fields onto `model_inputs` object
142
+ if "attention_mask" in mb :
143
+ model_inputs ["attention_mask" ] = mb ["attention_mask" ].to (
144
+ device = self .torch_device
145
+ )
146
+ if "position_ids" in mb :
147
+ model_inputs ["position_ids" ] = mb ["position_ids" ].to (
148
+ device = self .torch_device
149
+ )
147
150
148
151
return model_inputs
149
152
150
153
def _reduce_metrics (
151
- self , batch_total_samples : float , batch_total_length : float
152
- ) -> tuple [float , float ]:
154
+ self , batch_total_samples : int , batch_total_length : int
155
+ ) -> tuple [int , int ]:
153
156
"""Reduce rank-specific metrics across devices."""
154
157
inputs_to_reduce = torch .tensor (
155
158
[batch_total_samples , batch_total_length ],
0 commit comments