conv.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # coding=utf-8
  2. r"""Dynamically quantized convolution modules."""
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from torch import Tensor
  7. from torch._ops import ops
  8. from torch.nn.common_types import _size_1_t
  9. from torch.nn.modules.utils import _single, _pair, _triple
  10. from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
  11. import torch.ao.nn.quantized as nnq
  12. import warnings
  13. __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
  14. class Conv1d(nnq.Conv1d):
  15. r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
  16. For details on input arguments, parameters, and implementation see
  17. :class:`~torch.nn.Conv1d` and :class:`~torch.ao.nn.quantized.dynamic.Conv1d` and
  18. Attributes:
  19. weight (Tensor): packed tensor derived from the learnable weight
  20. parameter.
  21. scale (Tensor): scalar for the output scale
  22. zero_point (Tensor): scalar for the output zero point
  23. See :class:`~torch.nn.Conv1d` for other attributes.
  24. Examples::
  25. >>> # xdoctest: +SKIP
  26. >>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
  27. >>> input = torch.randn(20, 16, 100)
  28. >>> output = m(input)
  29. """
  30. _FLOAT_MODULE = nn.Conv1d
  31. _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
  32. _NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
  33. def __init__(self,
  34. in_channels: int,
  35. out_channels: int,
  36. kernel_size: _size_1_t,
  37. stride: _size_1_t = 1,
  38. padding: _size_1_t = 0,
  39. dilation: _size_1_t = 1,
  40. groups: int = 1,
  41. bias: bool = True,
  42. padding_mode: str = 'zeros',
  43. device=None,
  44. dtype=None,
  45. reduce_range=True):
  46. warnings.warn(
  47. "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
  48. self._get_name()
  49. )
  50. )
  51. factory_kwargs = {'device': device, 'dtype': dtype}
  52. kernel_size = _single(kernel_size)
  53. stride = _single(stride)
  54. padding = padding if isinstance(padding, str) else _single(padding)
  55. dilation = _single(dilation)
  56. super().__init__(
  57. in_channels, out_channels, kernel_size, stride, padding, dilation,
  58. groups, bias, padding_mode, **factory_kwargs)
  59. def _get_name(self):
  60. return 'DynamicQuantizedConv1d'
  61. def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
  62. # Temporarily using len(shape) instead of ndim due to JIT issue
  63. # https://github.com/pytorch/pytorch/issues/23890
  64. if len(input.shape) != 3:
  65. raise ValueError("Input shape must be `(N, C, L)`!")
  66. if self.padding_mode != 'zeros':
  67. # Padding in Conv1d is stored as (p, p), need to get (p,)
  68. _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
  69. input = F.pad(input, _reversed_padding_repeated_twice,
  70. mode=self.padding_mode)
  71. return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range)
  72. class Conv2d(nnq.Conv2d):
  73. r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
  74. For details on input arguments, parameters, and implementation see
  75. :class:`~torch.nn.Conv2d` and :class:`~torch.ao.nn.quantized.dynamic.Conv2d` and
  76. Attributes:
  77. weight (Tensor): packed tensor derived from the learnable weight
  78. parameter.
  79. scale (Tensor): scalar for the output scale
  80. zero_point (Tensor): scalar for the output zero point
  81. See :class:`~torch.nn.Conv2d` for other attributes.
  82. Examples::
  83. >>> # xdoctest: +SKIP
  84. >>> # With square kernels and equal stride
  85. >>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
  86. >>> # non-square kernels and unequal stride and with padding
  87. >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  88. >>> # non-square kernels and unequal stride and with padding and dilation
  89. >>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
  90. >>> input = torch.randn(20, 16, 50, 100)
  91. >>> output = m(input)
  92. """
  93. _FLOAT_MODULE = nn.Conv2d
  94. _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
  95. _NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
  96. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  97. padding=0, dilation=1, groups=1, bias=True,
  98. padding_mode='zeros', device=None, dtype=None):
  99. warnings.warn(
  100. "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
  101. self._get_name()
  102. )
  103. )
  104. factory_kwargs = {'device': device, 'dtype': dtype}
  105. kernel_size = _pair(kernel_size)
  106. stride = _pair(stride)
  107. padding = _pair(padding)
  108. dilation = _pair(dilation)
  109. super().__init__(
  110. in_channels, out_channels, kernel_size, stride, padding, dilation,
  111. groups, bias, padding_mode, **factory_kwargs)
  112. def _get_name(self):
  113. return 'DynamicQuantizedConv2d'
  114. def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
  115. # Temporarily using len(shape) instead of ndim due to JIT issue
  116. # https://github.com/pytorch/pytorch/issues/23890
  117. if len(input.shape) != 4:
  118. raise ValueError("Input shape must be `(N, C, H, W)`!")
  119. if self.padding_mode != 'zeros':
  120. _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
  121. input = F.pad(input, _reversed_padding_repeated_twice,
  122. mode=self.padding_mode)
  123. return ops.quantized.conv2d_dynamic(
  124. input, self._packed_params, reduce_range)
  125. class Conv3d(nnq.Conv3d):
  126. r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
  127. For details on input arguments, parameters, and implementation see
  128. :class:`~torch.nn.Conv3d` and :class:`~torch.ao.nn.quantized.dynamic.Conv3d` and
  129. Attributes:
  130. weight (Tensor): packed tensor derived from the learnable weight
  131. parameter.
  132. scale (Tensor): scalar for the output scale
  133. zero_point (Tensor): scalar for the output zero point
  134. See :class:`~torch.nn.Conv3d` for other attributes.
  135. Examples::
  136. >>> # xdoctest: +SKIP
  137. >>> # With square kernels and equal stride
  138. >>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
  139. >>> # non-square kernels and unequal stride and with padding
  140. >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
  141. >>> # non-square kernels and unequal stride and with padding and dilation
  142. >>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
  143. >>> input = torch.randn(20, 16, 56, 56, 56)
  144. >>> output = m(input)
  145. """
  146. _FLOAT_MODULE = nn.Conv3d
  147. _NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
  148. _NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
  149. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  150. padding=0, dilation=1, groups=1, bias=True,
  151. padding_mode='zeros', device=None, dtype=None):
  152. warnings.warn(
  153. "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
  154. self._get_name()
  155. )
  156. )
  157. assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
  158. factory_kwargs = {'device': device, 'dtype': dtype}
  159. kernel_size = _triple(kernel_size)
  160. stride = _triple(stride)
  161. padding = _triple(padding)
  162. dilation = _triple(dilation)
  163. super()._init(
  164. in_channels, out_channels, kernel_size, stride, padding, dilation,
  165. False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
  166. def _get_name(self):
  167. return 'DynamicQuantizedConv3d'
  168. def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
  169. # Temporarily using len(shape) instead of ndim due to JIT issue
  170. # https://github.com/pytorch/pytorch/issues/23890
  171. if len(input.shape) != 5:
  172. raise ValueError("Input shape must be `(N, C, D, H, W)`!")
  173. if self.padding_mode != 'zeros':
  174. _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
  175. input = F.pad(input, _reversed_padding_repeated_twice,
  176. mode=self.padding_mode)
  177. return ops.quantized.conv3d_dynamic(
  178. input, self._packed_params, reduce_range)
  179. class ConvTranspose1d(nnq.ConvTranspose1d):
  180. r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
  181. For details on input arguments, parameters, and implementation see
  182. :class:`~torch.nn.ConvTranspose1d`.
  183. For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv1d`
  184. Attributes:
  185. weight (Tensor): packed tensor derived from the learnable weight
  186. parameter.
  187. scale (Tensor): scalar for the output scale
  188. zero_point (Tensor): scalar for the output zero point
  189. See :class:`~torch.nn.ConvTranspose1d` for other attributes.
  190. Examples::
  191. >>> # xdoctest: +SKIP
  192. >>> # With square kernels and equal stride
  193. >>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
  194. >>> # non-square kernels and unequal stride and with padding
  195. >>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  196. >>> output = m(input)
  197. >>> # exact output size can be also specified as an argument
  198. >>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1)
  199. >>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
  200. >>> h = downsample(input)
  201. >>> h.size()
  202. torch.Size([1, 16, 6])
  203. >>> output = upsample(h, output_size=input.size())
  204. >>> output.size()
  205. torch.Size([1, 16, 12])
  206. """
  207. _FLOAT_MODULE = nn.ConvTranspose1d
  208. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  209. padding=0, output_padding=0, groups=1, bias=True,
  210. dilation=1, padding_mode='zeros', device=None, dtype=None):
  211. warnings.warn(
  212. "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
  213. self._get_name()
  214. )
  215. )
  216. factory_kwargs = {'device': device, 'dtype': dtype}
  217. super().__init__(
  218. in_channels, out_channels, kernel_size, stride, padding, output_padding,
  219. groups, bias, dilation, padding_mode, **factory_kwargs)
  220. def _get_name(self):
  221. return 'DynamicQuantizedConvTranpose1d'
  222. def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
  223. # Temporarily using len(shape) instead of ndim due to JIT issue
  224. # https://github.com/pytorch/pytorch/issues/23890
  225. if len(input.shape) != 3:
  226. raise ValueError("Input shape must be `(N, C, L)`!")
  227. return torch.ops.quantized.conv_transpose1d_dynamic(
  228. input, self._packed_params, reduce_range)
  229. class ConvTranspose2d(nnq.ConvTranspose2d):
  230. r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
  231. For details on input arguments, parameters, and implementation see
  232. :class:`~torch.nn.ConvTranspose2d`.
  233. For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv2d`
  234. Attributes:
  235. weight (Tensor): packed tensor derived from the learnable weight
  236. parameter.
  237. scale (Tensor): scalar for the output scale
  238. zero_point (Tensor): scalar for the output zero point
  239. See :class:`~torch.nn.ConvTranspose2d` for other attributes.
  240. Examples::
  241. >>> # xdoctest: +SKIP
  242. >>> # With square kernels and equal stride
  243. >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
  244. >>> # non-square kernels and unequal stride and with padding
  245. >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  246. >>> output = m(input)
  247. >>> # exact output size can be also specified as an argument
  248. >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
  249. >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
  250. >>> h = downsample(input)
  251. >>> h.size()
  252. torch.Size([1, 16, 6, 6])
  253. >>> output = upsample(h, output_size=input.size())
  254. >>> output.size()
  255. torch.Size([1, 16, 12, 12])
  256. """
  257. _FLOAT_MODULE = nn.ConvTranspose2d
  258. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  259. padding=0, output_padding=0, groups=1, bias=True,
  260. dilation=1, padding_mode='zeros', device=None, dtype=None):
  261. warnings.warn(
  262. "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
  263. self._get_name()
  264. )
  265. )
  266. factory_kwargs = {'device': device, 'dtype': dtype}
  267. super().__init__(
  268. in_channels, out_channels, kernel_size, stride, padding, output_padding,
  269. groups, bias, dilation, padding_mode, **factory_kwargs)
  270. def _get_name(self):
  271. return 'DynamicQuantizedConvTranpose2d'
  272. def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
  273. # Temporarily using len(shape) instead of ndim due to JIT issue
  274. # https://github.com/pytorch/pytorch/issues/23890
  275. if len(input.shape) != 4:
  276. raise ValueError("Input shape must be `(N, C, H, W)`!")
  277. return ops.quantized.conv_transpose2d_dynamic(
  278. input, self._packed_params, reduce_range)
  279. class ConvTranspose3d(nnq.ConvTranspose3d):
  280. r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
  281. For details on input arguments, parameters, and implementation see
  282. :class:`~torch.nn.ConvTranspose3d`.
  283. For special notes, please, see :class:`~torch.ao.nn.quantized.dynamic.Conv3d`
  284. Attributes:
  285. weight (Tensor): packed tensor derived from the learnable weight
  286. parameter.
  287. scale (Tensor): scalar for the output scale
  288. zero_point (Tensor): scalar for the output zero point
  289. See :class:`~torch.nn.ConvTranspose3d` for other attributes.
  290. Examples::
  291. >>> # xdoctest: +SKIP
  292. >>> # With cubic kernels and equal stride
  293. >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
  294. >>> # non-cubic kernels and unequal stride and with padding
  295. >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
  296. >>> output = m(input)
  297. >>> # exact output size can be also specified as an argument
  298. >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
  299. >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
  300. >>> h = downsample(input)
  301. >>> h.size()
  302. torch.Size([1, 16, 6, 6, 6])
  303. >>> output = upsample(h, output_size=input.size())
  304. >>> output.size()
  305. torch.Size([1, 16, 12, 12, 12])
  306. """
  307. _FLOAT_MODULE = nn.ConvTranspose3d
  308. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  309. padding=0, output_padding=0, groups=1, bias=True,
  310. dilation=1, padding_mode='zeros', device=None, dtype=None):
  311. warnings.warn(
  312. "The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
  313. self._get_name()
  314. )
  315. )
  316. factory_kwargs = {'device': device, 'dtype': dtype}
  317. super().__init__(
  318. in_channels, out_channels, kernel_size, stride, padding, output_padding,
  319. groups, bias, dilation, padding_mode, **factory_kwargs)
  320. def _get_name(self):
  321. return 'DynamicQuantizedConvTranpose3d'
  322. def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
  323. # Temporarily using len(shape) instead of ndim due to JIT issue
  324. # https://github.com/pytorch/pytorch/issues/23890
  325. if len(input.shape) != 5:
  326. raise ValueError("Input shape must be `(N, C, T, H, W)`!")
  327. return ops.quantized.conv_transpose3d_dynamic(
  328. input, self._packed_params, reduce_range)