sam.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # Copyright (c) Meta Platforms, Inc. and affiliates.
  3. # All rights reserved.
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Any, Dict, List, Tuple
  7. import torch
  8. from torch import nn
  9. from torch.nn import functional as F
  10. from .decoders import MaskDecoder
  11. from .encoders import ImageEncoderViT, PromptEncoder
  12. class Sam(nn.Module):
  13. mask_threshold: float = 0.0
  14. image_format: str = 'RGB'
  15. def __init__(self,
  16. image_encoder: ImageEncoderViT,
  17. prompt_encoder: PromptEncoder,
  18. mask_decoder: MaskDecoder,
  19. pixel_mean: List[float] = None,
  20. pixel_std: List[float] = None) -> None:
  21. """
  22. SAM predicts object masks from an image and input prompts.
  23. Args:
  24. image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
  25. efficient mask prediction.
  26. prompt_encoder (PromptEncoder): Encodes various types of input prompts.
  27. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
  28. pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
  29. pixel_std (list(float)): Std values for normalizing pixels in the input image.
  30. """
  31. if pixel_mean is None:
  32. pixel_mean = [123.675, 116.28, 103.53]
  33. if pixel_std is None:
  34. pixel_std = [58.395, 57.12, 57.375]
  35. super().__init__()
  36. self.image_encoder = image_encoder
  37. self.prompt_encoder = prompt_encoder
  38. self.mask_decoder = mask_decoder
  39. self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
  40. self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
  41. @property
  42. def device(self) -> Any:
  43. return self.pixel_mean.device
  44. @torch.no_grad()
  45. def forward(
  46. self,
  47. batched_input: List[Dict[str, Any]],
  48. multimask_output: bool,
  49. ) -> List[Dict[str, torch.Tensor]]:
  50. """
  51. Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
  52. SamPredictor is recommended over calling the model directly.
  53. Args:
  54. batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
  55. key can be excluded if it is not present.
  56. 'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
  57. 'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
  58. 'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
  59. transformed to the input frame of the model.
  60. 'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
  61. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
  62. the model.
  63. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
  64. multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
  65. mask.
  66. Returns:
  67. (list(dict)): A list over input images, where each element is as dictionary with the following keys.
  68. 'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
  69. input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
  70. 'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
  71. 'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
  72. as mask input to subsequent iterations of prediction.
  73. """
  74. input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
  75. image_embeddings = self.image_encoder(input_images)
  76. outputs = []
  77. for image_record, curr_embedding in zip(batched_input, image_embeddings):
  78. if 'point_coords' in image_record:
  79. points = (image_record['point_coords'], image_record['point_labels'])
  80. else:
  81. points = None
  82. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  83. points=points,
  84. boxes=image_record.get('boxes', None),
  85. masks=image_record.get('mask_inputs', None),
  86. )
  87. low_res_masks, iou_predictions = self.mask_decoder(
  88. image_embeddings=curr_embedding.unsqueeze(0),
  89. image_pe=self.prompt_encoder.get_dense_pe(),
  90. sparse_prompt_embeddings=sparse_embeddings,
  91. dense_prompt_embeddings=dense_embeddings,
  92. multimask_output=multimask_output,
  93. )
  94. masks = self.postprocess_masks(
  95. low_res_masks,
  96. input_size=image_record['image'].shape[-2:],
  97. original_size=image_record['original_size'],
  98. )
  99. masks = masks > self.mask_threshold
  100. outputs.append({
  101. 'masks': masks,
  102. 'iou_predictions': iou_predictions,
  103. 'low_res_logits': low_res_masks, })
  104. return outputs
  105. def postprocess_masks(
  106. self,
  107. masks: torch.Tensor,
  108. input_size: Tuple[int, ...],
  109. original_size: Tuple[int, ...],
  110. ) -> torch.Tensor:
  111. """
  112. Remove padding and upscale masks to the original image size.
  113. Args:
  114. masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
  115. input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
  116. original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
  117. Returns:
  118. (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
  119. """
  120. masks = F.interpolate(
  121. masks,
  122. (self.image_encoder.img_size, self.image_encoder.img_size),
  123. mode='bilinear',
  124. align_corners=False,
  125. )
  126. masks = masks[..., :input_size[0], :input_size[1]]
  127. return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
  128. def preprocess(self, x: torch.Tensor) -> torch.Tensor:
  129. """Normalize pixel values and pad to a square input."""
  130. # Normalize colors
  131. x = (x - self.pixel_mean) / self.pixel_std
  132. # Pad
  133. h, w = x.shape[-2:]
  134. padh = self.image_encoder.img_size - h
  135. padw = self.image_encoder.img_size - w
  136. return F.pad(x, (0, padw, 0, padh))