diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 41ce53af5..aaa9f7505 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -17,6 +17,15 @@ from sam2.utils.transforms import SAM2Transforms +class SAM2ImagePredictorState: + # Predictor state + _is_image_set = False + _features = None + _orig_hw = None + # Whether the predictor is set for single image or a batch of images + _is_batch = False + + class SAM2ImagePredictor: def __init__( self, @@ -48,12 +57,8 @@ def __init__( max_sprinkle_area=max_sprinkle_area, ) - # Predictor state - self._is_image_set = False - self._features = None - self._orig_hw = None - # Whether the predictor is set for single image or a batch of images - self._is_batch = False + # Predictor state, backward compatibility with stateful interface + self._state = SAM2ImagePredictorState() # Predictor config self.mask_threshold = mask_threshold @@ -86,7 +91,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": def set_image( self, image: Union[np.ndarray, Image], - ) -> None: + ) -> SAM2ImagePredictorState: """ Calculates the image embeddings for the provided image, allowing masks to be predicted with the 'predict' method. @@ -95,15 +100,18 @@ def set_image( image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255]. image_format (str): The color format of the image, in ['RGB', 'BGR']. + + Returns: + (SAM2ImagePredictorState): The precomputed state. """ - self.reset_predictor() + state = SAM2ImagePredictorState() # Transform the image to the form expected by the model if isinstance(image, np.ndarray): logging.info("For numpy array image, we assume (HxWxC) format") - self._orig_hw = [image.shape[:2]] + state._orig_hw = [image.shape[:2]] elif isinstance(image, Image): w, h = image.size - self._orig_hw = [(h, w)] + state._orig_hw = [(h, w)] else: raise NotImplementedError("Image format not supported") @@ -124,15 +132,17 @@ def set_image( feat.permute(1, 2, 0).view(1, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) ][::-1] - self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} - self._is_image_set = True + state._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + state._is_image_set = True + self._state = state logging.info("Image embeddings computed.") + return state @torch.no_grad() def set_image_batch( self, image_list: List[Union[np.ndarray]], - ) -> None: + ) -> SAM2ImagePredictorState: """ Calculates the image embeddings for the provided image batch, allowing masks to be predicted with the 'predict_batch' method. @@ -140,15 +150,18 @@ def set_image_batch( Arguments: image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray with pixel values in [0, 255]. + + Returns: + (SAM2ImagePredictorState): The precomputed state. """ - self.reset_predictor() + state = SAM2ImagePredictorState() assert isinstance(image_list, list) - self._orig_hw = [] + state._orig_hw = [] for image in image_list: assert isinstance( image, np.ndarray ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" - self._orig_hw.append(image.shape[:2]) + state._orig_hw.append(image.shape[:2]) # Transform the image to the form expected by the model img_batch = self._transforms.forward_batch(image_list) img_batch = img_batch.to(self.device) @@ -167,10 +180,12 @@ def set_image_batch( feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) ][::-1] - self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} - self._is_image_set = True - self._is_batch = True + state._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + state._is_image_set = True + state._is_batch = True + self._state = state logging.info("Image embeddings computed.") + return state def predict_batch( self, @@ -181,16 +196,23 @@ def predict_batch( multimask_output: bool = True, return_logits: bool = False, normalize_coords=True, + state: Optional[SAM2ImagePredictorState] = None, ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. It returns a tuple of lists of masks, ious, and low_res_masks_logits. """ - assert self._is_batch, "This function should only be used when in batched mode" - if not self._is_image_set: + if state is None: + logging.warning( + "Using non thread-safe stateful interface. Pass state, which is retunred by set_image()" + ) + state = self._state # compatibility for stateful interface + + assert state._is_batch, "This function should only be used when in batched mode" + if not state._is_image_set: raise RuntimeError( "An image must be set with .set_image_batch(...) before mask prediction." ) - num_images = len(self._features["image_embed"]) + num_images = len(state._features["image_embed"]) all_masks = [] all_ious = [] all_low_res_masks = [] @@ -207,6 +229,7 @@ def predict_batch( mask_input_batch[img_idx] if mask_input_batch is not None else None ) mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + state, point_coords, point_labels, box, @@ -222,6 +245,7 @@ def predict_batch( multimask_output, return_logits=return_logits, img_idx=img_idx, + state=state, ) masks_np = masks.squeeze(0).float().detach().cpu().numpy() iou_predictions_np = ( @@ -243,6 +267,7 @@ def predict( multimask_output: bool = True, return_logits: bool = False, normalize_coords=True, + state: Optional[SAM2ImagePredictorState] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Predict masks for the given input prompts, using the currently set image. @@ -267,6 +292,7 @@ def predict( return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + state (SAM2ImagePredictorState): Pass state returend by set_image() to use stateless interface allowing thread-safe execution. Returns: (np.ndarray): The output masks in CxHxW format, where C is the @@ -277,7 +303,13 @@ def predict( of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ - if not self._is_image_set: + if state is None: + logging.warning( + "Using non thread-safe stateful interface. Pass state, which is retunred by set_image()" + ) + state = self._state # compatibility for stateful interface + + if not state._is_image_set: raise RuntimeError( "An image must be set with .set_image(...) before mask prediction." ) @@ -285,7 +317,7 @@ def predict( # Transform input prompts mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( - point_coords, point_labels, box, mask_input, normalize_coords + state, point_coords, point_labels, box, mask_input, normalize_coords ) masks, iou_predictions, low_res_masks = self._predict( @@ -295,6 +327,7 @@ def predict( mask_input, multimask_output, return_logits=return_logits, + state=state, ) masks_np = masks.squeeze(0).float().detach().cpu().numpy() @@ -303,7 +336,14 @@ def predict( return masks_np, iou_predictions_np, low_res_masks_np def _prep_prompts( - self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + self, + state, + point_coords, + point_labels, + box, + mask_logits, + normalize_coords, + img_idx=-1, ): unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None @@ -315,7 +355,9 @@ def _prep_prompts( point_coords, dtype=torch.float, device=self.device ) unnorm_coords = self._transforms.transform_coords( - point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + point_coords, + normalize=normalize_coords, + orig_hw=state._orig_hw[img_idx], ) labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) if len(unnorm_coords.shape) == 2: @@ -323,7 +365,7 @@ def _prep_prompts( if box is not None: box = torch.as_tensor(box, dtype=torch.float, device=self.device) unnorm_box = self._transforms.transform_boxes( - box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + box, normalize=normalize_coords, orig_hw=state._orig_hw[img_idx] ) # Bx2x2 if mask_logits is not None: mask_input = torch.as_tensor( @@ -343,6 +385,7 @@ def _predict( multimask_output: bool = True, return_logits: bool = False, img_idx: int = -1, + state: Optional[SAM2ImagePredictorState] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Predict masks for the given input prompts, using the currently set image. @@ -350,6 +393,8 @@ def _predict( transformed to the input frame using SAM2Transforms. Arguments: + state (SAM2ImagePredictorState): Pass state returend by set_image() to use + stateless interface allowing thread-safe execution. point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the model. Each point is in (X,Y) in pixels. point_labels (torch.Tensor or None): A BxN array of labels for the @@ -369,6 +414,7 @@ def _predict( input prompts, multimask_output=False can give better results. return_logits (bool): If true, returns un-thresholded masks logits instead of a binary mask. + state (SAM2ImagePredictorState): Pass state returend by set_image() to use stateless interface allowing thread-safe execution. Returns: (torch.Tensor): The output masks in BxCxHxW format, where C is the @@ -379,7 +425,9 @@ def _predict( of masks and H=W=256. These low res logits can be passed to a subsequent iteration as mask input. """ - if not self._is_image_set: + if state is None: + state = self._state # compatibility for stateful interface + if not state._is_image_set: raise RuntimeError( "An image must be set with .set_image(...) before mask prediction." ) @@ -415,10 +463,10 @@ def _predict( ) # multi object prediction high_res_features = [ feat_level[img_idx].unsqueeze(0) - for feat_level in self._features["high_res_feats"] + for feat_level in state._features["high_res_feats"] ] low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( - image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_embeddings=state._features["image_embed"][img_idx].unsqueeze(0), image_pe=self.model.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, @@ -429,7 +477,7 @@ def _predict( # Upscale the masks to the original image resolution masks = self._transforms.postprocess_masks( - low_res_masks, self._orig_hw[img_idx] + low_res_masks, state._orig_hw[img_idx] ) low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) if not return_logits: @@ -437,20 +485,26 @@ def _predict( return masks, iou_predictions, low_res_masks - def get_image_embedding(self) -> torch.Tensor: + def get_image_embedding( + self, + state: Optional[SAM2ImagePredictorState] = None, + ) -> torch.Tensor: """ Returns the image embeddings for the currently set image, with shape 1xCxHxW, where C is the embedding dimension and (H,W) are the embedding spatial dimension of SAM (typically C=256, H=W=64). """ - if not self._is_image_set: + if state is None: + state = self._state # compatibility for stateful interface + + if not state._is_image_set: raise RuntimeError( "An image must be set with .set_image(...) to generate an embedding." ) assert ( - self._features is not None + state._features is not None ), "Features must exist if an image has been set." - return self._features["image_embed"] + return state._features["image_embed"] @property def device(self) -> torch.device: @@ -460,7 +514,7 @@ def reset_predictor(self) -> None: """ Resets the image embeddings and other state variables. """ - self._is_image_set = False - self._features = None - self._orig_hw = None - self._is_batch = False + logging.warning( + "Using non thread-safe stateful interface. See return value of set_image()" + ) + self._state = SAM2ImagePredictorState()