Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 93 additions & 39 deletions sam2/sam2_image_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
from sam2.utils.transforms import SAM2Transforms


class SAM2ImagePredictorState:
# Predictor state
_is_image_set = False
_features = None
_orig_hw = None
# Whether the predictor is set for single image or a batch of images
_is_batch = False


class SAM2ImagePredictor:
def __init__(
self,
Expand Down Expand Up @@ -48,12 +57,8 @@ def __init__(
max_sprinkle_area=max_sprinkle_area,
)

# Predictor state
self._is_image_set = False
self._features = None
self._orig_hw = None
# Whether the predictor is set for single image or a batch of images
self._is_batch = False
# Predictor state, backward compatibility with stateful interface
self._state = SAM2ImagePredictorState()

# Predictor config
self.mask_threshold = mask_threshold
Expand Down Expand Up @@ -86,7 +91,7 @@ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
def set_image(
self,
image: Union[np.ndarray, Image],
) -> None:
) -> SAM2ImagePredictorState:
"""
Calculates the image embeddings for the provided image, allowing
masks to be predicted with the 'predict' method.
Expand All @@ -95,15 +100,18 @@ def set_image(
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
with pixel values in [0, 255].
image_format (str): The color format of the image, in ['RGB', 'BGR'].

Returns:
(SAM2ImagePredictorState): The precomputed state.
"""
self.reset_predictor()
state = SAM2ImagePredictorState()
# Transform the image to the form expected by the model
if isinstance(image, np.ndarray):
logging.info("For numpy array image, we assume (HxWxC) format")
self._orig_hw = [image.shape[:2]]
state._orig_hw = [image.shape[:2]]
elif isinstance(image, Image):
w, h = image.size
self._orig_hw = [(h, w)]
state._orig_hw = [(h, w)]
else:
raise NotImplementedError("Image format not supported")

Expand All @@ -124,31 +132,36 @@ def set_image(
feat.permute(1, 2, 0).view(1, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
self._is_image_set = True
state._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
state._is_image_set = True
self._state = state
logging.info("Image embeddings computed.")
return state

@torch.no_grad()
def set_image_batch(
self,
image_list: List[Union[np.ndarray]],
) -> None:
) -> SAM2ImagePredictorState:
"""
Calculates the image embeddings for the provided image batch, allowing
masks to be predicted with the 'predict_batch' method.

Arguments:
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
with pixel values in [0, 255].

Returns:
(SAM2ImagePredictorState): The precomputed state.
"""
self.reset_predictor()
state = SAM2ImagePredictorState()
assert isinstance(image_list, list)
self._orig_hw = []
state._orig_hw = []
for image in image_list:
assert isinstance(
image, np.ndarray
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
self._orig_hw.append(image.shape[:2])
state._orig_hw.append(image.shape[:2])
# Transform the image to the form expected by the model
img_batch = self._transforms.forward_batch(image_list)
img_batch = img_batch.to(self.device)
Expand All @@ -167,10 +180,12 @@ def set_image_batch(
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
self._is_image_set = True
self._is_batch = True
state._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
state._is_image_set = True
state._is_batch = True
self._state = state
logging.info("Image embeddings computed.")
return state

def predict_batch(
self,
Expand All @@ -181,16 +196,23 @@ def predict_batch(
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
state: Optional[SAM2ImagePredictorState] = None,
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
"""
assert self._is_batch, "This function should only be used when in batched mode"
if not self._is_image_set:
if state is None:
logging.warning(
"Using non thread-safe stateful interface. Pass state, which is retunred by set_image()"
)
state = self._state # compatibility for stateful interface

assert state._is_batch, "This function should only be used when in batched mode"
if not state._is_image_set:
raise RuntimeError(
"An image must be set with .set_image_batch(...) before mask prediction."
)
num_images = len(self._features["image_embed"])
num_images = len(state._features["image_embed"])
all_masks = []
all_ious = []
all_low_res_masks = []
Expand All @@ -207,6 +229,7 @@ def predict_batch(
mask_input_batch[img_idx] if mask_input_batch is not None else None
)
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
state,
point_coords,
point_labels,
box,
Expand All @@ -222,6 +245,7 @@ def predict_batch(
multimask_output,
return_logits=return_logits,
img_idx=img_idx,
state=state,
)
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
iou_predictions_np = (
Expand All @@ -243,6 +267,7 @@ def predict(
multimask_output: bool = True,
return_logits: bool = False,
normalize_coords=True,
state: Optional[SAM2ImagePredictorState] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Predict masks for the given input prompts, using the currently set image.
Expand All @@ -267,6 +292,7 @@ def predict(
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
state (SAM2ImagePredictorState): Pass state returend by set_image() to use stateless interface allowing thread-safe execution.

Returns:
(np.ndarray): The output masks in CxHxW format, where C is the
Expand All @@ -277,15 +303,21 @@ def predict(
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
if not self._is_image_set:
if state is None:
logging.warning(
"Using non thread-safe stateful interface. Pass state, which is retunred by set_image()"
)
state = self._state # compatibility for stateful interface

if not state._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)

# Transform input prompts

mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
point_coords, point_labels, box, mask_input, normalize_coords
state, point_coords, point_labels, box, mask_input, normalize_coords
)

masks, iou_predictions, low_res_masks = self._predict(
Expand All @@ -295,6 +327,7 @@ def predict(
mask_input,
multimask_output,
return_logits=return_logits,
state=state,
)

masks_np = masks.squeeze(0).float().detach().cpu().numpy()
Expand All @@ -303,7 +336,14 @@ def predict(
return masks_np, iou_predictions_np, low_res_masks_np

def _prep_prompts(
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
self,
state,
point_coords,
point_labels,
box,
mask_logits,
normalize_coords,
img_idx=-1,
):

unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
Expand All @@ -315,15 +355,17 @@ def _prep_prompts(
point_coords, dtype=torch.float, device=self.device
)
unnorm_coords = self._transforms.transform_coords(
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
point_coords,
normalize=normalize_coords,
orig_hw=state._orig_hw[img_idx],
)
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
if len(unnorm_coords.shape) == 2:
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
if box is not None:
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
unnorm_box = self._transforms.transform_boxes(
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
box, normalize=normalize_coords, orig_hw=state._orig_hw[img_idx]
) # Bx2x2
if mask_logits is not None:
mask_input = torch.as_tensor(
Expand All @@ -343,13 +385,16 @@ def _predict(
multimask_output: bool = True,
return_logits: bool = False,
img_idx: int = -1,
state: Optional[SAM2ImagePredictorState] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the currently set image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using SAM2Transforms.

Arguments:
state (SAM2ImagePredictorState): Pass state returend by set_image() to use
stateless interface allowing thread-safe execution.
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
Expand All @@ -369,6 +414,7 @@ def _predict(
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
state (SAM2ImagePredictorState): Pass state returend by set_image() to use stateless interface allowing thread-safe execution.

Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
Expand All @@ -379,7 +425,9 @@ def _predict(
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
if not self._is_image_set:
if state is None:
state = self._state # compatibility for stateful interface
if not state._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) before mask prediction."
)
Expand Down Expand Up @@ -415,10 +463,10 @@ def _predict(
) # multi object prediction
high_res_features = [
feat_level[img_idx].unsqueeze(0)
for feat_level in self._features["high_res_feats"]
for feat_level in state._features["high_res_feats"]
]
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
image_embeddings=state._features["image_embed"][img_idx].unsqueeze(0),
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
Expand All @@ -429,28 +477,34 @@ def _predict(

# Upscale the masks to the original image resolution
masks = self._transforms.postprocess_masks(
low_res_masks, self._orig_hw[img_idx]
low_res_masks, state._orig_hw[img_idx]
)
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
if not return_logits:
masks = masks > self.mask_threshold

return masks, iou_predictions, low_res_masks

def get_image_embedding(self) -> torch.Tensor:
def get_image_embedding(
self,
state: Optional[SAM2ImagePredictorState] = None,
) -> torch.Tensor:
"""
Returns the image embeddings for the currently set image, with
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
the embedding spatial dimension of SAM (typically C=256, H=W=64).
"""
if not self._is_image_set:
if state is None:
state = self._state # compatibility for stateful interface

if not state._is_image_set:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
assert (
self._features is not None
state._features is not None
), "Features must exist if an image has been set."
return self._features["image_embed"]
return state._features["image_embed"]

@property
def device(self) -> torch.device:
Expand All @@ -460,7 +514,7 @@ def reset_predictor(self) -> None:
"""
Resets the image embeddings and other state variables.
"""
self._is_image_set = False
self._features = None
self._orig_hw = None
self._is_batch = False
logging.warning(
"Using non thread-safe stateful interface. See return value of set_image()"
)
self._state = SAM2ImagePredictorState()