123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import Optional, Dict, Any, List
- from torch.nn.common_types import _size_1_t
- from .utils import ReferenceQuantizedModule
- __all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
- class _ConvNd(torch.nn.modules.conv._ConvNd, ReferenceQuantizedModule):
- """ A reference version of nn.quantized.Conv2d
- we will not pack the parameters in this module, since weight packing is an
- optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
- this is useful when user want to use this module in other backends like Glow.
- """
- __annotations__ = {"bias": Optional[torch.Tensor]}
- _IS_REFERENCE = True
- @staticmethod
- def from_float(cls, float_conv, weight_qparams):
- qref_conv = cls(
- float_conv.in_channels,
- float_conv.out_channels,
- float_conv.kernel_size, # type: ignore[arg-type]
- float_conv.stride, # type: ignore[arg-type]
- float_conv.padding, # type: ignore[arg-type]
- float_conv.dilation, # type: ignore[arg-type]
- float_conv.groups,
- float_conv.bias is not None, # type: ignore[arg-type]
- float_conv.padding_mode,
- device=float_conv.weight.device,
- dtype=float_conv.weight.dtype,
- weight_qparams=weight_qparams)
- qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
- if float_conv.bias is not None:
- qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
- return qref_conv
- class Conv1d(_ConvNd, nn.Conv1d):
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: _size_1_t,
- stride: _size_1_t = 1,
- padding: _size_1_t = 0,
- dilation: _size_1_t = 1,
- groups: int = 1,
- bias: bool = True,
- padding_mode: str = "zeros",
- device=None,
- dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- nn.Conv1d.__init__(
- self, in_channels, out_channels, kernel_size, stride, padding, dilation,
- groups, bias, padding_mode, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.conv1d ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.conv1d --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized conv1d
- """
- weight_quant_dequant = self.get_weight()
- result = F.conv1d(
- x, weight_quant_dequant, self.bias, self.stride,
- self.padding, self.dilation, self.groups)
- return result
- def _get_name(self):
- return "QuantizedConv1d(Reference)"
- @classmethod
- def from_float(cls, float_conv, weight_qparams):
- return _ConvNd.from_float(cls, float_conv, weight_qparams)
- class Conv2d(_ConvNd, nn.Conv2d):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, bias=True,
- padding_mode='zeros',
- device=None,
- dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- nn.Conv2d.__init__(
- self, in_channels, out_channels, kernel_size, stride, padding, dilation,
- groups, bias, padding_mode, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.conv2d ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.conv2d --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized conv2d
- """
- weight_quant_dequant = self.get_weight()
- result = F.conv2d(
- x, weight_quant_dequant, self.bias, self.stride,
- self.padding, self.dilation, self.groups)
- return result
- def _get_name(self):
- return "QuantizedConv2d(Reference)"
- @classmethod
- def from_float(cls, float_conv, weight_qparams):
- return _ConvNd.from_float(cls, float_conv, weight_qparams)
- class Conv3d(_ConvNd, nn.Conv3d):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, dilation=1, groups=1, bias=True,
- padding_mode="zeros",
- device=None,
- dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- nn.Conv3d.__init__(
- self, in_channels, out_channels, kernel_size, stride, padding, dilation,
- groups, bias, padding_mode, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.conv3d ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.conv3d --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized conv3d
- """
- weight_quant_dequant = self.get_weight()
- result = F.conv3d(
- x, weight_quant_dequant, self.bias, self.stride,
- self.padding, self.dilation, self.groups)
- return result
- def _get_name(self):
- return "QuantizedConv3d(Reference)"
- @classmethod
- def from_float(cls, float_conv, weight_qparams):
- return _ConvNd.from_float(cls, float_conv, weight_qparams)
- class _ConvTransposeNd(_ConvNd, torch.nn.modules.conv._ConvTransposeNd):
- """ A reference version of nn.quantized.ConvTranspose2d
- we will not pack the parameters in this module, since weight packing is an
- optimization for quantized backends supported in PyTorch (fbgemm/qnnpack),
- this is useful when user want to use this module in other backends like Glow.
- """
- @staticmethod
- def from_float(cls, float_conv, weight_qparams):
- qref_conv = cls(
- float_conv.in_channels,
- float_conv.out_channels,
- float_conv.kernel_size, # type: ignore[arg-type]
- float_conv.stride, # type: ignore[arg-type]
- float_conv.padding, # type: ignore[arg-type]
- float_conv.output_padding, # type: ignore[arg-type]
- float_conv.groups,
- float_conv.bias is not None, # type: ignore[arg-type]
- float_conv.dilation, # type: ignore[arg-type]
- float_conv.padding_mode,
- device=float_conv.weight.device,
- dtype=float_conv.weight.dtype,
- weight_qparams=weight_qparams)
- qref_conv.weight = torch.nn.Parameter(float_conv.weight.detach())
- if float_conv.bias is not None:
- qref_conv.bias = torch.nn.Parameter(float_conv.bias.detach())
- return qref_conv
- class ConvTranspose1d(_ConvTransposeNd, nn.ConvTranspose1d):
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: _size_1_t,
- stride: _size_1_t = 1,
- padding: _size_1_t = 0,
- output_padding: _size_1_t = 0,
- groups: int = 1,
- bias: bool = True,
- dilation: _size_1_t = 1,
- padding_mode: str = "zeros",
- device=None,
- dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- nn.ConvTranspose1d.__init__(
- self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
- groups, bias, dilation, padding_mode, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.convTranspose1d ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.convTranspose1d --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized conv1d
- """
- assert isinstance(self.padding, tuple)
- # One cannot replace List by Tuple or Sequence in "_output_padding" because
- # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
- output_padding = self._output_padding(
- input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
- weight_quant_dequant = self.get_weight()
- result = F.conv_transpose1d(
- x, weight_quant_dequant, self.bias, self.stride,
- self.padding, output_padding, self.groups, self.dilation)
- return result
- def _get_name(self):
- return "QuantizedConvTranspose1d(Reference)"
- @classmethod
- def from_float(cls, float_conv, weight_qparams):
- return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
- class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, output_padding=0,
- groups=1, bias=True, dilation=1,
- padding_mode='zeros',
- device=None,
- dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- nn.ConvTranspose2d.__init__(
- self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
- groups, bias, dilation, padding_mode, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.convTranspose2d ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.convTranspose2d --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized conv2d
- """
- assert isinstance(self.padding, tuple)
- # One cannot replace List by Tuple or Sequence in "_output_padding" because
- # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
- output_padding = self._output_padding(
- input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
- weight_quant_dequant = self.get_weight()
- result = F.conv_transpose2d(
- x, weight_quant_dequant, self.bias, self.stride,
- self.padding, output_padding, self.groups, self.dilation)
- return result
- def _get_name(self):
- return "QuantizedConvTranspose2d(Reference)"
- @classmethod
- def from_float(cls, float_conv, weight_qparams):
- return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
- class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
- def __init__(self, in_channels, out_channels, kernel_size, stride=1,
- padding=0, output_padding=0,
- groups=1, bias=True, dilation=1,
- padding_mode="zeros",
- device=None,
- dtype=None,
- weight_qparams: Optional[Dict[str, Any]] = None):
- nn.ConvTranspose3d.__init__(
- self, in_channels, out_channels, kernel_size, stride, padding, output_padding,
- groups, bias, dilation, padding_mode, device, dtype)
- self._init_weight_qparams(weight_qparams, device)
- def forward(self, x: torch.Tensor, output_size: Optional[List[int]] = None) -> torch.Tensor:
- """
- we have:
- w(float) -- quant - dequant \
- x(float) ------------- F.convTranspose3d ---
- In the full model, we will see
- w(float) -- quant - *dequant \
- x -- quant --- *dequant -- *F.convTranspose3d --- *quant - dequant
- and the backend should be able to fuse the ops with `*` into a quantized conv3d
- """
- assert isinstance(self.padding, tuple)
- # One cannot replace List by Tuple or Sequence in "_output_padding" because
- # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
- output_padding = self._output_padding(
- input, output_size, self.stride, self.padding, self.kernel_size, self.dilation) # type: ignore[arg-type]
- weight_quant_dequant = self.get_weight()
- result = F.conv_transpose3d(
- x, weight_quant_dequant, self.bias, self.stride,
- self.padding, output_padding, self.groups, self.dilation)
- return result
- def _get_name(self):
- return "QuantizedConvTranspose3d(Reference)"
- @classmethod
- def from_float(cls, float_conv, weight_qparams):
- return _ConvTransposeNd.from_float(cls, float_conv, weight_qparams)
|