block.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Block modules
  4. """
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
  9. from .transformer import TransformerBlock
  10. __all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
  11. 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3')
  12. class DFL(nn.Module):
  13. """
  14. Integral module of Distribution Focal Loss (DFL).
  15. Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
  16. """
  17. def __init__(self, c1=16):
  18. """Initialize a convolutional layer with a given number of input channels."""
  19. super().__init__()
  20. self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
  21. x = torch.arange(c1, dtype=torch.float)
  22. self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
  23. self.c1 = c1
  24. def forward(self, x):
  25. """Applies a transformer layer on input tensor 'x' and returns a tensor."""
  26. b, c, a = x.shape # batch, channels, anchors
  27. return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
  28. # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
  29. class Proto(nn.Module):
  30. """YOLOv8 mask Proto module for segmentation models."""
  31. def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
  32. super().__init__()
  33. self.cv1 = Conv(c1, c_, k=3)
  34. self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
  35. self.cv2 = Conv(c_, c_, k=3)
  36. self.cv3 = Conv(c_, c2)
  37. def forward(self, x):
  38. """Performs a forward pass through layers using an upsampled input image."""
  39. return self.cv3(self.cv2(self.upsample(self.cv1(x))))
  40. class HGStem(nn.Module):
  41. """StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
  42. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  43. """
  44. def __init__(self, c1, cm, c2):
  45. super().__init__()
  46. self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
  47. self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
  48. self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
  49. self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
  50. self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
  51. self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
  52. def forward(self, x):
  53. """Forward pass of a PPHGNetV2 backbone layer."""
  54. x = self.stem1(x)
  55. x = F.pad(x, [0, 1, 0, 1])
  56. x2 = self.stem2a(x)
  57. x2 = F.pad(x2, [0, 1, 0, 1])
  58. x2 = self.stem2b(x2)
  59. x1 = self.pool(x)
  60. x = torch.cat([x1, x2], dim=1)
  61. x = self.stem3(x)
  62. x = self.stem4(x)
  63. return x
  64. class HGBlock(nn.Module):
  65. """HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
  66. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  67. """
  68. def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
  69. super().__init__()
  70. block = LightConv if lightconv else Conv
  71. self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
  72. self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
  73. self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
  74. self.add = shortcut and c1 == c2
  75. def forward(self, x):
  76. """Forward pass of a PPHGNetV2 backbone layer."""
  77. y = [x]
  78. y.extend(m(y[-1]) for m in self.m)
  79. y = self.ec(self.sc(torch.cat(y, 1)))
  80. return y + x if self.add else y
  81. class SPP(nn.Module):
  82. """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
  83. def __init__(self, c1, c2, k=(5, 9, 13)):
  84. """Initialize the SPP layer with input/output channels and pooling kernel sizes."""
  85. super().__init__()
  86. c_ = c1 // 2 # hidden channels
  87. self.cv1 = Conv(c1, c_, 1, 1)
  88. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  89. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  90. def forward(self, x):
  91. """Forward pass of the SPP layer, performing spatial pyramid pooling."""
  92. x = self.cv1(x)
  93. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  94. class SPPF(nn.Module):
  95. """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
  96. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  97. super().__init__()
  98. c_ = c1 // 2 # hidden channels
  99. self.cv1 = Conv(c1, c_, 1, 1)
  100. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  101. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  102. def forward(self, x):
  103. """Forward pass through Ghost Convolution block."""
  104. x = self.cv1(x)
  105. y1 = self.m(x)
  106. y2 = self.m(y1)
  107. return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
  108. class C1(nn.Module):
  109. """CSP Bottleneck with 1 convolution."""
  110. def __init__(self, c1, c2, n=1): # ch_in, ch_out, number
  111. super().__init__()
  112. self.cv1 = Conv(c1, c2, 1, 1)
  113. self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
  114. def forward(self, x):
  115. """Applies cross-convolutions to input in the C3 module."""
  116. y = self.cv1(x)
  117. return self.m(y) + y
  118. class C2(nn.Module):
  119. """CSP Bottleneck with 2 convolutions."""
  120. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  121. super().__init__()
  122. self.c = int(c2 * e) # hidden channels
  123. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  124. self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
  125. # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
  126. self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
  127. def forward(self, x):
  128. """Forward pass through the CSP bottleneck with 2 convolutions."""
  129. a, b = self.cv1(x).chunk(2, 1)
  130. return self.cv2(torch.cat((self.m(a), b), 1))
  131. class C2f(nn.Module):
  132. """Faster Implementation of CSP Bottleneck with 2 convolutions."""
  133. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  134. super().__init__()
  135. self.c = int(c2 * e) # hidden channels
  136. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  137. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  138. self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
  139. def forward(self, x):
  140. """Forward pass through C2f layer."""
  141. y = list(self.cv1(x).chunk(2, 1))
  142. y.extend(m(y[-1]) for m in self.m)
  143. return self.cv2(torch.cat(y, 1))
  144. def forward_split(self, x):
  145. """Forward pass using split() instead of chunk()."""
  146. y = list(self.cv1(x).split((self.c, self.c), 1))
  147. y.extend(m(y[-1]) for m in self.m)
  148. return self.cv2(torch.cat(y, 1))
  149. class C3(nn.Module):
  150. """CSP Bottleneck with 3 convolutions."""
  151. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  152. super().__init__()
  153. c_ = int(c2 * e) # hidden channels
  154. self.cv1 = Conv(c1, c_, 1, 1)
  155. self.cv2 = Conv(c1, c_, 1, 1)
  156. self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
  157. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
  158. def forward(self, x):
  159. """Forward pass through the CSP bottleneck with 2 convolutions."""
  160. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  161. class C3x(C3):
  162. """C3 module with cross-convolutions."""
  163. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  164. """Initialize C3TR instance and set default parameters."""
  165. super().__init__(c1, c2, n, shortcut, g, e)
  166. self.c_ = int(c2 * e)
  167. self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
  168. class RepC3(nn.Module):
  169. """Rep C3."""
  170. def __init__(self, c1, c2, n=3, e=1.0):
  171. super().__init__()
  172. c_ = int(c2 * e) # hidden channels
  173. self.cv1 = Conv(c1, c2, 1, 1)
  174. self.cv2 = Conv(c1, c2, 1, 1)
  175. self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
  176. self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
  177. def forward(self, x):
  178. """Forward pass of RT-DETR neck layer."""
  179. return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
  180. class C3TR(C3):
  181. """C3 module with TransformerBlock()."""
  182. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  183. """Initialize C3Ghost module with GhostBottleneck()."""
  184. super().__init__(c1, c2, n, shortcut, g, e)
  185. c_ = int(c2 * e)
  186. self.m = TransformerBlock(c_, c_, 4, n)
  187. class C3Ghost(C3):
  188. """C3 module with GhostBottleneck()."""
  189. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  190. """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
  191. super().__init__(c1, c2, n, shortcut, g, e)
  192. c_ = int(c2 * e) # hidden channels
  193. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  194. class GhostBottleneck(nn.Module):
  195. """Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
  196. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  197. super().__init__()
  198. c_ = c2 // 2
  199. self.conv = nn.Sequential(
  200. GhostConv(c1, c_, 1, 1), # pw
  201. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  202. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  203. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
  204. act=False)) if s == 2 else nn.Identity()
  205. def forward(self, x):
  206. """Applies skip connection and concatenation to input tensor."""
  207. return self.conv(x) + self.shortcut(x)
  208. class Bottleneck(nn.Module):
  209. """Standard bottleneck."""
  210. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
  211. super().__init__()
  212. c_ = int(c2 * e) # hidden channels
  213. self.cv1 = Conv(c1, c_, k[0], 1)
  214. self.cv2 = Conv(c_, c2, k[1], 1, g=g)
  215. self.add = shortcut and c1 == c2
  216. def forward(self, x):
  217. """'forward()' applies the YOLOv5 FPN to input data."""
  218. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  219. class BottleneckCSP(nn.Module):
  220. """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
  221. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  222. super().__init__()
  223. c_ = int(c2 * e) # hidden channels
  224. self.cv1 = Conv(c1, c_, 1, 1)
  225. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  226. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  227. self.cv4 = Conv(2 * c_, c2, 1, 1)
  228. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  229. self.act = nn.SiLU()
  230. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  231. def forward(self, x):
  232. """Applies a CSP bottleneck with 3 convolutions."""
  233. y1 = self.cv3(self.m(self.cv1(x)))
  234. y2 = self.cv2(x)
  235. return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))