common_pruning.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # -*- coding: utf-8 -*-
  2. # Owner(s): ["module: unknown"]
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. def rows_are_subset(subset_tensor, superset_tensor) -> bool:
  7. """
  8. Checks to see if all rows in subset tensor are present in the superset tensor
  9. """
  10. i = 0
  11. for row in subset_tensor:
  12. while i < len(superset_tensor):
  13. if not torch.equal(row, superset_tensor[i]):
  14. i += 1
  15. else:
  16. break
  17. else:
  18. return False
  19. return True
  20. class SimpleLinear(nn.Module):
  21. r"""Model with only Linear layers without biases, some wrapped in a Sequential,
  22. some following the Sequential. Used to test basic pruned Linear-Linear fusion."""
  23. def __init__(self):
  24. super().__init__()
  25. self.seq = nn.Sequential(
  26. nn.Linear(7, 5, bias=False),
  27. nn.Linear(5, 6, bias=False),
  28. nn.Linear(6, 4, bias=False),
  29. )
  30. self.linear1 = nn.Linear(4, 3, bias=False)
  31. self.linear2 = nn.Linear(3, 10, bias=False)
  32. def forward(self, x):
  33. x = self.seq(x)
  34. x = self.linear1(x)
  35. x = self.linear2(x)
  36. return x
  37. class LinearBias(nn.Module):
  38. r"""Model with only Linear layers, alternating layers with biases,
  39. wrapped in a Sequential. Used to test pruned Linear-Bias-Linear fusion."""
  40. def __init__(self):
  41. super().__init__()
  42. self.seq = nn.Sequential(
  43. nn.Linear(7, 5, bias=True),
  44. nn.Linear(5, 6, bias=False),
  45. nn.Linear(6, 3, bias=True),
  46. nn.Linear(3, 3, bias=True),
  47. nn.Linear(3, 10, bias=False),
  48. )
  49. def forward(self, x):
  50. x = self.seq(x)
  51. return x
  52. class LinearActivation(nn.Module):
  53. r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
  54. Activation functions modules in between each Linear in the Sequential, and each outside layer.
  55. Used to test pruned Linear(Bias)-Activation-Linear fusion."""
  56. def __init__(self):
  57. super().__init__()
  58. self.seq = nn.Sequential(
  59. nn.Linear(7, 5, bias=True),
  60. nn.ReLU(),
  61. nn.Linear(5, 6, bias=False),
  62. nn.Tanh(),
  63. nn.Linear(6, 4, bias=True),
  64. )
  65. self.linear1 = nn.Linear(4, 3, bias=True)
  66. self.act1 = nn.ReLU()
  67. self.linear2 = nn.Linear(3, 10, bias=False)
  68. self.act2 = nn.Tanh()
  69. def forward(self, x):
  70. x = self.seq(x)
  71. x = self.linear1(x)
  72. x = self.act1(x)
  73. x = self.linear2(x)
  74. x = self.act2(x)
  75. return x
  76. class LinearActivationFunctional(nn.Module):
  77. r"""Model with only Linear layers, some with bias, some in a Sequential and some following.
  78. Activation functions modules in between each Linear in the Sequential, and functional
  79. activationals are called in between each outside layer.
  80. Used to test pruned Linear(Bias)-Activation-Linear fusion."""
  81. def __init__(self):
  82. super().__init__()
  83. self.seq = nn.Sequential(
  84. nn.Linear(7, 5, bias=True),
  85. nn.ReLU(),
  86. nn.Linear(5, 6, bias=False),
  87. nn.ReLU(),
  88. nn.Linear(6, 4, bias=True),
  89. )
  90. self.linear1 = nn.Linear(4, 3, bias=True)
  91. self.linear2 = nn.Linear(3, 8, bias=False)
  92. self.linear3 = nn.Linear(8, 10, bias=False)
  93. self.act1 = nn.ReLU()
  94. def forward(self, x):
  95. x = self.seq(x)
  96. x = self.linear1(x)
  97. x = F.relu(x)
  98. x = self.linear2(x)
  99. x = F.relu(x)
  100. x = self.linear3(x)
  101. x = F.relu(x)
  102. return x
  103. class SimpleConv2d(nn.Module):
  104. r"""Model with only Conv2d layers, all without bias, some in a Sequential and some following.
  105. Used to test pruned Conv2d-Conv2d fusion."""
  106. def __init__(self):
  107. super().__init__()
  108. self.seq = nn.Sequential(
  109. nn.Conv2d(1, 32, 3, 1, bias=False),
  110. nn.Conv2d(32, 64, 3, 1, bias=False),
  111. )
  112. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
  113. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
  114. def forward(self, x):
  115. x = self.seq(x)
  116. x = self.conv2d1(x)
  117. x = self.conv2d2(x)
  118. return x
  119. class Conv2dBias(nn.Module):
  120. r"""Model with only Conv2d layers, some with bias, some in a Sequential and some outside.
  121. Used to test pruned Conv2d-Bias-Conv2d fusion."""
  122. def __init__(self):
  123. super().__init__()
  124. self.seq = nn.Sequential(
  125. nn.Conv2d(1, 32, 3, 1, bias=True),
  126. nn.Conv2d(32, 32, 3, 1, bias=True),
  127. nn.Conv2d(32, 64, 3, 1, bias=False),
  128. )
  129. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=True)
  130. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=False)
  131. def forward(self, x):
  132. x = self.seq(x)
  133. x = self.conv2d1(x)
  134. x = self.conv2d2(x)
  135. return x
  136. class Conv2dActivation(nn.Module):
  137. r"""Model with only Conv2d layers, some with bias, some in a Sequential and some following.
  138. Activation function modules in between each Sequential layer, functional activations called
  139. in-between each outside layer.
  140. Used to test pruned Conv2d-Bias-Activation-Conv2d fusion."""
  141. def __init__(self):
  142. super().__init__()
  143. self.seq = nn.Sequential(
  144. nn.Conv2d(1, 32, 3, 1, bias=True),
  145. nn.ReLU(),
  146. nn.Conv2d(32, 64, 3, 1, bias=True),
  147. nn.Tanh(),
  148. nn.Conv2d(64, 64, 3, 1, bias=False),
  149. nn.ReLU(),
  150. )
  151. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, bias=False)
  152. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, bias=True)
  153. def forward(self, x):
  154. x = self.seq(x)
  155. x = self.conv2d1(x)
  156. x = F.relu(x)
  157. x = self.conv2d2(x)
  158. x = F.hardtanh(x)
  159. return x
  160. class Conv2dPadBias(nn.Module):
  161. r"""Model with only Conv2d layers, all with bias and some with padding > 0,
  162. some in a Sequential and some following. Activation function modules in between each layer.
  163. Used to test that bias is propagated correctly in the special case of
  164. pruned Conv2d-Bias-(Activation)Conv2d fusion, when the second Conv2d layer has padding > 0."""
  165. def __init__(self):
  166. super().__init__()
  167. self.seq = nn.Sequential(
  168. nn.Conv2d(1, 32, 3, 1, padding=1, bias=True),
  169. nn.ReLU(),
  170. nn.Conv2d(32, 32, 3, 1, bias=False),
  171. nn.ReLU(),
  172. nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
  173. nn.ReLU(),
  174. nn.Conv2d(32, 32, 3, 1, padding=1, bias=True),
  175. nn.ReLU(),
  176. nn.Conv2d(32, 64, 3, 1, bias=True),
  177. nn.Tanh(),
  178. )
  179. self.conv2d1 = nn.Conv2d(64, 48, 3, 1, padding=1, bias=True)
  180. self.act1 = nn.ReLU()
  181. self.conv2d2 = nn.Conv2d(48, 52, 3, 1, padding=1, bias=True)
  182. self.act2 = nn.Tanh()
  183. def forward(self, x):
  184. x = self.seq(x)
  185. x = self.conv2d1(x)
  186. x = self.act1(x)
  187. x = self.conv2d2(x)
  188. x = self.act2(x)
  189. return x
  190. class Conv2dPool(nn.Module):
  191. r"""Model with only Conv2d layers, all with bias, some in a Sequential and some following.
  192. Activation function modules in between each layer, Pool2d modules in between each layer.
  193. Used to test pruned Conv2d-Pool2d-Conv2d fusion."""
  194. def __init__(self):
  195. super().__init__()
  196. self.seq = nn.Sequential(
  197. nn.Conv2d(1, 32, kernel_size=3, padding=1, bias=True),
  198. nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
  199. nn.ReLU(),
  200. nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True),
  201. nn.Tanh(),
  202. nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
  203. )
  204. self.conv2d1 = nn.Conv2d(64, 48, kernel_size=3, padding=1, bias=True)
  205. self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
  206. self.af1 = nn.ReLU()
  207. self.conv2d2 = nn.Conv2d(48, 52, kernel_size=3, padding=1, bias=True)
  208. self.conv2d3 = nn.Conv2d(52, 52, kernel_size=3, padding=1, bias=True)
  209. def forward(self, x):
  210. x = self.seq(x)
  211. x = self.conv2d1(x)
  212. x = self.maxpool(x)
  213. x = self.af1(x)
  214. x = self.conv2d2(x)
  215. x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=1)
  216. x = F.relu(x)
  217. x = self.conv2d3(x)
  218. return x
  219. class Conv2dPoolFlattenFunctional(nn.Module):
  220. r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
  221. and a functional Flatten followed by a Linear layer.
  222. Activation functions and Pool2ds in between each layer also.
  223. Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
  224. def __init__(self):
  225. super().__init__()
  226. self.seq = nn.Sequential(
  227. nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
  228. nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
  229. nn.ReLU(),
  230. nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
  231. nn.Tanh(),
  232. nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
  233. )
  234. self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
  235. self.af1 = nn.ReLU()
  236. self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
  237. self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
  238. self.fc = nn.Linear(11, 13, bias=True)
  239. def forward(self, x):
  240. x = self.seq(x)
  241. x = self.conv2d1(x)
  242. x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
  243. x = self.af1(x)
  244. x = self.conv2d2(x)
  245. x = self.avg_pool(x)
  246. x = torch.flatten(x, 1) # test functional flatten
  247. x = self.fc(x)
  248. return x
  249. class Conv2dPoolFlatten(nn.Module):
  250. r"""Model with Conv2d layers, all with bias, some in a Sequential and some following, and then a Pool2d
  251. and a Flatten module followed by a Linear layer.
  252. Activation functions and Pool2ds in between each layer also.
  253. Used to test pruned Conv2d-Pool2d-Flatten-Linear fusion."""
  254. def __init__(self):
  255. super().__init__()
  256. self.seq = nn.Sequential(
  257. nn.Conv2d(1, 3, kernel_size=3, padding=1, bias=True),
  258. nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
  259. nn.ReLU(),
  260. nn.Conv2d(3, 5, kernel_size=3, padding=1, bias=True),
  261. nn.Tanh(),
  262. nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
  263. )
  264. self.conv2d1 = nn.Conv2d(5, 7, kernel_size=3, padding=1, bias=True)
  265. self.af1 = nn.ReLU()
  266. self.conv2d2 = nn.Conv2d(7, 11, kernel_size=3, padding=1, bias=True)
  267. self.avg_pool = nn.AdaptiveAvgPool2d((2, 2))
  268. self.flatten = nn.Flatten()
  269. self.fc = nn.Linear(44, 13, bias=True)
  270. def forward(self, x):
  271. x = self.seq(x)
  272. x = self.conv2d1(x)
  273. x = F.max_pool2d(x, kernel_size=2, stride=2, padding=1)
  274. x = self.af1(x)
  275. x = self.conv2d2(x)
  276. x = self.avg_pool(x)
  277. x = self.flatten(x)
  278. x = self.fc(x)
  279. return x
  280. class LSTMLinearModel(nn.Module):
  281. """Container module with an encoder, a recurrent module, and a linear."""
  282. def __init__(
  283. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
  284. ):
  285. super().__init__()
  286. self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
  287. self.linear = nn.Linear(hidden_dim, output_dim)
  288. def forward(self, input):
  289. output, hidden = self.lstm(input)
  290. decoded = self.linear(output)
  291. return decoded, output
  292. class LSTMLayerNormLinearModel(nn.Module):
  293. """Container module with an LSTM, a LayerNorm, and a linear."""
  294. def __init__(
  295. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
  296. ):
  297. super().__init__()
  298. self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers)
  299. self.norm = nn.LayerNorm(hidden_dim)
  300. self.linear = nn.Linear(hidden_dim, output_dim)
  301. def forward(self, x):
  302. x, state = self.lstm(x)
  303. x = self.norm(x)
  304. x = self.linear(x)
  305. return x, state