conv.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946
  1. # coding=utf-8
  2. r"""Quantized convolution modules."""
  3. from typing import Optional, List, TypeVar
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch.ao.nn.intrinsic as nni
  8. import torch.ao.nn.intrinsic.qat as nniqat
  9. from torch._ops import ops
  10. from torch.nn.common_types import _size_1_t
  11. from torch.nn.modules.utils import _single, _pair, _triple
  12. from torch.nn.utils import fuse_conv_bn_weights
  13. from .utils import _quantize_weight, WeightedQuantizedModule
  14. __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
  15. _SUPPORTED_PADDING = {
  16. 'zeros',
  17. 'reflect'
  18. }
  19. def _reverse_repeat_padding(padding: List[int]) -> List[int]:
  20. _reversed_padding_repeated_twice: List[int] = []
  21. N = len(padding)
  22. for idx in range(N):
  23. for _ in range(2):
  24. _reversed_padding_repeated_twice.append(padding[N - idx - 1])
  25. return _reversed_padding_repeated_twice
  26. class _ConvNd(WeightedQuantizedModule):
  27. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  28. padding=0, dilation=1, groups=1, bias=True,
  29. padding_mode='zeros', device=None, dtype=None):
  30. # All subclasses have this signature - See PR #49702s
  31. raise NotImplementedError
  32. def _init(self, in_channels, out_channels, kernel_size, stride,
  33. padding, dilation,
  34. transposed, output_padding,
  35. groups, bias,
  36. padding_mode='zeros',
  37. device=None,
  38. dtype=None) -> None:
  39. factory_kwargs = {'device': device, 'dtype': dtype}
  40. super().__init__()
  41. if in_channels % groups != 0:
  42. raise ValueError('in_channels must be divisible by groups')
  43. if out_channels % groups != 0:
  44. raise ValueError('out_channels must be divisible by groups')
  45. self.in_channels = in_channels
  46. self.out_channels = out_channels
  47. self.kernel_size = kernel_size
  48. self.stride = stride
  49. self.padding = padding
  50. self.dilation = dilation
  51. self.transposed = transposed
  52. self.output_padding = output_padding
  53. self.groups = groups
  54. if padding_mode not in _SUPPORTED_PADDING:
  55. raise ValueError("'padding_mode' {} is not supported by quantized convolution".format(padding_mode))
  56. self.padding_mode = padding_mode
  57. # Initialize as NCHW. set_weight will internally transpose to NHWC.
  58. if self.transposed:
  59. weight_shape = [in_channels, out_channels // self.groups]
  60. else:
  61. weight_shape = [out_channels, in_channels // self.groups]
  62. qweight = torch._empty_affine_quantized(
  63. weight_shape + list(kernel_size),
  64. scale=1, zero_point=0, dtype=torch.qint8,
  65. **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
  66. bias_float = (
  67. torch.zeros(out_channels, dtype=torch.float,
  68. **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}) if bias else None)
  69. self.set_weight_bias(qweight, bias_float)
  70. self.scale = 1.0
  71. self.zero_point = 0
  72. def set_weight_bias(self, qweight, bias_float):
  73. raise NotImplementedError
  74. def bias(self):
  75. raise NotImplementedError
  76. def _weight_bias(self):
  77. raise NotImplementedError
  78. def extra_repr(self):
  79. s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
  80. ', stride={stride}, scale={scale}, zero_point={zero_point}')
  81. if self.padding != (0,) * len(self.padding):
  82. s += ', padding={padding}'
  83. if self.dilation != (1,) * len(self.dilation):
  84. s += ', dilation={dilation}'
  85. if self.output_padding != (0,) * len(self.output_padding):
  86. s += ', output_padding={output_padding}'
  87. if self.groups != 1:
  88. s += ', groups={groups}'
  89. if self.bias() is None:
  90. s += ', bias=False'
  91. return s.format(**self.__dict__)
  92. # ===== Serialization methods =====
  93. # The special consideration here is that we have to unpack the weights into
  94. # their regular QTensor form for serialization. Packed weights should not
  95. # live outside the process in which they were created, rather they should be
  96. # derived from the QTensor weight.
  97. # self
  98. # |--- weight : Tensor
  99. # |--- bias : Tensor
  100. #
  101. # TODO: maybe change to this when https://github.com/pytorch/pytorch/pull/32958 is landed
  102. # self
  103. # |--- _packed_params : Conv2dPackedParamsBase or Conv3dPackedParamsBase
  104. def _save_to_state_dict(self, destination, prefix, keep_vars):
  105. super()._save_to_state_dict(destination, prefix, keep_vars)
  106. (w, b) = self._weight_bias()
  107. destination[prefix + 'weight'] = w
  108. destination[prefix + 'bias'] = b
  109. destination[prefix + 'scale'] = torch.tensor(self.scale)
  110. destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
  111. @torch.jit.export
  112. def __getstate__(self):
  113. (w, b) = self._weight_bias()
  114. return (
  115. self.in_channels,
  116. self.out_channels,
  117. self.kernel_size,
  118. self.stride,
  119. self.padding,
  120. self.dilation,
  121. self.transposed,
  122. self.output_padding,
  123. self.groups,
  124. self.padding_mode,
  125. w,
  126. b,
  127. self.scale,
  128. self.zero_point,
  129. self.training
  130. )
  131. # ===== Deserialization methods =====
  132. # Counterpart to the serialization methods, we must pack the serialized
  133. # QTensor weight into its packed format for use by the FBGEMM ops.
  134. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  135. missing_keys, unexpected_keys, error_msgs):
  136. self.set_weight_bias(
  137. state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
  138. state_dict.pop(prefix + 'weight')
  139. state_dict.pop(prefix + 'bias')
  140. self.scale = float(state_dict[prefix + 'scale'])
  141. state_dict.pop(prefix + 'scale')
  142. self.zero_point = int(state_dict[prefix + 'zero_point'])
  143. state_dict.pop(prefix + 'zero_point')
  144. super()._load_from_state_dict(
  145. state_dict, prefix, local_metadata, False, missing_keys,
  146. unexpected_keys, error_msgs)
  147. @torch.jit.export
  148. def __setstate__(self, state):
  149. self.in_channels = state[0]
  150. self.out_channels = state[1]
  151. self.kernel_size = state[2]
  152. self.stride = state[3]
  153. self.padding = state[4]
  154. self.dilation = state[5]
  155. self.transposed = state[6]
  156. self.output_padding = state[7]
  157. self.groups = state[8]
  158. self.padding_mode = state[9]
  159. self.set_weight_bias(state[10], state[11])
  160. self.scale = state[12]
  161. self.zero_point = state[13]
  162. self.training = state[14]
  163. def __deepcopy__(self, memo):
  164. new_instance = type(self).__new__(type(self))
  165. torch.nn.Module.__init__(new_instance)
  166. state = self.__getstate__()
  167. new_instance.__setstate__(state)
  168. return new_instance
  169. def __copy__(self):
  170. return self.__deepcopy__({})
  171. @classmethod
  172. def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
  173. r"""Creates a qconv object and returns it.
  174. """
  175. if weight_post_process is None:
  176. weight_post_process = mod.qconfig.weight()
  177. weight_post_process(mod.weight)
  178. assert weight_post_process.dtype == torch.qint8, \
  179. 'Weight observer must have a dtype of qint8'
  180. qweight = _quantize_weight(mod.weight.float(), weight_post_process)
  181. # the __init__ call used is the one from derived classes and not the one from _ConvNd
  182. qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
  183. mod.stride, mod.padding, mod.dilation, mod.groups,
  184. mod.bias is not None, mod.padding_mode)
  185. qconv.set_weight_bias(qweight, mod.bias)
  186. if activation_post_process is None or activation_post_process.dtype == torch.float:
  187. return qconv # dynamic quantization doesn't need scale/zero_point
  188. else:
  189. act_scale, act_zp = activation_post_process.calculate_qparams()
  190. qconv.scale = float(act_scale)
  191. qconv.zero_point = int(act_zp)
  192. return qconv
  193. @staticmethod
  194. def from_float(cls, mod):
  195. if hasattr(mod, "weight_fake_quant"):
  196. # assert type(mod) == cls.__QAT_MODULE, " nnq." + cls.__name__ + \
  197. # ".from_float only works for " + cls.__QAT_MODULE.__name__
  198. if type(mod) == cls._NNIQAT_CONV_BN_MODULE:
  199. mod.weight, mod.bias = fuse_conv_bn_weights(
  200. mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
  201. mod.bn.eps, mod.bn.weight, mod.bn.bias)
  202. assert hasattr(mod, "activation_post_process"), \
  203. "Input QAT module must have observer attached"
  204. weight_post_process = mod.weight_fake_quant
  205. activation_post_process = mod.activation_post_process
  206. else:
  207. assert type(mod) == cls._FLOAT_MODULE, \
  208. " nnq." + cls.__name__ + ".from_float only works for " + \
  209. cls._FLOAT_MODULE.__name__ + " but got:" + str(type(mod))
  210. assert hasattr(mod, "qconfig"), \
  211. "Input float module must have qconfig defined."
  212. activation_post_process = None if not hasattr(
  213. mod, "activation_post_process") else mod.activation_post_process
  214. if type(mod) in [cls._NNI_CONV_RELU_MODULE, cls._NNI_CONV_ADD_MODULE, cls._NNI_CONV_ADD_RELU_MODULE]:
  215. mod = mod[0]
  216. weight_post_process = mod.qconfig.weight()
  217. return cls.get_qconv(mod, activation_post_process, weight_post_process)
  218. @classmethod
  219. def from_reference(cls, ref_qconv, output_scale, output_zero_point):
  220. r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
  221. Args:
  222. ref_qconv (Module): a reference quantized module, either produced by torch.ao.quantization
  223. utilities or provided by the user
  224. output_scale (float): scale for output Tensor
  225. output_zero_point (int): zero point for output Tensor
  226. """
  227. qconv = cls(
  228. ref_qconv.in_channels,
  229. ref_qconv.out_channels,
  230. ref_qconv.kernel_size, # type: ignore[arg-type]
  231. ref_qconv.stride, # type: ignore[arg-type]
  232. ref_qconv.padding, # type: ignore[arg-type]
  233. ref_qconv.dilation, # type: ignore[arg-type]
  234. ref_qconv.groups,
  235. ref_qconv.bias is not None, # type: ignore[arg-type]
  236. ref_qconv.padding_mode,
  237. device=ref_qconv.weight.device,
  238. dtype=ref_qconv.weight.dtype)
  239. qweight = ref_qconv.get_quantized_weight()
  240. qconv.set_weight_bias(qweight, ref_qconv.bias)
  241. qconv.scale = float(output_scale)
  242. qconv.zero_point = int(output_zero_point)
  243. return qconv
  244. class Conv1d(_ConvNd):
  245. r"""Applies a 1D convolution over a quantized input signal composed of
  246. several quantized input planes.
  247. For details on input arguments, parameters, and implementation see
  248. :class:`~torch.nn.Conv1d`.
  249. .. note::
  250. Only `zeros` is supported for the :attr:`padding_mode` argument.
  251. .. note::
  252. Only `torch.quint8` is supported for the input data type.
  253. Attributes:
  254. weight (Tensor): packed tensor derived from the learnable weight
  255. parameter.
  256. scale (Tensor): scalar for the output scale
  257. zero_point (Tensor): scalar for the output zero point
  258. See :class:`~torch.nn.Conv1d` for other attributes.
  259. Examples::
  260. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  261. >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
  262. >>> input = torch.randn(20, 16, 100)
  263. >>> # quantize input to quint8
  264. >>> # xdoctest: +SKIP
  265. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
  266. ... dtype=torch.quint8)
  267. >>> output = m(q_input)
  268. """
  269. _FLOAT_MODULE = nn.Conv1d
  270. _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn1d
  271. _NNI_CONV_RELU_MODULE = nni.ConvReLU1d
  272. _NNI_CONV_ADD_MODULE = None
  273. _NNI_CONV_ADD_RELU_MODULE = None
  274. def __init__(self,
  275. in_channels: int,
  276. out_channels: int,
  277. kernel_size: _size_1_t,
  278. stride: _size_1_t = 1,
  279. padding: _size_1_t = 0,
  280. dilation: _size_1_t = 1,
  281. groups: int = 1,
  282. bias: bool = True,
  283. padding_mode: str = 'zeros',
  284. device=None,
  285. dtype=None):
  286. factory_kwargs = {'device': device, 'dtype': dtype}
  287. kernel_size = _single(kernel_size)
  288. stride = _single(stride)
  289. padding = padding if isinstance(padding, str) else _single(padding)
  290. dilation = _single(dilation)
  291. # Subclasses of _ConvNd needs to call _init rather than __init__. See
  292. # discussion on PR #49702
  293. super()._init(
  294. in_channels, out_channels, kernel_size, stride, padding, dilation,
  295. False, _single(0), groups, bias, padding_mode, **factory_kwargs)
  296. def _get_name(self):
  297. return 'QuantizedConv1d'
  298. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  299. if self.padding_mode == 'zeros':
  300. self._packed_params = torch.ops.quantized.conv1d_prepack(
  301. w, b, self.stride, self.padding, self.dilation, self.groups)
  302. else:
  303. self._packed_params = torch.ops.quantized.conv1d_prepack(
  304. w, b, self.stride, _pair(0), self.dilation,
  305. self.groups)
  306. def _weight_bias(self):
  307. w, b = torch.ops.quantized.conv1d_unpack(self._packed_params)
  308. return w, b
  309. def weight(self):
  310. return self._weight_bias()[0]
  311. def bias(self):
  312. return self._weight_bias()[1]
  313. def forward(self, input):
  314. # Temporarily using len(shape) instead of ndim due to JIT issue
  315. # https://github.com/pytorch/pytorch/issues/23890
  316. if len(input.shape) != 3:
  317. raise ValueError("Input shape must be `(N, C, L)`!")
  318. if self.padding_mode != 'zeros':
  319. # Padding in Conv1d is stored as (p, p), need to get (p,)
  320. _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
  321. input = F.pad(input, _reversed_padding_repeated_twice,
  322. mode=self.padding_mode)
  323. return ops.quantized.conv1d(input, self._packed_params, self.scale, self.zero_point)
  324. @classmethod
  325. def from_float(cls, mod):
  326. r"""Creates a quantized module from a float module or qparams_dict.
  327. Args:
  328. mod (Module): a float module, either produced by torch.ao.quantization
  329. utilities or provided by the user
  330. """
  331. return _ConvNd.from_float(cls, mod)
  332. class Conv2d(_ConvNd):
  333. r"""Applies a 2D convolution over a quantized input signal composed of
  334. several quantized input planes.
  335. For details on input arguments, parameters, and implementation see
  336. :class:`~torch.nn.Conv2d`.
  337. .. note::
  338. Only `zeros` is supported for the :attr:`padding_mode` argument.
  339. .. note::
  340. Only `torch.quint8` is supported for the input data type.
  341. Attributes:
  342. weight (Tensor): packed tensor derived from the learnable weight
  343. parameter.
  344. scale (Tensor): scalar for the output scale
  345. zero_point (Tensor): scalar for the output zero point
  346. See :class:`~torch.nn.Conv2d` for other attributes.
  347. Examples::
  348. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  349. >>> # With square kernels and equal stride
  350. >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
  351. >>> # non-square kernels and unequal stride and with padding
  352. >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  353. >>> # non-square kernels and unequal stride and with padding and dilation
  354. >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
  355. >>> input = torch.randn(20, 16, 50, 100)
  356. >>> # quantize input to quint8
  357. >>> # xdoctest: +SKIP
  358. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  359. >>> output = m(q_input)
  360. """
  361. _FLOAT_MODULE = nn.Conv2d
  362. _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn2d
  363. _NNI_CONV_RELU_MODULE = nni.ConvReLU2d
  364. _NNI_CONV_ADD_MODULE = nni.ConvAdd2d
  365. _NNI_CONV_ADD_RELU_MODULE = nni.ConvAddReLU2d
  366. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  367. padding=0, dilation=1, groups=1, bias=True,
  368. padding_mode='zeros', device=None, dtype=None):
  369. factory_kwargs = {'device': device, 'dtype': dtype}
  370. kernel_size = _pair(kernel_size)
  371. stride = _pair(stride)
  372. padding = _pair(padding)
  373. dilation = _pair(dilation)
  374. # Subclasses of _ConvNd need to call _init rather than __init__. See
  375. # discussion on PR #49702
  376. super()._init(
  377. in_channels, out_channels, kernel_size, stride, padding, dilation,
  378. False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
  379. def _get_name(self):
  380. return 'QuantizedConv2d'
  381. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  382. if self.padding_mode == 'zeros':
  383. self._packed_params = torch.ops.quantized.conv2d_prepack(
  384. w, b, self.stride, self.padding, self.dilation, self.groups)
  385. else:
  386. self._packed_params = torch.ops.quantized.conv2d_prepack(
  387. w, b, self.stride, _pair(0), self.dilation, self.groups)
  388. def _weight_bias(self):
  389. return self._packed_params.unpack()
  390. def weight(self):
  391. return self._weight_bias()[0]
  392. def bias(self):
  393. return self._weight_bias()[1]
  394. def forward(self, input):
  395. # Temporarily using len(shape) instead of ndim due to JIT issue
  396. # https://github.com/pytorch/pytorch/issues/23890
  397. if len(input.shape) != 4:
  398. raise ValueError("Input shape must be `(N, C, H, W)`!")
  399. if self.padding_mode != 'zeros':
  400. _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
  401. input = F.pad(input, _reversed_padding_repeated_twice,
  402. mode=self.padding_mode)
  403. return ops.quantized.conv2d(
  404. input, self._packed_params, self.scale, self.zero_point)
  405. @classmethod
  406. def from_float(cls, mod):
  407. r"""Creates a quantized module from a float module or qparams_dict.
  408. Args:
  409. mod (Module): a float module, either produced by torch.ao.quantization
  410. utilities or provided by the user
  411. """
  412. return _ConvNd.from_float(cls, mod)
  413. class Conv3d(_ConvNd):
  414. r"""Applies a 3D convolution over a quantized input signal composed of
  415. several quantized input planes.
  416. For details on input arguments, parameters, and implementation see
  417. :class:`~torch.nn.Conv3d`.
  418. .. note::
  419. Only `zeros` is supported for the :attr:`padding_mode` argument.
  420. .. note::
  421. Only `torch.quint8` is supported for the input data type.
  422. Attributes:
  423. weight (Tensor): packed tensor derived from the learnable weight
  424. parameter.
  425. scale (Tensor): scalar for the output scale
  426. zero_point (Tensor): scalar for the output zero point
  427. See :class:`~torch.nn.Conv3d` for other attributes.
  428. Examples::
  429. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  430. >>> # With square kernels and equal stride
  431. >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
  432. >>> # non-square kernels and unequal stride and with padding
  433. >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
  434. >>> # non-square kernels and unequal stride and with padding and dilation
  435. >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
  436. >>> input = torch.randn(20, 16, 56, 56, 56)
  437. >>> # quantize input to quint8
  438. >>> # xdoctest: +SKIP
  439. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  440. >>> output = m(q_input)
  441. """
  442. _FLOAT_MODULE = nn.Conv3d
  443. _NNIQAT_CONV_BN_MODULE = nniqat.ConvBn3d
  444. _NNI_CONV_RELU_MODULE = nni.ConvReLU3d
  445. _NNI_CONV_ADD_MODULE = None
  446. _NNI_CONV_ADD_RELU_MODULE = None
  447. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  448. padding=0, dilation=1, groups=1, bias=True,
  449. padding_mode='zeros', device=None, dtype=None):
  450. assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
  451. factory_kwargs = {'device': device, 'dtype': dtype}
  452. kernel_size = _triple(kernel_size)
  453. stride = _triple(stride)
  454. padding = _triple(padding)
  455. dilation = _triple(dilation)
  456. # Subclasses of _ConvNd need to call _init rather than __init__. See
  457. # discussion on PR #49702
  458. super()._init(
  459. in_channels, out_channels, kernel_size, stride, padding, dilation,
  460. False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
  461. def _get_name(self):
  462. return 'QuantizedConv3d'
  463. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  464. if self.padding_mode == 'zeros':
  465. self._packed_params = torch.ops.quantized.conv3d_prepack(
  466. w, b, self.stride, self.padding, self.dilation, self.groups)
  467. else:
  468. self._packed_params = torch.ops.quantized.conv3d_prepack(
  469. w, b, self.stride, _triple(0), self.dilation, self.groups)
  470. def _weight_bias(self):
  471. return self._packed_params.unpack()
  472. def weight(self):
  473. return self._weight_bias()[0]
  474. def bias(self):
  475. return self._weight_bias()[1]
  476. def forward(self, input):
  477. # Temporarily using len(shape) instead of ndim due to JIT issue
  478. # https://github.com/pytorch/pytorch/issues/23890
  479. if len(input.shape) != 5:
  480. raise ValueError("Input shape must be `(N, C, D, H, W)`!")
  481. if self.padding_mode != 'zeros':
  482. _reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
  483. input = F.pad(input, _reversed_padding_repeated_twice,
  484. mode=self.padding_mode)
  485. return ops.quantized.conv3d(
  486. input, self._packed_params, self.scale, self.zero_point)
  487. @classmethod
  488. def from_float(cls, mod):
  489. r"""Creates a quantized module from a float module or qparams_dict.
  490. Args:
  491. mod (Module): a float module, either produced by torch.ao.quantization
  492. utilities or provided by the user
  493. """
  494. return _ConvNd.from_float(cls, mod)
  495. # === Transposed Convolutions ===
  496. MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
  497. class _ConvTransposeNd(_ConvNd):
  498. _FLOAT_MODULE = MOD
  499. def __init__(self, in_channels, out_channels, kernel_size, stride,
  500. padding, dilation, transposed, output_padding,
  501. groups, bias, padding_mode, device=None, dtype=None):
  502. if padding_mode != 'zeros':
  503. raise ValueError('Only "zeros" padding mode is supported for {}'.format(self.__class__.__name__))
  504. factory_kwargs = {'device': device, 'dtype': dtype}
  505. # Subclasses of _ConvNd need to call _init rather than __init__. See
  506. # discussion on PR #49702
  507. super()._init(
  508. in_channels, out_channels, kernel_size, stride,
  509. padding, dilation, transposed, output_padding,
  510. groups, bias, padding_mode, **factory_kwargs)
  511. def _input_padding(self, kernel_size: List[int], dilation: List[int], padding: List[int]) -> List[int]:
  512. res = torch.jit.annotate(List[int], [])
  513. for kdx in range(len(kernel_size)):
  514. pad = (dilation[kdx] * (kernel_size[kdx] - 1) - padding[kdx])
  515. res.append(pad)
  516. return res
  517. @classmethod
  518. def from_float(cls, mod):
  519. r"""Creates a quantized module from a float module or qparams_dict.
  520. Args:
  521. mod (Module): a float module, either produced by torch.ao.quantization
  522. utilities or provided by the user
  523. """
  524. # derived classes override cls._FLOAT_MODULE attribute
  525. msg = ' nnq.' + cls.__name__ + '.from_float only works for ' + \
  526. cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
  527. assert type(mod) == cls._FLOAT_MODULE, msg
  528. assert hasattr(mod, 'qconfig'), \
  529. 'Input float module must have qconfig defined.'
  530. weight_post_process = mod.qconfig.weight()
  531. weight_post_process(mod.weight)
  532. assert weight_post_process.dtype == torch.qint8, \
  533. 'Weight observer must have a dtype of qint8'
  534. qweight = _quantize_weight(mod.weight.float(), weight_post_process)
  535. # the __init__ call used is the one from derived classes and not the one from _ConvTransposeNd
  536. qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, # type: ignore[call-arg]
  537. mod.stride, mod.padding, mod.output_padding, mod.groups,
  538. mod.bias is not None, mod.dilation, mod.padding_mode)
  539. qconv.set_weight_bias(qweight, mod.bias)
  540. if not hasattr(mod, "activation_post_process") or mod.activation_post_process.dtype == torch.float:
  541. return qconv # dynamic quantization doesn't need scale/zero_point
  542. else:
  543. act_scale, act_zp = mod.activation_post_process.calculate_qparams()
  544. qconv.scale = float(act_scale)
  545. qconv.zero_point = int(act_zp)
  546. return qconv
  547. @staticmethod
  548. def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
  549. r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
  550. Args:
  551. ref_qconvt (Module): a reference quantized module, either produced by torch.ao.quantization
  552. utilities or provided by the user
  553. output_scale (float): scale for output Tensor
  554. output_zero_point (int): zero point for output Tensor
  555. """
  556. qconv = cls(
  557. ref_qconvt.in_channels,
  558. ref_qconvt.out_channels,
  559. ref_qconvt.kernel_size, # type: ignore[arg-type]
  560. ref_qconvt.stride, # type: ignore[arg-type]
  561. ref_qconvt.padding, # type: ignore[arg-type]
  562. ref_qconvt.output_padding, # type: ignore[arg-type]
  563. ref_qconvt.groups,
  564. ref_qconvt.bias is not None, # type: ignore[arg-type]
  565. ref_qconvt.dilation, # type: ignore[arg-type]
  566. ref_qconvt.padding_mode,
  567. device=ref_qconvt.weight.device,
  568. dtype=ref_qconvt.weight.dtype)
  569. qweight = ref_qconvt.get_quantized_weight()
  570. qconv.set_weight_bias(qweight, ref_qconvt.bias)
  571. qconv.scale = float(output_scale)
  572. qconv.zero_point = int(output_zero_point)
  573. return qconv
  574. class ConvTranspose1d(_ConvTransposeNd):
  575. r"""Applies a 1D transposed convolution operator over an input image
  576. composed of several input planes.
  577. For details on input arguments, parameters, and implementation see
  578. :class:`~torch.nn.ConvTranspose1d`.
  579. .. note:: Currently only the QNNPACK engine is implemented.
  580. Please, set the `torch.backends.quantized.engine = 'qnnpack'`
  581. For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`
  582. Attributes:
  583. weight (Tensor): packed tensor derived from the learnable weight
  584. parameter.
  585. scale (Tensor): scalar for the output scale
  586. zero_point (Tensor): scalar for the output zero point
  587. See :class:`~torch.nn.ConvTranspose2d` for other attributes.
  588. Examples::
  589. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  590. >>> torch.backends.quantized.engine = 'qnnpack'
  591. >>> from torch.ao.nn import quantized as nnq
  592. >>> # With square kernels and equal stride
  593. >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
  594. >>> # non-square kernels and unequal stride and with padding
  595. >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  596. >>> input = torch.randn(20, 16, 50)
  597. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  598. >>> output = m(q_input)
  599. >>> # exact output size can be also specified as an argument
  600. >>> input = torch.randn(1, 16, 12)
  601. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  602. >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
  603. >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
  604. >>> h = downsample(q_input)
  605. >>> h.size()
  606. torch.Size([1, 16, 6])
  607. >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
  608. >>> output = upsample(h, output_size=input.size())
  609. >>> output.size()
  610. torch.Size([1, 16, 12])
  611. """
  612. _FLOAT_MODULE = nn.ConvTranspose1d
  613. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  614. padding=0, output_padding=0, groups=1, bias=True,
  615. dilation=1, padding_mode='zeros', device=None, dtype=None):
  616. factory_kwargs = {'device': device, 'dtype': dtype}
  617. kernel_size = _single(kernel_size)
  618. stride = _single(stride)
  619. padding = _single(padding)
  620. dilation = _single(dilation)
  621. output_padding = _single(output_padding)
  622. super().__init__(
  623. in_channels, out_channels, kernel_size, stride, padding, dilation,
  624. True, output_padding, groups, bias, padding_mode, **factory_kwargs)
  625. def _get_name(self):
  626. return 'QuantizedConvTranpose1d'
  627. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  628. self._packed_params = torch.ops.quantized.conv_transpose1d_prepack(
  629. w, b, self.stride, self.padding, self.output_padding, self.dilation,
  630. self.groups)
  631. def _weight_bias(self):
  632. w, b = torch.ops.quantized.conv_transpose1d_unpack(self._packed_params)
  633. return w, b
  634. def weight(self):
  635. (w, _) = self._weight_bias()
  636. return w
  637. def bias(self):
  638. (_, b) = self._weight_bias()
  639. return b
  640. def forward(self, input):
  641. # Temporarily using len(shape) instead of ndim due to JIT issue
  642. # https://github.com/pytorch/pytorch/issues/23890
  643. if len(input.shape) != 3:
  644. raise ValueError("Input shape must be `(N, C, L)`!")
  645. return torch.ops.quantized.conv_transpose1d(
  646. input, self._packed_params, self.scale, self.zero_point)
  647. @classmethod
  648. def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
  649. return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
  650. class ConvTranspose2d(_ConvTransposeNd):
  651. r"""Applies a 2D transposed convolution operator over an input image
  652. composed of several input planes.
  653. For details on input arguments, parameters, and implementation see
  654. :class:`~torch.nn.ConvTranspose2d`.
  655. For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`
  656. Attributes:
  657. weight (Tensor): packed tensor derived from the learnable weight
  658. parameter.
  659. scale (Tensor): scalar for the output scale
  660. zero_point (Tensor): scalar for the output zero point
  661. See :class:`~torch.nn.ConvTranspose2d` for other attributes.
  662. Examples::
  663. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  664. >>> # QNNPACK or FBGEMM as backend
  665. >>> torch.backends.quantized.engine = 'qnnpack'
  666. >>> # With square kernels and equal stride
  667. >>> import torch.ao.nn.quantized as nnq
  668. >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
  669. >>> # non-square kernels and unequal stride and with padding
  670. >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
  671. >>> input = torch.randn(20, 16, 50, 100)
  672. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  673. >>> output = m(q_input)
  674. >>> # exact output size can be also specified as an argument
  675. >>> input = torch.randn(1, 16, 12, 12)
  676. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  677. >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
  678. >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
  679. >>> h = downsample(q_input)
  680. >>> h.size()
  681. torch.Size([1, 16, 6, 6])
  682. >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
  683. >>> output = upsample(h, output_size=input.size())
  684. >>> output.size()
  685. torch.Size([1, 16, 12, 12])
  686. """
  687. _FLOAT_MODULE = nn.ConvTranspose2d
  688. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  689. padding=0, output_padding=0, groups=1, bias=True,
  690. dilation=1, padding_mode='zeros', device=None, dtype=None):
  691. factory_kwargs = {'device': device, 'dtype': dtype}
  692. kernel_size = _pair(kernel_size)
  693. stride = _pair(stride)
  694. padding = _pair(padding)
  695. dilation = _pair(dilation)
  696. output_padding = _pair(output_padding)
  697. super().__init__(
  698. in_channels, out_channels, kernel_size, stride, padding, dilation,
  699. True, output_padding, groups, bias, padding_mode, **factory_kwargs)
  700. def _get_name(self):
  701. return 'QuantizedConvTranpose2d'
  702. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  703. self._packed_params = torch.ops.quantized.conv_transpose2d_prepack(
  704. w, b, self.stride, self.padding, self.output_padding, self.dilation,
  705. self.groups)
  706. def _weight_bias(self):
  707. w, b = torch.ops.quantized.conv2d_unpack(self._packed_params)
  708. return w, b
  709. def weight(self):
  710. (w, _) = self._weight_bias()
  711. return w
  712. def bias(self):
  713. (_, b) = self._weight_bias()
  714. return b
  715. def forward(self, input):
  716. # Temporarily using len(shape) instead of ndim due to JIT issue
  717. # https://github.com/pytorch/pytorch/issues/23890
  718. if len(input.shape) != 4:
  719. raise ValueError("Input shape must be `(N, C, H, W)`!")
  720. return ops.quantized.conv_transpose2d(
  721. input, self._packed_params, self.scale, self.zero_point)
  722. @classmethod
  723. def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
  724. return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)
  725. class ConvTranspose3d(_ConvTransposeNd):
  726. r"""Applies a 3D transposed convolution operator over an input image
  727. composed of several input planes.
  728. For details on input arguments, parameters, and implementation see
  729. :class:`~torch.nn.ConvTranspose3d`.
  730. .. note:: Currently only the FBGEMM engine is implemented.
  731. Please, set the `torch.backends.quantized.engine = 'fbgemm'`
  732. For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`
  733. Attributes:
  734. weight (Tensor): packed tensor derived from the learnable weight
  735. parameter.
  736. scale (Tensor): scalar for the output scale
  737. zero_point (Tensor): scalar for the output zero point
  738. See :class:`~torch.nn.ConvTranspose3d` for other attributes.
  739. Examples::
  740. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
  741. >>> torch.backends.quantized.engine = 'fbgemm'
  742. >>> from torch.ao.nn import quantized as nnq
  743. >>> # With cubic kernels and equal stride
  744. >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
  745. >>> # non-cubic kernels and unequal stride and with padding
  746. >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
  747. >>> input = torch.randn(20, 16, 50, 100, 100)
  748. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  749. >>> output = m(q_input)
  750. >>> # exact output size can be also specified as an argument
  751. >>> input = torch.randn(1, 16, 12, 12, 12)
  752. >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
  753. >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
  754. >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
  755. >>> h = downsample(q_input)
  756. >>> h.size()
  757. torch.Size([1, 16, 6, 6, 6])
  758. >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
  759. >>> output = upsample(h, output_size=input.size())
  760. >>> output.size()
  761. torch.Size([1, 16, 12, 12, 12])
  762. """
  763. _FLOAT_MODULE = nn.ConvTranspose3d
  764. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  765. padding=0, output_padding=0, groups=1, bias=True,
  766. dilation=1, padding_mode='zeros', device=None, dtype=None):
  767. factory_kwargs = {'device': device, 'dtype': dtype}
  768. kernel_size = _triple(kernel_size)
  769. stride = _triple(stride)
  770. padding = _triple(padding)
  771. dilation = _triple(dilation)
  772. output_padding = _triple(output_padding)
  773. super().__init__(
  774. in_channels, out_channels, kernel_size, stride, padding, dilation,
  775. True, output_padding, groups, bias, padding_mode, **factory_kwargs)
  776. def _get_name(self):
  777. return 'QuantizedConvTranpose3d'
  778. def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
  779. self._packed_params = torch.ops.quantized.conv_transpose3d_prepack(
  780. w, b, self.stride, self.padding, self.output_padding, self.dilation,
  781. self.groups)
  782. def _weight_bias(self):
  783. w, b = torch.ops.quantized.conv3d_unpack(self._packed_params)
  784. return w, b
  785. def weight(self):
  786. (w, _) = self._weight_bias()
  787. return w
  788. def bias(self):
  789. (_, b) = self._weight_bias()
  790. return b
  791. def forward(self, input):
  792. # Temporarily using len(shape) instead of ndim due to JIT issue
  793. # https://github.com/pytorch/pytorch/issues/23890
  794. if len(input.shape) != 5:
  795. raise ValueError("Input shape must be `(N, C, T, H, W)`!")
  796. return ops.quantized.conv_transpose3d(
  797. input, self._packed_params, self.scale, self.zero_point)
  798. @classmethod
  799. def from_reference(cls, ref_qconvt, output_scale, output_zero_point):
  800. return _ConvTransposeNd.from_reference(cls, ref_qconvt, output_scale, output_zero_point)