prompt.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import os
  3. from pathlib import Path
  4. import cv2
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import torch
  8. from PIL import Image
  9. from ultralytics.utils import LOGGER
  10. class FastSAMPrompt:
  11. def __init__(self, img_path, results, device='cuda') -> None:
  12. # self.img_path = img_path
  13. self.device = device
  14. self.results = results
  15. self.img_path = str(img_path)
  16. self.ori_img = cv2.imread(self.img_path)
  17. # Import and assign clip
  18. try:
  19. import clip # for linear_assignment
  20. except ImportError:
  21. from ultralytics.utils.checks import check_requirements
  22. check_requirements('git+https://github.com/openai/CLIP.git')
  23. import clip
  24. self.clip = clip
  25. @staticmethod
  26. def _segment_image(image, bbox):
  27. image_array = np.array(image)
  28. segmented_image_array = np.zeros_like(image_array)
  29. x1, y1, x2, y2 = bbox
  30. segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
  31. segmented_image = Image.fromarray(segmented_image_array)
  32. black_image = Image.new('RGB', image.size, (255, 255, 255))
  33. # transparency_mask = np.zeros_like((), dtype=np.uint8)
  34. transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
  35. transparency_mask[y1:y2, x1:x2] = 255
  36. transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
  37. black_image.paste(segmented_image, mask=transparency_mask_image)
  38. return black_image
  39. @staticmethod
  40. def _format_results(result, filter=0):
  41. annotations = []
  42. n = len(result.masks.data)
  43. for i in range(n):
  44. mask = result.masks.data[i] == 1.0
  45. if torch.sum(mask) < filter:
  46. continue
  47. annotation = {
  48. 'id': i,
  49. 'segmentation': mask.cpu().numpy(),
  50. 'bbox': result.boxes.data[i],
  51. 'score': result.boxes.conf[i]}
  52. annotation['area'] = annotation['segmentation'].sum()
  53. annotations.append(annotation)
  54. return annotations
  55. @staticmethod
  56. def filter_masks(annotations): # filter the overlap mask
  57. annotations.sort(key=lambda x: x['area'], reverse=True)
  58. to_remove = set()
  59. for i in range(len(annotations)):
  60. a = annotations[i]
  61. for j in range(i + 1, len(annotations)):
  62. b = annotations[j]
  63. if i != j and j not in to_remove and b['area'] < a['area'] and \
  64. (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8:
  65. to_remove.add(j)
  66. return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
  67. @staticmethod
  68. def _get_bbox_from_mask(mask):
  69. mask = mask.astype(np.uint8)
  70. contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  71. x1, y1, w, h = cv2.boundingRect(contours[0])
  72. x2, y2 = x1 + w, y1 + h
  73. if len(contours) > 1:
  74. for b in contours:
  75. x_t, y_t, w_t, h_t = cv2.boundingRect(b)
  76. # 将多个bbox合并成一个
  77. x1 = min(x1, x_t)
  78. y1 = min(y1, y_t)
  79. x2 = max(x2, x_t + w_t)
  80. y2 = max(y2, y_t + h_t)
  81. return [x1, y1, x2, y2]
  82. def plot(self,
  83. annotations,
  84. output,
  85. bbox=None,
  86. points=None,
  87. point_label=None,
  88. mask_random_color=True,
  89. better_quality=True,
  90. retina=False,
  91. with_countouers=True):
  92. if isinstance(annotations[0], dict):
  93. annotations = [annotation['segmentation'] for annotation in annotations]
  94. if isinstance(annotations, torch.Tensor):
  95. annotations = annotations.cpu().numpy()
  96. result_name = os.path.basename(self.img_path)
  97. image = self.ori_img
  98. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  99. original_h = image.shape[0]
  100. original_w = image.shape[1]
  101. # for macOS only
  102. # plt.switch_backend('TkAgg')
  103. fig = plt.figure(figsize=(original_w / 100, original_h / 100))
  104. # Add subplot with no margin.
  105. plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
  106. plt.margins(0, 0)
  107. plt.gca().xaxis.set_major_locator(plt.NullLocator())
  108. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  109. plt.imshow(image)
  110. if better_quality:
  111. for i, mask in enumerate(annotations):
  112. mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
  113. annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
  114. self.fast_show_mask(
  115. annotations,
  116. plt.gca(),
  117. random_color=mask_random_color,
  118. bbox=bbox,
  119. points=points,
  120. pointlabel=point_label,
  121. retinamask=retina,
  122. target_height=original_h,
  123. target_width=original_w,
  124. )
  125. if with_countouers:
  126. contour_all = []
  127. temp = np.zeros((original_h, original_w, 1))
  128. for i, mask in enumerate(annotations):
  129. if isinstance(mask, dict):
  130. mask = mask['segmentation']
  131. annotation = mask.astype(np.uint8)
  132. if not retina:
  133. annotation = cv2.resize(
  134. annotation,
  135. (original_w, original_h),
  136. interpolation=cv2.INTER_NEAREST,
  137. )
  138. contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
  139. contour_all.extend(iter(contours))
  140. cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
  141. color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
  142. contour_mask = temp / 255 * color.reshape(1, 1, -1)
  143. plt.imshow(contour_mask)
  144. save_path = Path(output) / result_name
  145. save_path.parent.mkdir(exist_ok=True, parents=True)
  146. plt.axis('off')
  147. fig.savefig(save_path)
  148. LOGGER.info(f'Saved to {save_path.absolute()}')
  149. # CPU post process
  150. @staticmethod
  151. def fast_show_mask(
  152. annotation,
  153. ax,
  154. random_color=False,
  155. bbox=None,
  156. points=None,
  157. pointlabel=None,
  158. retinamask=True,
  159. target_height=960,
  160. target_width=960,
  161. ):
  162. n, h, w = annotation.shape # batch, height, width
  163. areas = np.sum(annotation, axis=(1, 2))
  164. annotation = annotation[np.argsort(areas)]
  165. index = (annotation != 0).argmax(axis=0)
  166. if random_color:
  167. color = np.random.random((n, 1, 1, 3))
  168. else:
  169. color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
  170. transparency = np.ones((n, 1, 1, 1)) * 0.6
  171. visual = np.concatenate([color, transparency], axis=-1)
  172. mask_image = np.expand_dims(annotation, -1) * visual
  173. show = np.zeros((h, w, 4))
  174. h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
  175. indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
  176. show[h_indices, w_indices, :] = mask_image[indices]
  177. if bbox is not None:
  178. x1, y1, x2, y2 = bbox
  179. ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
  180. # Draw point
  181. if points is not None:
  182. plt.scatter(
  183. [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
  184. [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
  185. s=20,
  186. c='y',
  187. )
  188. plt.scatter(
  189. [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
  190. [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
  191. s=20,
  192. c='m',
  193. )
  194. if not retinamask:
  195. show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
  196. ax.imshow(show)
  197. @torch.no_grad()
  198. def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
  199. preprocessed_images = [preprocess(image).to(device) for image in elements]
  200. tokenized_text = self.clip.tokenize([search_text]).to(device)
  201. stacked_images = torch.stack(preprocessed_images)
  202. image_features = model.encode_image(stacked_images)
  203. text_features = model.encode_text(tokenized_text)
  204. image_features /= image_features.norm(dim=-1, keepdim=True)
  205. text_features /= text_features.norm(dim=-1, keepdim=True)
  206. probs = 100.0 * image_features @ text_features.T
  207. return probs[:, 0].softmax(dim=0)
  208. def _crop_image(self, format_results):
  209. image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB))
  210. ori_w, ori_h = image.size
  211. annotations = format_results
  212. mask_h, mask_w = annotations[0]['segmentation'].shape
  213. if ori_w != mask_w or ori_h != mask_h:
  214. image = image.resize((mask_w, mask_h))
  215. cropped_boxes = []
  216. cropped_images = []
  217. not_crop = []
  218. filter_id = []
  219. # annotations, _ = filter_masks(annotations)
  220. # filter_id = list(_)
  221. for _, mask in enumerate(annotations):
  222. if np.sum(mask['segmentation']) <= 100:
  223. filter_id.append(_)
  224. continue
  225. bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox
  226. cropped_boxes.append(self._segment_image(image, bbox)) # 保存裁剪的图片
  227. # cropped_boxes.append(segment_image(image,mask["segmentation"]))
  228. cropped_images.append(bbox) # 保存裁剪的图片的bbox
  229. return cropped_boxes, cropped_images, not_crop, filter_id, annotations
  230. def box_prompt(self, bbox):
  231. assert (bbox[2] != 0 and bbox[3] != 0)
  232. masks = self.results[0].masks.data
  233. target_height = self.ori_img.shape[0]
  234. target_width = self.ori_img.shape[1]
  235. h = masks.shape[1]
  236. w = masks.shape[2]
  237. if h != target_height or w != target_width:
  238. bbox = [
  239. int(bbox[0] * w / target_width),
  240. int(bbox[1] * h / target_height),
  241. int(bbox[2] * w / target_width),
  242. int(bbox[3] * h / target_height), ]
  243. bbox[0] = max(round(bbox[0]), 0)
  244. bbox[1] = max(round(bbox[1]), 0)
  245. bbox[2] = min(round(bbox[2]), w)
  246. bbox[3] = min(round(bbox[3]), h)
  247. # IoUs = torch.zeros(len(masks), dtype=torch.float32)
  248. bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
  249. masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
  250. orig_masks_area = torch.sum(masks, dim=(1, 2))
  251. union = bbox_area + orig_masks_area - masks_area
  252. IoUs = masks_area / union
  253. max_iou_index = torch.argmax(IoUs)
  254. return np.array([masks[max_iou_index].cpu().numpy()])
  255. def point_prompt(self, points, pointlabel): # numpy 处理
  256. masks = self._format_results(self.results[0], 0)
  257. target_height = self.ori_img.shape[0]
  258. target_width = self.ori_img.shape[1]
  259. h = masks[0]['segmentation'].shape[0]
  260. w = masks[0]['segmentation'].shape[1]
  261. if h != target_height or w != target_width:
  262. points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
  263. onemask = np.zeros((h, w))
  264. for i, annotation in enumerate(masks):
  265. mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
  266. for i, point in enumerate(points):
  267. if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
  268. onemask += mask
  269. if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
  270. onemask -= mask
  271. onemask = onemask >= 1
  272. return np.array([onemask])
  273. def text_prompt(self, text):
  274. format_results = self._format_results(self.results[0], 0)
  275. cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
  276. clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
  277. scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
  278. max_idx = scores.argsort()
  279. max_idx = max_idx[-1]
  280. max_idx += sum(np.array(filter_id) <= int(max_idx))
  281. return np.array([annotations[max_idx]['segmentation']])
  282. def everything_prompt(self):
  283. return self.results[0].masks.data