@@ -124,7 +124,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
124124 rec_thresholds = rec_thresholds ,
125125 class_mean = None ,
126126 )
127- precision = torch .double if torch .device (device ) != torch . device ( "mps" ) else torch .float32
127+ precision = torch .double if torch .device (device ). type != "mps" else torch .float32
128128 self .rec_thresholds = self .rec_thresholds .to (device = device , dtype = precision )
129129
130130 @reinit__is_reduced
@@ -207,12 +207,12 @@ def _compute_recall_and_precision(
207207 indices = torch .argsort (scores , dim = - 1 , stable = True , descending = True )
208208 tp = TP [..., indices ]
209209 tp_summation = tp .cumsum (dim = - 1 )
210- if tp_summation .device != torch . device ( "mps" ) :
210+ if tp_summation .device . type != "mps" :
211211 tp_summation = tp_summation .double ()
212212
213213 fp = FP [..., indices ]
214214 fp_summation = fp .cumsum (dim = - 1 )
215- if fp_summation .device != torch . device ( "mps" ) :
215+ if fp_summation .device . type != "mps" :
216216 fp_summation = fp_summation .double ()
217217
218218 recall = tp_summation / y_true_count
@@ -342,7 +342,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
342342 )
343343
344344 scores = pred ["scores" ][max_best_detections_index ]
345- if self ._device == torch . device ( "mps" ) and scores .dtype == torch .double :
345+ if self ._device . type == "mps" and scores .dtype == torch .double :
346346 scores = scores .to (dtype = torch .float32 )
347347 self ._scores .append (scores .to (self ._device ))
348348 self ._y_pred_labels .append (pred_labels .to (dtype = torch .int , device = self ._device ))
@@ -352,7 +352,7 @@ def _compute(self) -> torch.Tensor:
352352 pred_labels = _cat_and_agg_tensors (self ._y_pred_labels , cast (Tuple [int ], ()), torch .int , self ._device )
353353 TP = _cat_and_agg_tensors (self ._tps , (len (self ._iou_thresholds ),), torch .uint8 , self ._device )
354354 FP = _cat_and_agg_tensors (self ._fps , (len (self ._iou_thresholds ),), torch .uint8 , self ._device )
355- fp_precision = torch .double if self ._device != torch . device ( "mps" ) else torch .float32
355+ fp_precision = torch .double if self ._device . type != "mps" else torch .float32
356356 scores = _cat_and_agg_tensors (self ._scores , cast (Tuple [int ], ()), fp_precision , self ._device )
357357
358358 average_precisions_recalls = - torch .ones (
0 commit comments