head.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Model head modules
  4. """
  5. import math
  6. import torch
  7. import torch.nn as nn
  8. from torch.nn.init import constant_, xavier_uniform_
  9. from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
  10. from .block import DFL, Proto
  11. from .conv import Conv
  12. from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
  13. from .utils import bias_init_with_prob, linear_init_
  14. __all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
  15. class Detect(nn.Module):
  16. """YOLOv8 Detect head for detection models."""
  17. dynamic = False # force grid reconstruction
  18. export = False # export mode
  19. shape = None
  20. anchors = torch.empty(0) # init
  21. strides = torch.empty(0) # init
  22. def __init__(self, nc=80, ch=()): # detection layer
  23. super().__init__()
  24. self.nc = nc # number of classes
  25. self.nl = len(ch) # number of detection layers
  26. self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
  27. self.no = nc + self.reg_max * 4 # number of outputs per anchor
  28. self.stride = torch.zeros(self.nl) # strides computed during build
  29. c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
  30. self.cv2 = nn.ModuleList(
  31. nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
  32. self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
  33. self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
  34. def forward(self, x):
  35. """Concatenates and returns predicted bounding boxes and class probabilities."""
  36. shape = x[0].shape # BCHW
  37. for i in range(self.nl):
  38. x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
  39. if self.training:
  40. return x
  41. elif self.dynamic or self.shape != shape:
  42. self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
  43. self.shape = shape
  44. x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
  45. if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
  46. box = x_cat[:, :self.reg_max * 4]
  47. cls = x_cat[:, self.reg_max * 4:]
  48. else:
  49. box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
  50. dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
  51. if self.export and self.format in ('tflite', 'edgetpu'):
  52. # Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
  53. # https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
  54. # See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
  55. img_h = shape[2] * self.stride[0]
  56. img_w = shape[3] * self.stride[0]
  57. img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
  58. dbox /= img_size
  59. y = torch.cat((dbox, cls.sigmoid()), 1)
  60. return y if self.export else (y, x)
  61. def bias_init(self):
  62. """Initialize Detect() biases, WARNING: requires stride availability."""
  63. m = self # self.model[-1] # Detect() module
  64. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
  65. # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
  66. for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
  67. a[-1].bias.data[:] = 1.0 # box
  68. b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
  69. class Segment(Detect):
  70. """YOLOv8 Segment head for segmentation models."""
  71. def __init__(self, nc=80, nm=32, npr=256, ch=()):
  72. """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
  73. super().__init__(nc, ch)
  74. self.nm = nm # number of masks
  75. self.npr = npr # number of protos
  76. self.proto = Proto(ch[0], self.npr, self.nm) # protos
  77. self.detect = Detect.forward
  78. c4 = max(ch[0] // 4, self.nm)
  79. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
  80. def forward(self, x):
  81. """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
  82. p = self.proto(x[0]) # mask protos
  83. bs = p.shape[0] # batch size
  84. mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
  85. x = self.detect(self, x)
  86. if self.training:
  87. return x, mc, p
  88. return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
  89. class Pose(Detect):
  90. """YOLOv8 Pose head for keypoints models."""
  91. def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
  92. """Initialize YOLO network with default parameters and Convolutional Layers."""
  93. super().__init__(nc, ch)
  94. self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
  95. self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
  96. self.detect = Detect.forward
  97. c4 = max(ch[0] // 4, self.nk)
  98. self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
  99. def forward(self, x):
  100. """Perform forward pass through YOLO model and return predictions."""
  101. bs = x[0].shape[0] # batch size
  102. kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
  103. x = self.detect(self, x)
  104. if self.training:
  105. return x, kpt
  106. pred_kpt = self.kpts_decode(bs, kpt)
  107. return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
  108. def kpts_decode(self, bs, kpts):
  109. """Decodes keypoints."""
  110. ndim = self.kpt_shape[1]
  111. if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
  112. y = kpts.view(bs, *self.kpt_shape, -1)
  113. a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
  114. if ndim == 3:
  115. a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
  116. return a.view(bs, self.nk, -1)
  117. else:
  118. y = kpts.clone()
  119. if ndim == 3:
  120. y[:, 2::3].sigmoid_() # inplace sigmoid
  121. y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
  122. y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
  123. return y
  124. class Classify(nn.Module):
  125. """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
  126. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  127. super().__init__()
  128. c_ = 1280 # efficientnet_b0 size
  129. self.conv = Conv(c1, c_, k, s, p, g)
  130. self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
  131. self.drop = nn.Dropout(p=0.0, inplace=True)
  132. self.linear = nn.Linear(c_, c2) # to x(b,c2)
  133. def forward(self, x):
  134. """Performs a forward pass of the YOLO model on input image data."""
  135. if isinstance(x, list):
  136. x = torch.cat(x, 1)
  137. x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
  138. return x if self.training else x.softmax(1)
  139. class RTDETRDecoder(nn.Module):
  140. export = False # export mode
  141. def __init__(
  142. self,
  143. nc=80,
  144. ch=(512, 1024, 2048),
  145. hd=256, # hidden dim
  146. nq=300, # num queries
  147. ndp=4, # num decoder points
  148. nh=8, # num head
  149. ndl=6, # num decoder layers
  150. d_ffn=1024, # dim of feedforward
  151. dropout=0.,
  152. act=nn.ReLU(),
  153. eval_idx=-1,
  154. # training args
  155. nd=100, # num denoising
  156. label_noise_ratio=0.5,
  157. box_noise_scale=1.0,
  158. learnt_init_query=False):
  159. super().__init__()
  160. self.hidden_dim = hd
  161. self.nhead = nh
  162. self.nl = len(ch) # num level
  163. self.nc = nc
  164. self.num_queries = nq
  165. self.num_decoder_layers = ndl
  166. # backbone feature projection
  167. self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
  168. # NOTE: simplified version but it's not consistent with .pt weights.
  169. # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
  170. # Transformer module
  171. decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
  172. self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
  173. # denoising part
  174. self.denoising_class_embed = nn.Embedding(nc, hd)
  175. self.num_denoising = nd
  176. self.label_noise_ratio = label_noise_ratio
  177. self.box_noise_scale = box_noise_scale
  178. # decoder embedding
  179. self.learnt_init_query = learnt_init_query
  180. if learnt_init_query:
  181. self.tgt_embed = nn.Embedding(nq, hd)
  182. self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
  183. # encoder head
  184. self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
  185. self.enc_score_head = nn.Linear(hd, nc)
  186. self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
  187. # decoder head
  188. self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
  189. self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
  190. self._reset_parameters()
  191. def forward(self, x, batch=None):
  192. from ultralytics.models.utils.ops import get_cdn_group
  193. # input projection and embedding
  194. feats, shapes = self._get_encoder_input(x)
  195. # prepare denoising training
  196. dn_embed, dn_bbox, attn_mask, dn_meta = \
  197. get_cdn_group(batch,
  198. self.nc,
  199. self.num_queries,
  200. self.denoising_class_embed.weight,
  201. self.num_denoising,
  202. self.label_noise_ratio,
  203. self.box_noise_scale,
  204. self.training)
  205. embed, refer_bbox, enc_bboxes, enc_scores = \
  206. self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
  207. # decoder
  208. dec_bboxes, dec_scores = self.decoder(embed,
  209. refer_bbox,
  210. feats,
  211. shapes,
  212. self.dec_bbox_head,
  213. self.dec_score_head,
  214. self.query_pos_head,
  215. attn_mask=attn_mask)
  216. x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
  217. if self.training:
  218. return x
  219. # (bs, 300, 4+nc)
  220. y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
  221. return y if self.export else (y, x)
  222. def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
  223. anchors = []
  224. for i, (h, w) in enumerate(shapes):
  225. sy = torch.arange(end=h, dtype=dtype, device=device)
  226. sx = torch.arange(end=w, dtype=dtype, device=device)
  227. grid_y, grid_x = torch.meshgrid(sy, sx, indexing='ij') if TORCH_1_10 else torch.meshgrid(sy, sx)
  228. grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
  229. valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
  230. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
  231. wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
  232. anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
  233. anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
  234. valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
  235. anchors = torch.log(anchors / (1 - anchors))
  236. anchors = anchors.masked_fill(~valid_mask, float('inf'))
  237. return anchors, valid_mask
  238. def _get_encoder_input(self, x):
  239. # get projection features
  240. x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
  241. # get encoder inputs
  242. feats = []
  243. shapes = []
  244. for feat in x:
  245. h, w = feat.shape[2:]
  246. # [b, c, h, w] -> [b, h*w, c]
  247. feats.append(feat.flatten(2).permute(0, 2, 1))
  248. # [nl, 2]
  249. shapes.append([h, w])
  250. # [b, h*w, c]
  251. feats = torch.cat(feats, 1)
  252. return feats, shapes
  253. def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
  254. bs = len(feats)
  255. # prepare input for decoder
  256. anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
  257. features = self.enc_output(valid_mask * feats) # bs, h*w, 256
  258. enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
  259. # query selection
  260. # (bs, num_queries)
  261. topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
  262. # (bs, num_queries)
  263. batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
  264. # (bs, num_queries, 256)
  265. top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  266. # (bs, num_queries, 4)
  267. top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
  268. # dynamic anchors + static content
  269. refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
  270. enc_bboxes = refer_bbox.sigmoid()
  271. if dn_bbox is not None:
  272. refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
  273. enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
  274. embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
  275. if self.training:
  276. refer_bbox = refer_bbox.detach()
  277. if not self.learnt_init_query:
  278. embeddings = embeddings.detach()
  279. if dn_embed is not None:
  280. embeddings = torch.cat([dn_embed, embeddings], 1)
  281. return embeddings, refer_bbox, enc_bboxes, enc_scores
  282. # TODO
  283. def _reset_parameters(self):
  284. # class and bbox head init
  285. bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
  286. # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
  287. # linear_init_(self.enc_score_head)
  288. constant_(self.enc_score_head.bias, bias_cls)
  289. constant_(self.enc_bbox_head.layers[-1].weight, 0.)
  290. constant_(self.enc_bbox_head.layers[-1].bias, 0.)
  291. for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
  292. # linear_init_(cls_)
  293. constant_(cls_.bias, bias_cls)
  294. constant_(reg_.layers[-1].weight, 0.)
  295. constant_(reg_.layers[-1].bias, 0.)
  296. linear_init_(self.enc_output[0])
  297. xavier_uniform_(self.enc_output[0].weight)
  298. if self.learnt_init_query:
  299. xavier_uniform_(self.tgt_embed.weight)
  300. xavier_uniform_(self.query_pos_head.layers[0].weight)
  301. xavier_uniform_(self.query_pos_head.layers[1].weight)
  302. for layer in self.input_proj:
  303. xavier_uniform_(layer[0].weight)