conv.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Convolution modules
  4. """
  5. import math
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. __all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
  10. 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
  11. def autopad(k, p=None, d=1): # kernel, padding, dilation
  12. """Pad to 'same' shape outputs."""
  13. if d > 1:
  14. k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
  15. if p is None:
  16. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  17. return p
  18. class Conv(nn.Module):
  19. """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
  20. default_act = nn.SiLU() # default activation
  21. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
  22. """Initialize Conv layer with given arguments including activation."""
  23. super().__init__()
  24. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
  25. self.bn = nn.BatchNorm2d(c2)
  26. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  27. def forward(self, x):
  28. """Apply convolution, batch normalization and activation to input tensor."""
  29. return self.act(self.bn(self.conv(x)))
  30. def forward_fuse(self, x):
  31. """Perform transposed convolution of 2D data."""
  32. return self.act(self.conv(x))
  33. class Conv2(Conv):
  34. """Simplified RepConv module with Conv fusing."""
  35. def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
  36. """Initialize Conv layer with given arguments including activation."""
  37. super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
  38. self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
  39. def forward(self, x):
  40. """Apply convolution, batch normalization and activation to input tensor."""
  41. return self.act(self.bn(self.conv(x) + self.cv2(x)))
  42. def forward_fuse(self, x):
  43. """Apply fused convolution, batch normalization and activation to input tensor."""
  44. return self.act(self.bn(self.conv(x)))
  45. def fuse_convs(self):
  46. """Fuse parallel convolutions."""
  47. w = torch.zeros_like(self.conv.weight.data)
  48. i = [x // 2 for x in w.shape[2:]]
  49. w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
  50. self.conv.weight.data += w
  51. self.__delattr__('cv2')
  52. self.forward = self.forward_fuse
  53. class LightConv(nn.Module):
  54. """Light convolution with args(ch_in, ch_out, kernel).
  55. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
  56. """
  57. def __init__(self, c1, c2, k=1, act=nn.ReLU()):
  58. """Initialize Conv layer with given arguments including activation."""
  59. super().__init__()
  60. self.conv1 = Conv(c1, c2, 1, act=False)
  61. self.conv2 = DWConv(c2, c2, k, act=act)
  62. def forward(self, x):
  63. """Apply 2 convolutions to input tensor."""
  64. return self.conv2(self.conv1(x))
  65. class DWConv(Conv):
  66. """Depth-wise convolution."""
  67. def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
  68. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
  69. class DWConvTranspose2d(nn.ConvTranspose2d):
  70. """Depth-wise transpose convolution."""
  71. def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
  72. super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
  73. class ConvTranspose(nn.Module):
  74. """Convolution transpose 2d layer."""
  75. default_act = nn.SiLU() # default activation
  76. def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
  77. """Initialize ConvTranspose2d layer with batch normalization and activation function."""
  78. super().__init__()
  79. self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
  80. self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
  81. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  82. def forward(self, x):
  83. """Applies transposed convolutions, batch normalization and activation to input."""
  84. return self.act(self.bn(self.conv_transpose(x)))
  85. def forward_fuse(self, x):
  86. """Applies activation and convolution transpose operation to input."""
  87. return self.act(self.conv_transpose(x))
  88. class Focus(nn.Module):
  89. """Focus wh information into c-space."""
  90. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  91. super().__init__()
  92. self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
  93. # self.contract = Contract(gain=2)
  94. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  95. return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
  96. # return self.conv(self.contract(x))
  97. class GhostConv(nn.Module):
  98. """Ghost Convolution https://github.com/huawei-noah/ghostnet."""
  99. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  100. super().__init__()
  101. c_ = c2 // 2 # hidden channels
  102. self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
  103. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
  104. def forward(self, x):
  105. """Forward propagation through a Ghost Bottleneck layer with skip connection."""
  106. y = self.cv1(x)
  107. return torch.cat((y, self.cv2(y)), 1)
  108. class RepConv(nn.Module):
  109. """
  110. RepConv is a basic rep-style block, including training and deploy status. This module is used in RT-DETR.
  111. Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  112. """
  113. default_act = nn.SiLU() # default activation
  114. def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
  115. super().__init__()
  116. assert k == 3 and p == 1
  117. self.g = g
  118. self.c1 = c1
  119. self.c2 = c2
  120. self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
  121. self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
  122. self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
  123. self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
  124. def forward_fuse(self, x):
  125. """Forward process"""
  126. return self.act(self.conv(x))
  127. def forward(self, x):
  128. """Forward process"""
  129. id_out = 0 if self.bn is None else self.bn(x)
  130. return self.act(self.conv1(x) + self.conv2(x) + id_out)
  131. def get_equivalent_kernel_bias(self):
  132. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
  133. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
  134. kernelid, biasid = self._fuse_bn_tensor(self.bn)
  135. return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  136. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  137. if kernel1x1 is None:
  138. return 0
  139. else:
  140. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  141. def _fuse_bn_tensor(self, branch):
  142. if branch is None:
  143. return 0, 0
  144. if isinstance(branch, Conv):
  145. kernel = branch.conv.weight
  146. running_mean = branch.bn.running_mean
  147. running_var = branch.bn.running_var
  148. gamma = branch.bn.weight
  149. beta = branch.bn.bias
  150. eps = branch.bn.eps
  151. elif isinstance(branch, nn.BatchNorm2d):
  152. if not hasattr(self, 'id_tensor'):
  153. input_dim = self.c1 // self.g
  154. kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
  155. for i in range(self.c1):
  156. kernel_value[i, i % input_dim, 1, 1] = 1
  157. self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
  158. kernel = self.id_tensor
  159. running_mean = branch.running_mean
  160. running_var = branch.running_var
  161. gamma = branch.weight
  162. beta = branch.bias
  163. eps = branch.eps
  164. std = (running_var + eps).sqrt()
  165. t = (gamma / std).reshape(-1, 1, 1, 1)
  166. return kernel * t, beta - running_mean * gamma / std
  167. def fuse_convs(self):
  168. if hasattr(self, 'conv'):
  169. return
  170. kernel, bias = self.get_equivalent_kernel_bias()
  171. self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
  172. out_channels=self.conv1.conv.out_channels,
  173. kernel_size=self.conv1.conv.kernel_size,
  174. stride=self.conv1.conv.stride,
  175. padding=self.conv1.conv.padding,
  176. dilation=self.conv1.conv.dilation,
  177. groups=self.conv1.conv.groups,
  178. bias=True).requires_grad_(False)
  179. self.conv.weight.data = kernel
  180. self.conv.bias.data = bias
  181. for para in self.parameters():
  182. para.detach_()
  183. self.__delattr__('conv1')
  184. self.__delattr__('conv2')
  185. if hasattr(self, 'nm'):
  186. self.__delattr__('nm')
  187. if hasattr(self, 'bn'):
  188. self.__delattr__('bn')
  189. if hasattr(self, 'id_tensor'):
  190. self.__delattr__('id_tensor')
  191. class ChannelAttention(nn.Module):
  192. """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
  193. def __init__(self, channels: int) -> None:
  194. super().__init__()
  195. self.pool = nn.AdaptiveAvgPool2d(1)
  196. self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
  197. self.act = nn.Sigmoid()
  198. def forward(self, x: torch.Tensor) -> torch.Tensor:
  199. return x * self.act(self.fc(self.pool(x)))
  200. class SpatialAttention(nn.Module):
  201. """Spatial-attention module."""
  202. def __init__(self, kernel_size=7):
  203. """Initialize Spatial-attention module with kernel size argument."""
  204. super().__init__()
  205. assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  206. padding = 3 if kernel_size == 7 else 1
  207. self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  208. self.act = nn.Sigmoid()
  209. def forward(self, x):
  210. """Apply channel and spatial attention on input for feature recalibration."""
  211. return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
  212. class CBAM(nn.Module):
  213. """Convolutional Block Attention Module."""
  214. def __init__(self, c1, kernel_size=7): # ch_in, kernels
  215. super().__init__()
  216. self.channel_attention = ChannelAttention(c1)
  217. self.spatial_attention = SpatialAttention(kernel_size)
  218. def forward(self, x):
  219. """Applies the forward pass through C1 module."""
  220. return self.spatial_attention(self.channel_attention(x))
  221. class Concat(nn.Module):
  222. """Concatenate a list of tensors along dimension."""
  223. def __init__(self, dimension=1):
  224. """Concatenates a list of tensors along a specified dimension."""
  225. super().__init__()
  226. self.d = dimension
  227. def forward(self, x):
  228. """Forward pass for the YOLOv8 mask Proto module."""
  229. return torch.cat(x, self.d)