123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import Any, Dict, List, Tuple
- import torch
- from torch import nn
- from torch.nn import functional as F
- from .decoders import MaskDecoder
- from .encoders import ImageEncoderViT, PromptEncoder
- class Sam(nn.Module):
- mask_threshold: float = 0.0
- image_format: str = 'RGB'
- def __init__(self,
- image_encoder: ImageEncoderViT,
- prompt_encoder: PromptEncoder,
- mask_decoder: MaskDecoder,
- pixel_mean: List[float] = None,
- pixel_std: List[float] = None) -> None:
- """
- SAM predicts object masks from an image and input prompts.
- Args:
- image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
- efficient mask prediction.
- prompt_encoder (PromptEncoder): Encodes various types of input prompts.
- mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
- pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
- pixel_std (list(float)): Std values for normalizing pixels in the input image.
- """
- if pixel_mean is None:
- pixel_mean = [123.675, 116.28, 103.53]
- if pixel_std is None:
- pixel_std = [58.395, 57.12, 57.375]
- super().__init__()
- self.image_encoder = image_encoder
- self.prompt_encoder = prompt_encoder
- self.mask_decoder = mask_decoder
- self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
- self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
- @property
- def device(self) -> Any:
- return self.pixel_mean.device
- @torch.no_grad()
- def forward(
- self,
- batched_input: List[Dict[str, Any]],
- multimask_output: bool,
- ) -> List[Dict[str, torch.Tensor]]:
- """
- Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
- SamPredictor is recommended over calling the model directly.
- Args:
- batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
- key can be excluded if it is not present.
- 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
- 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
- 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
- transformed to the input frame of the model.
- 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
- 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
- the model.
- 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
- multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
- mask.
- Returns:
- (list(dict)): A list over input images, where each element is as dictionary with the following keys.
- 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
- input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
- 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
- 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
- as mask input to subsequent iterations of prediction.
- """
- input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
- image_embeddings = self.image_encoder(input_images)
- outputs = []
- for image_record, curr_embedding in zip(batched_input, image_embeddings):
- if 'point_coords' in image_record:
- points = (image_record['point_coords'], image_record['point_labels'])
- else:
- points = None
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- points=points,
- boxes=image_record.get('boxes', None),
- masks=image_record.get('mask_inputs', None),
- )
- low_res_masks, iou_predictions = self.mask_decoder(
- image_embeddings=curr_embedding.unsqueeze(0),
- image_pe=self.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- )
- masks = self.postprocess_masks(
- low_res_masks,
- input_size=image_record['image'].shape[-2:],
- original_size=image_record['original_size'],
- )
- masks = masks > self.mask_threshold
- outputs.append({
- 'masks': masks,
- 'iou_predictions': iou_predictions,
- 'low_res_logits': low_res_masks, })
- return outputs
- def postprocess_masks(
- self,
- masks: torch.Tensor,
- input_size: Tuple[int, ...],
- original_size: Tuple[int, ...],
- ) -> torch.Tensor:
- """
- Remove padding and upscale masks to the original image size.
- Args:
- masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
- input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
- original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
- Returns:
- (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
- """
- masks = F.interpolate(
- masks,
- (self.image_encoder.img_size, self.image_encoder.img_size),
- mode='bilinear',
- align_corners=False,
- )
- masks = masks[..., :input_size[0], :input_size[1]]
- return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
- def preprocess(self, x: torch.Tensor) -> torch.Tensor:
- """Normalize pixel values and pad to a square input."""
- # Normalize colors
- x = (x - self.pixel_mean) / self.pixel_std
- # Pad
- h, w = x.shape[-2:]
- padh = self.image_encoder.img_size - h
- padw = self.image_encoder.img_size - w
- return F.pad(x, (0, padw, 0, padh))
|