tiny_encoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. # --------------------------------------------------------
  3. # TinyViT Model Architecture
  4. # Copyright (c) 2022 Microsoft
  5. # Adapted from LeViT and Swin Transformer
  6. # LeViT: (https://github.com/facebookresearch/levit)
  7. # Swin: (https://github.com/microsoft/swin-transformer)
  8. # Build the TinyViT Model
  9. # --------------------------------------------------------
  10. import itertools
  11. from typing import Tuple
  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. import torch.utils.checkpoint as checkpoint
  16. from ultralytics.utils.instance import to_2tuple
  17. class Conv2d_BN(torch.nn.Sequential):
  18. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
  19. super().__init__()
  20. self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
  21. bn = torch.nn.BatchNorm2d(b)
  22. torch.nn.init.constant_(bn.weight, bn_weight_init)
  23. torch.nn.init.constant_(bn.bias, 0)
  24. self.add_module('bn', bn)
  25. @torch.no_grad()
  26. def fuse(self):
  27. c, bn = self._modules.values()
  28. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  29. w = c.weight * w[:, None, None, None]
  30. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  31. m = torch.nn.Conv2d(w.size(1) * self.c.groups,
  32. w.size(0),
  33. w.shape[2:],
  34. stride=self.c.stride,
  35. padding=self.c.padding,
  36. dilation=self.c.dilation,
  37. groups=self.c.groups)
  38. m.weight.data.copy_(w)
  39. m.bias.data.copy_(b)
  40. return m
  41. # NOTE: This module and timm package is needed only for training.
  42. # from ultralytics.utils.checks import check_requirements
  43. # check_requirements('timm')
  44. # from timm.models.layers import DropPath as TimmDropPath
  45. # from timm.models.layers import trunc_normal_
  46. # class DropPath(TimmDropPath):
  47. #
  48. # def __init__(self, drop_prob=None):
  49. # super().__init__(drop_prob=drop_prob)
  50. # self.drop_prob = drop_prob
  51. #
  52. # def __repr__(self):
  53. # msg = super().__repr__()
  54. # msg += f'(drop_prob={self.drop_prob})'
  55. # return msg
  56. class PatchEmbed(nn.Module):
  57. def __init__(self, in_chans, embed_dim, resolution, activation):
  58. super().__init__()
  59. img_size: Tuple[int, int] = to_2tuple(resolution)
  60. self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
  61. self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
  62. self.in_chans = in_chans
  63. self.embed_dim = embed_dim
  64. n = embed_dim
  65. self.seq = nn.Sequential(
  66. Conv2d_BN(in_chans, n // 2, 3, 2, 1),
  67. activation(),
  68. Conv2d_BN(n // 2, n, 3, 2, 1),
  69. )
  70. def forward(self, x):
  71. return self.seq(x)
  72. class MBConv(nn.Module):
  73. def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
  74. super().__init__()
  75. self.in_chans = in_chans
  76. self.hidden_chans = int(in_chans * expand_ratio)
  77. self.out_chans = out_chans
  78. self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
  79. self.act1 = activation()
  80. self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
  81. self.act2 = activation()
  82. self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
  83. self.act3 = activation()
  84. # NOTE: `DropPath` is needed only for training.
  85. # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  86. self.drop_path = nn.Identity()
  87. def forward(self, x):
  88. shortcut = x
  89. x = self.conv1(x)
  90. x = self.act1(x)
  91. x = self.conv2(x)
  92. x = self.act2(x)
  93. x = self.conv3(x)
  94. x = self.drop_path(x)
  95. x += shortcut
  96. return self.act3(x)
  97. class PatchMerging(nn.Module):
  98. def __init__(self, input_resolution, dim, out_dim, activation):
  99. super().__init__()
  100. self.input_resolution = input_resolution
  101. self.dim = dim
  102. self.out_dim = out_dim
  103. self.act = activation()
  104. self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
  105. stride_c = 1 if out_dim in [320, 448, 576] else 2
  106. self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
  107. self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
  108. def forward(self, x):
  109. if x.ndim == 3:
  110. H, W = self.input_resolution
  111. B = len(x)
  112. # (B, C, H, W)
  113. x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
  114. x = self.conv1(x)
  115. x = self.act(x)
  116. x = self.conv2(x)
  117. x = self.act(x)
  118. x = self.conv3(x)
  119. return x.flatten(2).transpose(1, 2)
  120. class ConvLayer(nn.Module):
  121. def __init__(
  122. self,
  123. dim,
  124. input_resolution,
  125. depth,
  126. activation,
  127. drop_path=0.,
  128. downsample=None,
  129. use_checkpoint=False,
  130. out_dim=None,
  131. conv_expand_ratio=4.,
  132. ):
  133. super().__init__()
  134. self.dim = dim
  135. self.input_resolution = input_resolution
  136. self.depth = depth
  137. self.use_checkpoint = use_checkpoint
  138. # build blocks
  139. self.blocks = nn.ModuleList([
  140. MBConv(
  141. dim,
  142. dim,
  143. conv_expand_ratio,
  144. activation,
  145. drop_path[i] if isinstance(drop_path, list) else drop_path,
  146. ) for i in range(depth)])
  147. # patch merging layer
  148. self.downsample = None if downsample is None else downsample(
  149. input_resolution, dim=dim, out_dim=out_dim, activation=activation)
  150. def forward(self, x):
  151. for blk in self.blocks:
  152. x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
  153. return x if self.downsample is None else self.downsample(x)
  154. class Mlp(nn.Module):
  155. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  156. super().__init__()
  157. out_features = out_features or in_features
  158. hidden_features = hidden_features or in_features
  159. self.norm = nn.LayerNorm(in_features)
  160. self.fc1 = nn.Linear(in_features, hidden_features)
  161. self.fc2 = nn.Linear(hidden_features, out_features)
  162. self.act = act_layer()
  163. self.drop = nn.Dropout(drop)
  164. def forward(self, x):
  165. x = self.norm(x)
  166. x = self.fc1(x)
  167. x = self.act(x)
  168. x = self.drop(x)
  169. x = self.fc2(x)
  170. return self.drop(x)
  171. class Attention(torch.nn.Module):
  172. def __init__(
  173. self,
  174. dim,
  175. key_dim,
  176. num_heads=8,
  177. attn_ratio=4,
  178. resolution=(14, 14),
  179. ):
  180. super().__init__()
  181. # (h, w)
  182. assert isinstance(resolution, tuple) and len(resolution) == 2
  183. self.num_heads = num_heads
  184. self.scale = key_dim ** -0.5
  185. self.key_dim = key_dim
  186. self.nh_kd = nh_kd = key_dim * num_heads
  187. self.d = int(attn_ratio * key_dim)
  188. self.dh = int(attn_ratio * key_dim) * num_heads
  189. self.attn_ratio = attn_ratio
  190. h = self.dh + nh_kd * 2
  191. self.norm = nn.LayerNorm(dim)
  192. self.qkv = nn.Linear(dim, h)
  193. self.proj = nn.Linear(self.dh, dim)
  194. points = list(itertools.product(range(resolution[0]), range(resolution[1])))
  195. N = len(points)
  196. attention_offsets = {}
  197. idxs = []
  198. for p1 in points:
  199. for p2 in points:
  200. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  201. if offset not in attention_offsets:
  202. attention_offsets[offset] = len(attention_offsets)
  203. idxs.append(attention_offsets[offset])
  204. self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
  205. self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N), persistent=False)
  206. @torch.no_grad()
  207. def train(self, mode=True):
  208. super().train(mode)
  209. if mode and hasattr(self, 'ab'):
  210. del self.ab
  211. else:
  212. self.ab = self.attention_biases[:, self.attention_bias_idxs]
  213. def forward(self, x): # x (B,N,C)
  214. B, N, _ = x.shape
  215. # Normalization
  216. x = self.norm(x)
  217. qkv = self.qkv(x)
  218. # (B, N, num_heads, d)
  219. q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
  220. # (B, num_heads, N, d)
  221. q = q.permute(0, 2, 1, 3)
  222. k = k.permute(0, 2, 1, 3)
  223. v = v.permute(0, 2, 1, 3)
  224. self.ab = self.ab.to(self.attention_biases.device)
  225. attn = ((q @ k.transpose(-2, -1)) * self.scale +
  226. (self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
  227. attn = attn.softmax(dim=-1)
  228. x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
  229. return self.proj(x)
  230. class TinyViTBlock(nn.Module):
  231. """
  232. TinyViT Block.
  233. Args:
  234. dim (int): Number of input channels.
  235. input_resolution (tuple[int, int]): Input resolution.
  236. num_heads (int): Number of attention heads.
  237. window_size (int): Window size.
  238. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  239. drop (float, optional): Dropout rate. Default: 0.0
  240. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  241. local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
  242. activation (torch.nn): the activation function. Default: nn.GELU
  243. """
  244. def __init__(
  245. self,
  246. dim,
  247. input_resolution,
  248. num_heads,
  249. window_size=7,
  250. mlp_ratio=4.,
  251. drop=0.,
  252. drop_path=0.,
  253. local_conv_size=3,
  254. activation=nn.GELU,
  255. ):
  256. super().__init__()
  257. self.dim = dim
  258. self.input_resolution = input_resolution
  259. self.num_heads = num_heads
  260. assert window_size > 0, 'window_size must be greater than 0'
  261. self.window_size = window_size
  262. self.mlp_ratio = mlp_ratio
  263. # NOTE: `DropPath` is needed only for training.
  264. # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  265. self.drop_path = nn.Identity()
  266. assert dim % num_heads == 0, 'dim must be divisible by num_heads'
  267. head_dim = dim // num_heads
  268. window_resolution = (window_size, window_size)
  269. self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
  270. mlp_hidden_dim = int(dim * mlp_ratio)
  271. mlp_activation = activation
  272. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop)
  273. pad = local_conv_size // 2
  274. self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
  275. def forward(self, x):
  276. H, W = self.input_resolution
  277. B, L, C = x.shape
  278. assert L == H * W, 'input feature has wrong size'
  279. res_x = x
  280. if H == self.window_size and W == self.window_size:
  281. x = self.attn(x)
  282. else:
  283. x = x.view(B, H, W, C)
  284. pad_b = (self.window_size - H % self.window_size) % self.window_size
  285. pad_r = (self.window_size - W % self.window_size) % self.window_size
  286. padding = pad_b > 0 or pad_r > 0
  287. if padding:
  288. x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  289. pH, pW = H + pad_b, W + pad_r
  290. nH = pH // self.window_size
  291. nW = pW // self.window_size
  292. # window partition
  293. x = x.view(B, nH, self.window_size, nW, self.window_size,
  294. C).transpose(2, 3).reshape(B * nH * nW, self.window_size * self.window_size, C)
  295. x = self.attn(x)
  296. # window reverse
  297. x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
  298. if padding:
  299. x = x[:, :H, :W].contiguous()
  300. x = x.view(B, L, C)
  301. x = res_x + self.drop_path(x)
  302. x = x.transpose(1, 2).reshape(B, C, H, W)
  303. x = self.local_conv(x)
  304. x = x.view(B, C, L).transpose(1, 2)
  305. return x + self.drop_path(self.mlp(x))
  306. def extra_repr(self) -> str:
  307. return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
  308. f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
  309. class BasicLayer(nn.Module):
  310. """
  311. A basic TinyViT layer for one stage.
  312. Args:
  313. dim (int): Number of input channels.
  314. input_resolution (tuple[int]): Input resolution.
  315. depth (int): Number of blocks.
  316. num_heads (int): Number of attention heads.
  317. window_size (int): Local window size.
  318. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  319. drop (float, optional): Dropout rate. Default: 0.0
  320. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  321. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  322. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  323. local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
  324. activation (torch.nn): the activation function. Default: nn.GELU
  325. out_dim (int | optional): the output dimension of the layer. Default: None
  326. """
  327. def __init__(
  328. self,
  329. dim,
  330. input_resolution,
  331. depth,
  332. num_heads,
  333. window_size,
  334. mlp_ratio=4.,
  335. drop=0.,
  336. drop_path=0.,
  337. downsample=None,
  338. use_checkpoint=False,
  339. local_conv_size=3,
  340. activation=nn.GELU,
  341. out_dim=None,
  342. ):
  343. super().__init__()
  344. self.dim = dim
  345. self.input_resolution = input_resolution
  346. self.depth = depth
  347. self.use_checkpoint = use_checkpoint
  348. # build blocks
  349. self.blocks = nn.ModuleList([
  350. TinyViTBlock(
  351. dim=dim,
  352. input_resolution=input_resolution,
  353. num_heads=num_heads,
  354. window_size=window_size,
  355. mlp_ratio=mlp_ratio,
  356. drop=drop,
  357. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  358. local_conv_size=local_conv_size,
  359. activation=activation,
  360. ) for i in range(depth)])
  361. # patch merging layer
  362. self.downsample = None if downsample is None else downsample(
  363. input_resolution, dim=dim, out_dim=out_dim, activation=activation)
  364. def forward(self, x):
  365. for blk in self.blocks:
  366. x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
  367. return x if self.downsample is None else self.downsample(x)
  368. def extra_repr(self) -> str:
  369. return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
  370. class LayerNorm2d(nn.Module):
  371. def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
  372. super().__init__()
  373. self.weight = nn.Parameter(torch.ones(num_channels))
  374. self.bias = nn.Parameter(torch.zeros(num_channels))
  375. self.eps = eps
  376. def forward(self, x: torch.Tensor) -> torch.Tensor:
  377. u = x.mean(1, keepdim=True)
  378. s = (x - u).pow(2).mean(1, keepdim=True)
  379. x = (x - u) / torch.sqrt(s + self.eps)
  380. return self.weight[:, None, None] * x + self.bias[:, None, None]
  381. class TinyViT(nn.Module):
  382. def __init__(
  383. self,
  384. img_size=224,
  385. in_chans=3,
  386. num_classes=1000,
  387. embed_dims=[96, 192, 384, 768],
  388. depths=[2, 2, 6, 2],
  389. num_heads=[3, 6, 12, 24],
  390. window_sizes=[7, 7, 14, 7],
  391. mlp_ratio=4.,
  392. drop_rate=0.,
  393. drop_path_rate=0.1,
  394. use_checkpoint=False,
  395. mbconv_expand_ratio=4.0,
  396. local_conv_size=3,
  397. layer_lr_decay=1.0,
  398. ):
  399. super().__init__()
  400. self.img_size = img_size
  401. self.num_classes = num_classes
  402. self.depths = depths
  403. self.num_layers = len(depths)
  404. self.mlp_ratio = mlp_ratio
  405. activation = nn.GELU
  406. self.patch_embed = PatchEmbed(in_chans=in_chans,
  407. embed_dim=embed_dims[0],
  408. resolution=img_size,
  409. activation=activation)
  410. patches_resolution = self.patch_embed.patches_resolution
  411. self.patches_resolution = patches_resolution
  412. # stochastic depth
  413. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  414. # build layers
  415. self.layers = nn.ModuleList()
  416. for i_layer in range(self.num_layers):
  417. kwargs = dict(
  418. dim=embed_dims[i_layer],
  419. input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
  420. patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
  421. # input_resolution=(patches_resolution[0] // (2 ** i_layer),
  422. # patches_resolution[1] // (2 ** i_layer)),
  423. depth=depths[i_layer],
  424. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
  425. downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
  426. use_checkpoint=use_checkpoint,
  427. out_dim=embed_dims[min(i_layer + 1,
  428. len(embed_dims) - 1)],
  429. activation=activation,
  430. )
  431. if i_layer == 0:
  432. layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
  433. else:
  434. layer = BasicLayer(num_heads=num_heads[i_layer],
  435. window_size=window_sizes[i_layer],
  436. mlp_ratio=self.mlp_ratio,
  437. drop=drop_rate,
  438. local_conv_size=local_conv_size,
  439. **kwargs)
  440. self.layers.append(layer)
  441. # Classifier head
  442. self.norm_head = nn.LayerNorm(embed_dims[-1])
  443. self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
  444. # init weights
  445. self.apply(self._init_weights)
  446. self.set_layer_lr_decay(layer_lr_decay)
  447. self.neck = nn.Sequential(
  448. nn.Conv2d(
  449. embed_dims[-1],
  450. 256,
  451. kernel_size=1,
  452. bias=False,
  453. ),
  454. LayerNorm2d(256),
  455. nn.Conv2d(
  456. 256,
  457. 256,
  458. kernel_size=3,
  459. padding=1,
  460. bias=False,
  461. ),
  462. LayerNorm2d(256),
  463. )
  464. def set_layer_lr_decay(self, layer_lr_decay):
  465. decay_rate = layer_lr_decay
  466. # layers -> blocks (depth)
  467. depth = sum(self.depths)
  468. lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
  469. def _set_lr_scale(m, scale):
  470. for p in m.parameters():
  471. p.lr_scale = scale
  472. self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
  473. i = 0
  474. for layer in self.layers:
  475. for block in layer.blocks:
  476. block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
  477. i += 1
  478. if layer.downsample is not None:
  479. layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
  480. assert i == depth
  481. for m in [self.norm_head, self.head]:
  482. m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
  483. for k, p in self.named_parameters():
  484. p.param_name = k
  485. def _check_lr_scale(m):
  486. for p in m.parameters():
  487. assert hasattr(p, 'lr_scale'), p.param_name
  488. self.apply(_check_lr_scale)
  489. def _init_weights(self, m):
  490. if isinstance(m, nn.Linear):
  491. # NOTE: This initialization is needed only for training.
  492. # trunc_normal_(m.weight, std=.02)
  493. if m.bias is not None:
  494. nn.init.constant_(m.bias, 0)
  495. elif isinstance(m, nn.LayerNorm):
  496. nn.init.constant_(m.bias, 0)
  497. nn.init.constant_(m.weight, 1.0)
  498. @torch.jit.ignore
  499. def no_weight_decay_keywords(self):
  500. return {'attention_biases'}
  501. def forward_features(self, x):
  502. # x: (N, C, H, W)
  503. x = self.patch_embed(x)
  504. x = self.layers[0](x)
  505. start_i = 1
  506. for i in range(start_i, len(self.layers)):
  507. layer = self.layers[i]
  508. x = layer(x)
  509. B, _, C = x.size()
  510. x = x.view(B, 64, 64, C)
  511. x = x.permute(0, 3, 1, 2)
  512. return self.neck(x)
  513. def forward(self, x):
  514. return self.forward_features(x)