123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import numpy as np
- import torch
- import torch.nn.functional as F
- import torchvision
- from ultralytics.data.augment import LetterBox
- from ultralytics.engine.predictor import BasePredictor
- from ultralytics.engine.results import Results
- from ultralytics.utils import DEFAULT_CFG, ops
- from ultralytics.utils.torch_utils import select_device
- from .amg import (batch_iterator, batched_mask_to_box, build_all_layer_point_grids, calculate_stability_score,
- generate_crop_boxes, is_box_near_crop_edge, remove_small_regions, uncrop_boxes_xyxy, uncrop_masks)
- from .build import build_sam
- class Predictor(BasePredictor):
- def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
- if overrides is None:
- overrides = {}
- overrides.update(dict(task='segment', mode='predict', imgsz=1024))
- super().__init__(cfg, overrides, _callbacks)
- # SAM needs retina_masks=True, or the results would be a mess.
- self.args.retina_masks = True
- # Args for set_image
- self.im = None
- self.features = None
- # Args for set_prompts
- self.prompts = {}
- # Args for segment everything
- self.segment_all = False
- def preprocess(self, im):
- """Prepares input image before inference.
- Args:
- im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
- """
- if self.im is not None:
- return self.im
- not_tensor = not isinstance(im, torch.Tensor)
- if not_tensor:
- im = np.stack(self.pre_transform(im))
- im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
- im = np.ascontiguousarray(im) # contiguous
- im = torch.from_numpy(im)
- img = im.to(self.device)
- img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
- if not_tensor:
- img = (img - self.mean) / self.std
- return img
- def pre_transform(self, im):
- """
- Pre-transform input image before inference.
- Args:
- im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
- Returns:
- (list): A list of transformed images.
- """
- assert len(im) == 1, 'SAM model has not supported batch inference yet!'
- return [LetterBox(self.args.imgsz, auto=False, center=False)(image=x) for x in im]
- def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
- """
- Predict masks for the given input prompts, using the currently set image.
- Args:
- im (torch.Tensor): The preprocessed image, (N, C, H, W).
- bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
- points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
- labels (np.ndarray | List, None): (N, ), labels for the point prompts.
- 1 indicates a foreground point and 0 indicates a background point.
- masks (np.ndarray, None): A low resolution mask input to the model, typically
- coming from a previous prediction iteration. Has form (N, H, W), where
- for SAM, H=W=256.
- multimask_output (bool): If true, the model will return three masks.
- For ambiguous input prompts (such as a single click), this will often
- produce better masks than a single prediction. If only a single
- mask is needed, the model's predicted quality score can be used
- to select the best mask. For non-ambiguous prompts, such as multiple
- input prompts, multimask_output=False can give better results.
- Returns:
- (np.ndarray): The output masks in CxHxW format, where C is the
- number of masks, and (H, W) is the original image size.
- (np.ndarray): An array of length C containing the model's
- predictions for the quality of each mask.
- (np.ndarray): An array of shape CxHxW, where C is the number
- of masks and H=W=256. These low resolution logits can be passed to
- a subsequent iteration as mask input.
- """
- # Get prompts from self.prompts first
- bboxes = self.prompts.pop('bboxes', bboxes)
- points = self.prompts.pop('points', points)
- masks = self.prompts.pop('masks', masks)
- if all(i is None for i in [bboxes, points, masks]):
- return self.generate(im, *args, **kwargs)
- return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
- def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
- """
- Predict masks for the given input prompts, using the currently set image.
- Args:
- im (torch.Tensor): The preprocessed image, (N, C, H, W).
- bboxes (np.ndarray | List, None): (N, 4), in XYXY format.
- points (np.ndarray | List, None): (N, 2), Each point is in (X,Y) in pixels.
- labels (np.ndarray | List, None): (N, ), labels for the point prompts.
- 1 indicates a foreground point and 0 indicates a background point.
- masks (np.ndarray, None): A low resolution mask input to the model, typically
- coming from a previous prediction iteration. Has form (N, H, W), where
- for SAM, H=W=256.
- multimask_output (bool): If true, the model will return three masks.
- For ambiguous input prompts (such as a single click), this will often
- produce better masks than a single prediction. If only a single
- mask is needed, the model's predicted quality score can be used
- to select the best mask. For non-ambiguous prompts, such as multiple
- input prompts, multimask_output=False can give better results.
- Returns:
- (np.ndarray): The output masks in CxHxW format, where C is the
- number of masks, and (H, W) is the original image size.
- (np.ndarray): An array of length C containing the model's
- predictions for the quality of each mask.
- (np.ndarray): An array of shape CxHxW, where C is the number
- of masks and H=W=256. These low resolution logits can be passed to
- a subsequent iteration as mask input.
- """
- features = self.model.image_encoder(im) if self.features is None else self.features
- src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
- r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
- # Transform input prompts
- if points is not None:
- points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
- points = points[None] if points.ndim == 1 else points
- # Assuming labels are all positive if users don't pass labels.
- if labels is None:
- labels = np.ones(points.shape[0])
- labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
- points *= r
- # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
- points, labels = points[:, None, :], labels[:, None]
- if bboxes is not None:
- bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
- bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
- bboxes *= r
- if masks is not None:
- masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device)
- masks = masks[:, None, :, :]
- points = (points, labels) if points is not None else None
- # Embed prompts
- sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
- points=points,
- boxes=bboxes,
- masks=masks,
- )
- # Predict masks
- pred_masks, pred_scores = self.model.mask_decoder(
- image_embeddings=features,
- image_pe=self.model.prompt_encoder.get_dense_pe(),
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- )
- # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
- # `d` could be 1 or 3 depends on `multimask_output`.
- return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
- def generate(self,
- im,
- crop_n_layers=0,
- crop_overlap_ratio=512 / 1500,
- crop_downscale_factor=1,
- point_grids=None,
- points_stride=32,
- points_batch_size=64,
- conf_thres=0.88,
- stability_score_thresh=0.95,
- stability_score_offset=0.95,
- crop_nms_thresh=0.7):
- """Segment the whole image.
- Args:
- im (torch.Tensor): The preprocessed image, (N, C, H, W).
- crop_n_layers (int): If >0, mask prediction will be run again on
- crops of the image. Sets the number of layers to run, where each
- layer has 2**i_layer number of image crops.
- crop_overlap_ratio (float): Sets the degree to which crops overlap.
- In the first crop layer, crops will overlap by this fraction of
- the image length. Later layers with more crops scale down this overlap.
- crop_downscale_factor (int): The number of points-per-side
- sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
- point_grids (list(np.ndarray), None): A list over explicit grids
- of points used for sampling, normalized to [0,1]. The nth grid in the
- list is used in the nth crop layer. Exclusive with points_per_side.
- points_stride (int, None): The number of points to be sampled
- along one side of the image. The total number of points is
- points_per_side**2. If None, 'point_grids' must provide explicit
- point sampling.
- points_batch_size (int): Sets the number of points run simultaneously
- by the model. Higher numbers may be faster but use more GPU memory.
- conf_thres (float): A filtering threshold in [0,1], using the
- model's predicted mask quality.
- stability_score_thresh (float): A filtering threshold in [0,1], using
- the stability of the mask under changes to the cutoff used to binarize
- the model's mask predictions.
- stability_score_offset (float): The amount to shift the cutoff when
- calculated the stability score.
- crop_nms_thresh (float): The box IoU cutoff used by non-maximal
- suppression to filter duplicate masks between different crops.
- """
- self.segment_all = True
- ih, iw = im.shape[2:]
- crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio)
- if point_grids is None:
- point_grids = build_all_layer_point_grids(
- points_stride,
- crop_n_layers,
- crop_downscale_factor,
- )
- pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], []
- for crop_region, layer_idx in zip(crop_regions, layer_idxs):
- x1, y1, x2, y2 = crop_region
- w, h = x2 - x1, y2 - y1
- area = torch.tensor(w * h, device=im.device)
- points_scale = np.array([[w, h]]) # w, h
- # Crop image and interpolate to input size
- crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode='bilinear', align_corners=False)
- # (num_points, 2)
- points_for_image = point_grids[layer_idx] * points_scale
- crop_masks, crop_scores, crop_bboxes = [], [], []
- for (points, ) in batch_iterator(points_batch_size, points_for_image):
- pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True)
- # Interpolate predicted masks to input size
- pred_mask = F.interpolate(pred_mask[None], (h, w), mode='bilinear', align_corners=False)[0]
- idx = pred_score > conf_thres
- pred_mask, pred_score = pred_mask[idx], pred_score[idx]
- stability_score = calculate_stability_score(pred_mask, self.model.mask_threshold,
- stability_score_offset)
- idx = stability_score > stability_score_thresh
- pred_mask, pred_score = pred_mask[idx], pred_score[idx]
- # Bool type is much more memory-efficient.
- pred_mask = pred_mask > self.model.mask_threshold
- # (N, 4)
- pred_bbox = batched_mask_to_box(pred_mask).float()
- keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih])
- if not torch.all(keep_mask):
- pred_bbox = pred_bbox[keep_mask]
- pred_mask = pred_mask[keep_mask]
- pred_score = pred_score[keep_mask]
- crop_masks.append(pred_mask)
- crop_bboxes.append(pred_bbox)
- crop_scores.append(pred_score)
- # Do nms within this crop
- crop_masks = torch.cat(crop_masks)
- crop_bboxes = torch.cat(crop_bboxes)
- crop_scores = torch.cat(crop_scores)
- keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS
- crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region)
- crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw)
- crop_scores = crop_scores[keep]
- pred_masks.append(crop_masks)
- pred_bboxes.append(crop_bboxes)
- pred_scores.append(crop_scores)
- region_areas.append(area.expand(len(crop_masks)))
- pred_masks = torch.cat(pred_masks)
- pred_bboxes = torch.cat(pred_bboxes)
- pred_scores = torch.cat(pred_scores)
- region_areas = torch.cat(region_areas)
- # Remove duplicate masks between crops
- if len(crop_regions) > 1:
- scores = 1 / region_areas
- keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh)
- pred_masks = pred_masks[keep]
- pred_bboxes = pred_bboxes[keep]
- pred_scores = pred_scores[keep]
- return pred_masks, pred_scores, pred_bboxes
- def setup_model(self, model, verbose=True):
- """Set up YOLO model with specified thresholds and device."""
- device = select_device(self.args.device, verbose=verbose)
- if model is None:
- model = build_sam(self.args.model)
- model.eval()
- self.model = model.to(device)
- self.device = device
- self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device)
- self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device)
- # TODO: Temporary settings for compatibility
- self.model.pt = False
- self.model.triton = False
- self.model.stride = 32
- self.model.fp16 = False
- self.done_warmup = True
- def postprocess(self, preds, img, orig_imgs):
- """Post-processes inference output predictions to create detection masks for objects."""
- # (N, 1, H, W), (N, 1)
- pred_masks, pred_scores = preds[:2]
- pred_bboxes = preds[2] if self.segment_all else None
- names = dict(enumerate(str(i) for i in range(len(pred_masks))))
- results = []
- is_list = isinstance(orig_imgs, list) # input images are a list, not a torch.Tensor
- for i, masks in enumerate([pred_masks]):
- orig_img = orig_imgs[i] if is_list else orig_imgs
- if pred_bboxes is not None:
- pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
- cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
- pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
- masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
- masks = masks > self.model.mask_threshold # to bool
- img_path = self.batch[0][i]
- results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
- # Reset segment-all mode.
- self.segment_all = False
- return results
- def setup_source(self, source):
- """Sets up source and inference mode."""
- if source is not None:
- super().setup_source(source)
- def set_image(self, image):
- """Set image in advance.
- Args:
- image (str | np.ndarray): image file path or np.ndarray image by cv2.
- """
- if self.model is None:
- model = build_sam(self.args.model)
- self.setup_model(model)
- self.setup_source(image)
- assert len(self.dataset) == 1, '`set_image` only supports setting one image!'
- for batch in self.dataset:
- im = self.preprocess(batch[1])
- self.features = self.model.image_encoder(im)
- self.im = im
- break
- def set_prompts(self, prompts):
- """Set prompts in advance."""
- self.prompts = prompts
- def reset_image(self):
- self.im = None
- self.features = None
- @staticmethod
- def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
- """
- Removes small disconnected regions and holes in masks, then reruns
- box NMS to remove any new duplicates. Requires open-cv as a dependency.
- Args:
- masks (torch.Tensor): Masks, (N, H, W).
- min_area (int): Minimum area threshold.
- nms_thresh (float): NMS threshold.
- """
- if len(masks) == 0:
- return masks
- # Filter small disconnected regions and holes
- new_masks = []
- scores = []
- for mask in masks:
- mask = mask.cpu().numpy()
- mask, changed = remove_small_regions(mask, min_area, mode='holes')
- unchanged = not changed
- mask, changed = remove_small_regions(mask, min_area, mode='islands')
- unchanged = unchanged and not changed
- new_masks.append(torch.as_tensor(mask).unsqueeze(0))
- # Give score=0 to changed masks and score=1 to unchanged masks
- # so NMS will prefer ones that didn't need postprocessing
- scores.append(float(unchanged))
- # Recalculate boxes and remove any new duplicates
- new_masks = torch.cat(new_masks, dim=0)
- boxes = batched_mask_to_box(new_masks)
- keep = torchvision.ops.nms(
- boxes.float(),
- torch.as_tensor(scores),
- nms_thresh,
- )
- # Only recalculate masks for masks that have changed
- for i in keep:
- if scores[i] == 0.0:
- masks[i] = new_masks[i]
- return masks[keep]
|