@@ -77,10 +77,12 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
77
77
num_minibatches = len (batch )
78
78
79
79
# initialize accumulation variables
80
- batch_total_samples = 0.0
81
- batch_total_length = 0.0
82
- accumulated_loss = 0.0
83
- accumulated_aux_loss = 0.0 if self .model .is_gpt_oss else None
80
+ batch_total_samples = 0
81
+ batch_total_length = 0
82
+ accumulated_loss = torch .tensor ([0.0 ], dtype = torch .float32 )
83
+ accumulated_aux_loss = (
84
+ torch .tensor ([0.0 ], dtype = torch .float32 ) if self .model .is_gpt_oss else None
85
+ )
84
86
grad_accum_steps = 0
85
87
86
88
# process each minibatch
@@ -134,22 +136,25 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
134
136
def _prepare_model_inputs (self , mb : CollatedItem ) -> ModelInputs :
135
137
"""Prepare and move model inputs to GPU."""
136
138
model_inputs = ModelInputs (
137
- input_ids = mb ["input_ids" ],
138
- labels = mb ["labels" ],
139
- position_ids = mb ["position_ids" ],
139
+ input_ids = mb ["input_ids" ].to (device = self .torch_device ),
140
+ labels = mb ["labels" ].to (device = self .torch_device ),
140
141
)
141
- if "attention_mask" in mb :
142
- model_inputs ["attention_mask" ] = mb ["attention_mask" ]
143
142
144
- # send tensors to gpu
145
- for k in model_inputs .keys ():
146
- model_inputs [k ] = model_inputs [k ].to (device = self .torch_device )
143
+ # add optional fields onto `model_inputs` object
144
+ if "attention_mask" in mb :
145
+ model_inputs ["attention_mask" ] = mb ["attention_mask" ].to (
146
+ device = self .torch_device
147
+ )
148
+ if "position_ids" in mb :
149
+ model_inputs ["position_ids" ] = mb ["position_ids" ].to (
150
+ device = self .torch_device
151
+ )
147
152
148
153
return model_inputs
149
154
150
155
def _reduce_metrics (
151
- self , batch_total_samples : float , batch_total_length : float
152
- ) -> tuple [float , float ]:
156
+ self , batch_total_samples : int , batch_total_length : int
157
+ ) -> tuple [int , int ]:
153
158
"""Reduce rank-specific metrics across devices."""
154
159
inputs_to_reduce = torch .tensor (
155
160
[batch_total_samples , batch_total_length ],
0 commit comments