fused.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import torch
  2. from torch.nn import Conv1d, Conv2d, Conv3d, ReLU, Linear, BatchNorm1d, BatchNorm2d, BatchNorm3d
  3. from torch.nn.utils.parametrize import type_before_parametrizations
  4. __all__ = ['ConvReLU1d', 'ConvReLU2d', 'ConvReLU3d', 'LinearReLU', 'ConvBn1d', 'ConvBn2d',
  5. 'ConvBnReLU1d', 'ConvBnReLU2d', 'ConvBn3d', 'ConvBnReLU3d', 'BNReLU2d', 'BNReLU3d',
  6. 'LinearBn1d', 'LinearLeakyReLU', 'LinearTanh', 'ConvAdd2d', 'ConvAddReLU2d']
  7. # Used for identifying intrinsic modules used in quantization
  8. class _FusedModule(torch.nn.Sequential):
  9. pass
  10. class ConvReLU1d(_FusedModule):
  11. r"""This is a sequential container which calls the Conv1d and ReLU modules.
  12. During quantization this will be replaced with the corresponding fused module."""
  13. def __init__(self, conv, relu):
  14. assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(relu) == ReLU, \
  15. 'Incorrect types for input modules{}{}'.format(
  16. type_before_parametrizations(conv), type_before_parametrizations(relu))
  17. super().__init__(conv, relu)
  18. class ConvReLU2d(_FusedModule):
  19. r"""This is a sequential container which calls the Conv2d and ReLU modules.
  20. During quantization this will be replaced with the corresponding fused module."""
  21. def __init__(self, conv, relu):
  22. assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(relu) == ReLU, \
  23. 'Incorrect types for input modules{}{}'.format(
  24. type_before_parametrizations(conv), type_before_parametrizations(relu))
  25. super().__init__(conv, relu)
  26. class ConvReLU3d(_FusedModule):
  27. r"""This is a sequential container which calls the Conv3d and ReLU modules.
  28. During quantization this will be replaced with the corresponding fused module."""
  29. def __init__(self, conv, relu):
  30. assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(relu) == ReLU, \
  31. 'Incorrect types for input modules{}{}'.format(
  32. type_before_parametrizations(conv), type_before_parametrizations(relu))
  33. super().__init__(conv, relu)
  34. class LinearReLU(_FusedModule):
  35. r"""This is a sequential container which calls the Linear and ReLU modules.
  36. During quantization this will be replaced with the corresponding fused module."""
  37. def __init__(self, linear, relu):
  38. assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(relu) == ReLU, \
  39. 'Incorrect types for input modules{}{}'.format(
  40. type_before_parametrizations(linear), type_before_parametrizations(relu))
  41. super().__init__(linear, relu)
  42. class ConvBn1d(_FusedModule):
  43. r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules.
  44. During quantization this will be replaced with the corresponding fused module."""
  45. def __init__(self, conv, bn):
  46. assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d, \
  47. 'Incorrect types for input modules{}{}'.format(
  48. type_before_parametrizations(conv), type_before_parametrizations(bn))
  49. super().__init__(conv, bn)
  50. class ConvBn2d(_FusedModule):
  51. r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules.
  52. During quantization this will be replaced with the corresponding fused module."""
  53. def __init__(self, conv, bn):
  54. assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d, \
  55. 'Incorrect types for input modules{}{}'.format(
  56. type_before_parametrizations(conv), type_before_parametrizations(bn))
  57. super().__init__(conv, bn)
  58. class ConvBnReLU1d(_FusedModule):
  59. r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules.
  60. During quantization this will be replaced with the corresponding fused module."""
  61. def __init__(self, conv, bn, relu):
  62. assert type_before_parametrizations(conv) == Conv1d and type_before_parametrizations(bn) == BatchNorm1d and \
  63. type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
  64. .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
  65. super().__init__(conv, bn, relu)
  66. class ConvBnReLU2d(_FusedModule):
  67. r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules.
  68. During quantization this will be replaced with the corresponding fused module."""
  69. def __init__(self, conv, bn, relu):
  70. assert type_before_parametrizations(conv) == Conv2d and type_before_parametrizations(bn) == BatchNorm2d and \
  71. type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
  72. .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
  73. super().__init__(conv, bn, relu)
  74. class ConvBn3d(_FusedModule):
  75. r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules.
  76. During quantization this will be replaced with the corresponding fused module."""
  77. def __init__(self, conv, bn):
  78. assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d, \
  79. 'Incorrect types for input modules{}{}'.format(
  80. type_before_parametrizations(conv), type_before_parametrizations(bn))
  81. super().__init__(conv, bn)
  82. class ConvBnReLU3d(_FusedModule):
  83. r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules.
  84. During quantization this will be replaced with the corresponding fused module."""
  85. def __init__(self, conv, bn, relu):
  86. assert type_before_parametrizations(conv) == Conv3d and type_before_parametrizations(bn) == BatchNorm3d and \
  87. type_before_parametrizations(relu) == ReLU, 'Incorrect types for input modules{}{}{}' \
  88. .format(type_before_parametrizations(conv), type_before_parametrizations(bn), type_before_parametrizations(relu))
  89. super().__init__(conv, bn, relu)
  90. class BNReLU2d(_FusedModule):
  91. r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules.
  92. During quantization this will be replaced with the corresponding fused module."""
  93. def __init__(self, batch_norm, relu):
  94. assert type_before_parametrizations(batch_norm) == BatchNorm2d and type_before_parametrizations(relu) == ReLU, \
  95. 'Incorrect types for input modules{}{}'.format(
  96. type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
  97. super().__init__(batch_norm, relu)
  98. class BNReLU3d(_FusedModule):
  99. r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules.
  100. During quantization this will be replaced with the corresponding fused module."""
  101. def __init__(self, batch_norm, relu):
  102. assert type_before_parametrizations(batch_norm) == BatchNorm3d and type_before_parametrizations(relu) == ReLU, \
  103. 'Incorrect types for input modules{}{}'.format(
  104. type_before_parametrizations(batch_norm), type_before_parametrizations(relu))
  105. super().__init__(batch_norm, relu)
  106. class LinearBn1d(_FusedModule):
  107. r"""This is a sequential container which calls the Linear and BatchNorm1d modules.
  108. During quantization this will be replaced with the corresponding fused module."""
  109. def __init__(self, linear, bn):
  110. assert type_before_parametrizations(linear) == Linear and type_before_parametrizations(bn) == BatchNorm1d, \
  111. 'Incorrect types for input modules{}{}'.format(type_before_parametrizations(linear), type_before_parametrizations(bn))
  112. super().__init__(linear, bn)
  113. class LinearLeakyReLU(_FusedModule):
  114. r"""This is a sequential container which calls the Linear and LeakyReLU modules.
  115. During quantization this will be replaced with the corresponding fused module."""
  116. def __init__(self, linear, leaky_relu):
  117. assert type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU, \
  118. 'Incorrect types for input modules{}{}'.format(
  119. type(linear), type(leaky_relu))
  120. super().__init__(linear, leaky_relu)
  121. class LinearTanh(_FusedModule):
  122. r"""This is a sequential container which calls the Linear and Tanh modules.
  123. During quantization this will be replaced with the corresponding fused module."""
  124. def __init__(self, linear, tanh):
  125. assert type(linear) == Linear and type(tanh) == torch.nn.Tanh, \
  126. 'Incorrect types for input modules{}{}'.format(
  127. type(linear), type(tanh))
  128. super().__init__(linear, tanh)
  129. class ConvAdd2d(_FusedModule):
  130. r"""This is a sequential container which calls the Conv2d modules with extra Add.
  131. During quantization this will be replaced with the corresponding fused module."""
  132. def __init__(self, conv, add):
  133. super().__init__(conv)
  134. self.add = add
  135. def forward(self, x1, x2):
  136. return self.add(self[0](x1), x2)
  137. class ConvAddReLU2d(_FusedModule):
  138. r"""This is a sequential container which calls the Conv2d, add, Relu.
  139. During quantization this will be replaced with the corresponding fused module."""
  140. def __init__(self, conv, add, relu):
  141. super().__init__(conv)
  142. self.add = add
  143. self.relu = relu
  144. def forward(self, x1, x2):
  145. return self.relu(self.add(self[0](x1), x2))