conv.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.modules.utils import _single, _pair, _triple
  4. from torch.ao.nn.intrinsic import _FusedModule
  5. from typing import Tuple, TypeVar, Union
  6. from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
  7. __all__ = [
  8. "Conv1d",
  9. "Conv2d",
  10. "Conv3d"
  11. ]
  12. MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
  13. class _ConvNd(nn.modules.conv._ConvNd):
  14. _FLOAT_MODULE = MOD
  15. def __init__(self,
  16. in_channels: int,
  17. out_channels: int,
  18. kernel_size: Tuple[int, ...],
  19. stride: Tuple[int, ...],
  20. padding: Tuple[int, ...],
  21. dilation: Tuple[int, ...],
  22. transposed: bool,
  23. output_padding: Tuple[int, ...],
  24. groups: int,
  25. bias: bool,
  26. padding_mode: str,
  27. qconfig=None,
  28. device=None,
  29. dtype=None) -> None:
  30. factory_kwargs = {"device": device, "dtype": dtype}
  31. nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
  32. stride, padding, dilation, transposed,
  33. output_padding, groups, bias, padding_mode, **factory_kwargs)
  34. assert qconfig, 'qconfig must be provided for QAT module'
  35. self.qconfig = qconfig
  36. self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
  37. def forward(self, input):
  38. return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
  39. @staticmethod
  40. def from_float(cls, mod):
  41. r"""Create a qat module from a float module
  42. Args:
  43. `mod`: a float module, either produced by torch.ao.quantization utilities
  44. or directly from user
  45. """
  46. assert type(mod) == cls._FLOAT_MODULE, (
  47. "qat."
  48. + cls.__name__
  49. + ".from_float only works for "
  50. + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
  51. )
  52. assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
  53. assert mod.qconfig, 'Input float module must have a valid qconfig'
  54. if issubclass(type(mod), _FusedModule):
  55. mod = mod[0] # type: ignore[index]
  56. qconfig = mod.qconfig
  57. qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
  58. stride=mod.stride, padding=mod.padding, dilation=mod.dilation,
  59. groups=mod.groups, bias=mod.bias is not None,
  60. padding_mode=mod.padding_mode, qconfig=qconfig)
  61. qat_conv.weight = mod.weight
  62. qat_conv.bias = mod.bias
  63. return qat_conv
  64. def to_float(self):
  65. """ This works for both single qat conv, and the qat conv - relu modules
  66. to convert the qat module to a floating point module
  67. """
  68. cls = type(self)
  69. conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined, operator]
  70. self.in_channels,
  71. self.out_channels,
  72. self.kernel_size, # type: ignore[arg-type]
  73. self.stride, # type: ignore[arg-type]
  74. self.padding, # type: ignore[arg-type]
  75. self.dilation, # type: ignore[arg-type]
  76. self.groups,
  77. self.bias is not None,
  78. self.padding_mode)
  79. conv.weight = torch.nn.Parameter(self.weight.detach())
  80. if self.bias is not None:
  81. conv.bias = torch.nn.Parameter(self.bias.detach())
  82. # conv relu
  83. if issubclass(cls, _FusedModule):
  84. modules = [conv]
  85. assert hasattr(cls, "_FLOAT_RELU_MODULE")
  86. relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
  87. modules.append(relu)
  88. fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator]
  89. fused.train(self.training)
  90. return fused
  91. else:
  92. return conv
  93. class Conv1d(_ConvNd, nn.Conv1d):
  94. r"""
  95. A Conv1d module attached with FakeQuantize modules for weight,
  96. used for quantization aware training.
  97. We adopt the same interface as :class:`~torch.nn.Conv1d`
  98. Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
  99. default.
  100. Attributes:
  101. weight_fake_quant: fake quant module for weight
  102. """
  103. _FLOAT_MODULE = nn.Conv1d
  104. _FLOAT_CONV_MODULE = nn.Conv1d
  105. def __init__(self,
  106. in_channels: int,
  107. out_channels: int,
  108. kernel_size: _size_1_t,
  109. stride: _size_1_t = 1,
  110. padding: Union[str, _size_1_t] = 0,
  111. dilation: _size_1_t = 1,
  112. groups: int = 1,
  113. bias: bool = True,
  114. padding_mode: str = 'zeros',
  115. qconfig=None,
  116. device=None,
  117. dtype=None) -> None:
  118. kernel_size_ = _single(kernel_size)
  119. stride_ = _single(stride)
  120. padding_ = padding if isinstance(padding, str) else _single(padding)
  121. dilation_ = _single(dilation)
  122. super().__init__(
  123. in_channels,
  124. out_channels,
  125. kernel_size_,
  126. stride=stride_,
  127. padding=padding_,
  128. dilation=dilation_,
  129. transposed=False,
  130. output_padding=_single(0),
  131. groups=groups,
  132. bias=bias,
  133. padding_mode=padding_mode,
  134. qconfig=qconfig,
  135. device=device,
  136. dtype=dtype)
  137. @classmethod
  138. def from_float(cls, mod):
  139. return super().from_float(cls, mod)
  140. class Conv2d(_ConvNd, nn.Conv2d):
  141. r"""
  142. A Conv2d module attached with FakeQuantize modules for weight,
  143. used for quantization aware training.
  144. We adopt the same interface as `torch.nn.Conv2d`, please see
  145. https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
  146. for documentation.
  147. Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
  148. default.
  149. Attributes:
  150. weight_fake_quant: fake quant module for weight
  151. """
  152. _FLOAT_MODULE = nn.Conv2d
  153. _FLOAT_CONV_MODULE = nn.Conv2d
  154. def __init__(self,
  155. in_channels: int,
  156. out_channels: int,
  157. kernel_size: _size_2_t,
  158. stride: _size_2_t = 1,
  159. padding: Union[str, _size_2_t] = 0,
  160. dilation: _size_2_t = 1,
  161. groups: int = 1,
  162. bias: bool = True,
  163. padding_mode: str = 'zeros',
  164. qconfig=None,
  165. device=None,
  166. dtype=None) -> None:
  167. kernel_size_ = _pair(kernel_size)
  168. stride_ = _pair(stride)
  169. padding_ = padding if isinstance(padding, str) else _pair(padding)
  170. dilation_ = _pair(dilation)
  171. super().__init__(
  172. in_channels,
  173. out_channels,
  174. kernel_size_,
  175. stride=stride_,
  176. padding=padding_,
  177. dilation=dilation_,
  178. transposed=False,
  179. output_padding=_pair(0),
  180. groups=groups,
  181. bias=bias,
  182. padding_mode=padding_mode,
  183. qconfig=qconfig,
  184. device=device,
  185. dtype=dtype)
  186. def forward(self, input):
  187. return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
  188. @classmethod
  189. def from_float(cls, mod):
  190. return super().from_float(cls, mod)
  191. class Conv3d(_ConvNd, nn.Conv3d):
  192. r"""
  193. A Conv3d module attached with FakeQuantize modules for weight,
  194. used for quantization aware training.
  195. We adopt the same interface as `torch.nn.Conv3d`, please see
  196. https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
  197. for documentation.
  198. Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
  199. default.
  200. Attributes:
  201. weight_fake_quant: fake quant module for weight
  202. """
  203. _FLOAT_MODULE = nn.Conv3d
  204. _FLOAT_CONV_MODULE = nn.Conv3d
  205. def __init__(self,
  206. in_channels: int,
  207. out_channels: int,
  208. kernel_size: _size_3_t,
  209. stride: _size_3_t = 1,
  210. padding: Union[str, _size_3_t] = 0,
  211. dilation: _size_3_t = 1,
  212. groups: int = 1,
  213. bias: bool = True,
  214. padding_mode: str = 'zeros',
  215. qconfig=None,
  216. device=None,
  217. dtype=None) -> None:
  218. kernel_size_ = _triple(kernel_size)
  219. stride_ = _triple(stride)
  220. padding_ = padding if isinstance(padding, str) else _triple(padding)
  221. dilation_ = _triple(dilation)
  222. super().__init__(
  223. in_channels,
  224. out_channels,
  225. kernel_size_,
  226. stride=stride_,
  227. padding=padding_,
  228. dilation=dilation_,
  229. transposed=False,
  230. output_padding=_triple(0),
  231. groups=groups,
  232. bias=bias,
  233. padding_mode=padding_mode,
  234. qconfig=qconfig,
  235. device=device,
  236. dtype=dtype)
  237. def forward(self, input):
  238. return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
  239. @classmethod
  240. def from_float(cls, mod):
  241. return super().from_float(cls, mod)