conv.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from typing import Optional, Dict, Any, List
  5. from torch.nn.common_types import _size_1_t
  6. from .utils import ReferenceQuantizedModule
  7. __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
  8. class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule):
  9. """ A reference version of nn.quantized.Conv2d
  10. we will not pack the parameters in this module, since weight packing is an
  11. optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
  12. this is useful when user want to use this module in other backends like Glow.
  13. """
  14. __annotations__ = {"bias": Optional[torch.Tensor]}
  15. _IS_REFERENCE = True
  16. @staticmethod
  17. def from_float(cls, float_conv, weight_qparams):
  18. qref_conv = cls(
  19. float_conv.in_channels,
  20. float_conv.out_channels,
  21. float_conv.kernel_size, # type: ignore[arg-type]
  22. float_conv.stride, # type: ignore[arg-type]
  23. float_conv.padding, # type: ignore[arg-type]
  24. float_conv.dilation, # type: ignore[arg-type]
  25. float_conv.groups,
  26. float_conv.bias is not None, # type: ignore[arg-type]
  27. float_conv.padding_mode,
  28. device=float_conv.weight.device,
  29. dtype=float_conv.weight.dtype,
  30. weight_qparams=weight_qparams)
  31. qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
  32. if float_conv.bias is not None:
  33. qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
  34. return qref_conv
  35. class Conv1d(_ConvNd, nn.Conv1d):
  36. def __init__(self,
  37. in_channels: int,
  38. out_channels: int,
  39. kernel_size: _size_1_t,
  40. stride: _size_1_t = 1,
  41. padding: _size_1_t = 0,
  42. dilation: _size_1_t = 1,
  43. groups: int = 1,
  44. bias: bool = True,
  45. padding_mode: str = "zeros",
  46. device=None,
  47. dtype=None,
  48. weight_qparams: Optional[Dict[str, Any]] = None):
  49. nn.Conv1d.__init__(
  50. self, in_channels, out_channels, kernel_size, stride, padding, dilation,
  51. groups, bias, padding_mode, device, dtype)
  52. self._init_weight_qparams(weight_qparams, device)
  53. def forward(self, x: torch.Tensor) -> torch.Tensor:
  54. """
  55. we have:
  56. w(float) -- quant - dequant \
  57. x(float) ------------- F.conv1d ---
  58. In the full model, we will see
  59. w(float) -- quant - *dequant \
  60. x -- quant --- *dequant -- *F.conv1d --- *quant - dequant
  61. and the backend should be able to fuse the ops with `*` into a quantized conv1d
  62. """
  63. weight_quant_dequant = self.get_weight()
  64. result = F.conv1d(
  65. x, weight_quant_dequant, self.bias, self.stride,
  66. self.padding, self.dilation, self.groups)
  67. return result
  68. def _get_name(self):
  69. return "QuantizedConv1d(Reference)"
  70. @classmethod
  71. def from_float(cls, float_conv, weight_qparams):
  72. return _ConvNd.from_float(cls, float_conv, weight_qparams)
  73. class Conv2d(_ConvNd, nn.Conv2d):
  74. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  75. padding=0, dilation=1, groups=1, bias=True,
  76. padding_mode='zeros',
  77. device=None,
  78. dtype=None,
  79. weight_qparams: Optional[Dict[str, Any]] = None):
  80. nn.Conv2d.__init__(
  81. self, in_channels, out_channels, kernel_size, stride, padding, dilation,
  82. groups, bias, padding_mode, device, dtype)
  83. self._init_weight_qparams(weight_qparams, device)
  84. def forward(self, x: torch.Tensor) -> torch.Tensor:
  85. """
  86. we have:
  87. w(float) -- quant - dequant \
  88. x(float) ------------- F.conv2d ---
  89. In the full model, we will see
  90. w(float) -- quant - *dequant \
  91. x -- quant --- *dequant -- *F.conv2d --- *quant - dequant
  92. and the backend should be able to fuse the ops with `*` into a quantized conv2d
  93. """
  94. weight_quant_dequant = self.get_weight()
  95. result = F.conv2d(
  96. x, weight_quant_dequant, self.bias, self.stride,
  97. self.padding, self.dilation, self.groups)
  98. return result
  99. def _get_name(self):
  100. return "QuantizedConv2d(Reference)"
  101. @classmethod
  102. def from_float(cls, float_conv, weight_qparams):
  103. return _ConvNd.from_float(cls, float_conv, weight_qparams)
  104. class Conv3d(_ConvNd, nn.Conv3d):
  105. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  106. padding=0, dilation=1, groups=1, bias=True,
  107. padding_mode="zeros",
  108. device=None,
  109. dtype=None,
  110. weight_qparams: Optional[Dict[str, Any]] = None):
  111. nn.Conv3d.__init__(
  112. self, in_channels, out_channels, kernel_size, stride, padding, dilation,
  113. groups, bias, padding_mode, device, dtype)
  114. self._init_weight_qparams(weight_qparams, device)
  115. def forward(self, x: torch.Tensor) -> torch.Tensor:
  116. """
  117. we have:
  118. w(float) -- quant - dequant \
  119. x(float) ------------- F.conv3d ---
  120. In the full model, we will see
  121. w(float) -- quant - *dequant \
  122. x -- quant --- *dequant -- *F.conv3d --- *quant - dequant
  123. and the backend should be able to fuse the ops with `*` into a quantized conv3d
  124. """
  125. weight_quant_dequant = self.get_weight()
  126. result = F.conv3d(
  127. x, weight_quant_dequant, self.bias, self.stride,
  128. self.padding, self.dilation, self.groups)
  129. return result
  130. def _get_name(self):
  131. return "QuantizedConv3d(Reference)"
  132. @classmethod
  133. def from_float(cls, float_conv, weight_qparams):
  134. return _ConvNd.from_float(cls, float_conv, weight_qparams)
  135. class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd):
  136. """ A reference version of nn.quantized.ConvTranspose2d
  137. we will not pack the parameters in this module, since weight packing is an
  138. optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
  139. this is useful when user want to use this module in other backends like Glow.
  140. """
  141. @staticmethod
  142. def from_float(cls, float_conv, weight_qparams):
  143. qref_conv = cls(
  144. float_conv.in_channels,
  145. float_conv.out_channels,
  146. float_conv.kernel_size, # type: ignore[arg-type]
  147. float_conv.stride, # type: ignore[arg-type]
  148. float_conv.padding, # type: ignore[arg-type]
  149. float_conv.output_padding, # type: ignore[arg-type]
  150. float_conv.groups,
  151. float_conv.bias is not None, # type: ignore[arg-type]
  152. float_conv.dilation, # type: ignore[arg-type]
  153. float_conv.padding_mode,
  154. device=float_conv.weight.device,
  155. dtype=float_conv.weight.dtype,
  156. weight_qparams=weight_qparams)
  157. qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
  158. if float_conv.bias is not None:
  159. qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
  160. return qref_conv
  161. class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
  162. def __init__(self,
  163. in_channels: int,
  164. out_channels: int,
  165. kernel_size: _size_1_t,
  166. stride: _size_1_t = 1,
  167. padding: _size_1_t = 0,
  168. output_padding: _size_1_t = 0,
  169. groups: int = 1,
  170. bias: bool = True,
  171. dilation: _size_1_t = 1,
  172. padding_mode: str = "zeros",
  173. device=None,
  174. dtype=None,
  175. weight_qparams: Optional[Dict[str, Any]] = None):
  176. nn.ConvTranspose1d.__init__(
  177. self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
  178. groups, bias, dilation, padding_mode, device, dtype)
  179. self._init_weight_qparams(weight_qparams, device)
  180. def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
  181. """
  182. we have:
  183. w(float) -- quant - dequant \
  184. x(float) ------------- F.convTranspose1d ---
  185. In the full model, we will see
  186. w(float) -- quant - *dequant \
  187. x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant
  188. and the backend should be able to fuse the ops with `*` into a quantized conv1d
  189. """
  190. assert isinstance(self.padding, tuple)
  191. # One cannot replace List by Tuple or Sequence in "_output_padding" because
  192. # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
  193. output_padding = self._output_padding(
  194. input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
  195. weight_quant_dequant = self.get_weight()
  196. result = F.conv_transpose1d(
  197. x, weight_quant_dequant, self.bias, self.stride,
  198. self.padding, output_padding, self.groups, self.dilation)
  199. return result
  200. def _get_name(self):
  201. return "QuantizedConvTranspose1d(Reference)"
  202. @classmethod
  203. def from_float(cls, float_conv, weight_qparams):
  204. return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
  205. class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
  206. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  207. padding=0, output_padding=0,
  208. groups=1, bias=True, dilation=1,
  209. padding_mode='zeros',
  210. device=None,
  211. dtype=None,
  212. weight_qparams: Optional[Dict[str, Any]] = None):
  213. nn.ConvTranspose2d.__init__(
  214. self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
  215. groups, bias, dilation, padding_mode, device, dtype)
  216. self._init_weight_qparams(weight_qparams, device)
  217. def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
  218. """
  219. we have:
  220. w(float) -- quant - dequant \
  221. x(float) ------------- F.convTranspose2d ---
  222. In the full model, we will see
  223. w(float) -- quant - *dequant \
  224. x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant
  225. and the backend should be able to fuse the ops with `*` into a quantized conv2d
  226. """
  227. assert isinstance(self.padding, tuple)
  228. # One cannot replace List by Tuple or Sequence in "_output_padding" because
  229. # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
  230. output_padding = self._output_padding(
  231. input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
  232. weight_quant_dequant = self.get_weight()
  233. result = F.conv_transpose2d(
  234. x, weight_quant_dequant, self.bias, self.stride,
  235. self.padding, output_padding, self.groups, self.dilation)
  236. return result
  237. def _get_name(self):
  238. return "QuantizedConvTranspose2d(Reference)"
  239. @classmethod
  240. def from_float(cls, float_conv, weight_qparams):
  241. return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
  242. class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
  243. def __init__(self, in_channels, out_channels, kernel_size, stride=1,
  244. padding=0, output_padding=0,
  245. groups=1, bias=True, dilation=1,
  246. padding_mode="zeros",
  247. device=None,
  248. dtype=None,
  249. weight_qparams: Optional[Dict[str, Any]] = None):
  250. nn.ConvTranspose3d.__init__(
  251. self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
  252. groups, bias, dilation, padding_mode, device, dtype)
  253. self._init_weight_qparams(weight_qparams, device)
  254. def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
  255. """
  256. we have:
  257. w(float) -- quant - dequant \
  258. x(float) ------------- F.convTranspose3d ---
  259. In the full model, we will see
  260. w(float) -- quant - *dequant \
  261. x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant
  262. and the backend should be able to fuse the ops with `*` into a quantized conv3d
  263. """
  264. assert isinstance(self.padding, tuple)
  265. # One cannot replace List by Tuple or Sequence in "_output_padding" because
  266. # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
  267. output_padding = self._output_padding(
  268. input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
  269. weight_quant_dequant = self.get_weight()
  270. result = F.conv_transpose3d(
  271. x, weight_quant_dequant, self.bias, self.stride,
  272. self.padding, output_padding, self.groups, self.dilation)
  273. return result
  274. def _get_name(self):
  275. return "QuantizedConvTranspose3d(Reference)"
  276. @classmethod
  277. def from_float(cls, float_conv, weight_qparams):
  278. return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)