123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- import torch
- import torch.nn as nn
- from torch.nn.modules.utils import _single, _pair, _triple
- from torch.ao.nn.intrinsic import _FusedModule
- from typing import Tuple, TypeVar, Union
- from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
- __all__ = [
- "Conv1d",
- "Conv2d",
- "Conv3d"
- ]
- MOD = TypeVar('MOD', bound=nn.modules.conv._ConvNd)
- class _ConvNd(nn.modules.conv._ConvNd):
- _FLOAT_MODULE = MOD
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: Tuple[int, ...],
- stride: Tuple[int, ...],
- padding: Tuple[int, ...],
- dilation: Tuple[int, ...],
- transposed: bool,
- output_padding: Tuple[int, ...],
- groups: int,
- bias: bool,
- padding_mode: str,
- qconfig=None,
- device=None,
- dtype=None) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- nn.modules.conv._ConvNd.__init__(self, in_channels, out_channels, kernel_size,
- stride, padding, dilation, transposed,
- output_padding, groups, bias, padding_mode, **factory_kwargs)
- assert qconfig, 'qconfig must be provided for QAT module'
- self.qconfig = qconfig
- self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
- def forward(self, input):
- return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
- @staticmethod
- def from_float(cls, mod):
- r"""Create a qat module from a float module
- Args:
- `mod`: a float module, either produced by torch.ao.quantization utilities
- or directly from user
- """
- assert type(mod) == cls._FLOAT_MODULE, (
- "qat."
- + cls.__name__
- + ".from_float only works for "
- + cls._FLOAT_MODULE.__name__ # type: ignore[attr-defined]
- )
- assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
- assert mod.qconfig, 'Input float module must have a valid qconfig'
- if issubclass(type(mod), _FusedModule):
- mod = mod[0] # type: ignore[index]
- qconfig = mod.qconfig
- qat_conv = cls(mod.in_channels, mod.out_channels, mod.kernel_size,
- stride=mod.stride, padding=mod.padding, dilation=mod.dilation,
- groups=mod.groups, bias=mod.bias is not None,
- padding_mode=mod.padding_mode, qconfig=qconfig)
- qat_conv.weight = mod.weight
- qat_conv.bias = mod.bias
- return qat_conv
- def to_float(self):
- """ This works for both single qat conv, and the qat conv - relu modules
- to convert the qat module to a floating point module
- """
- cls = type(self)
- conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined, operator]
- self.in_channels,
- self.out_channels,
- self.kernel_size, # type: ignore[arg-type]
- self.stride, # type: ignore[arg-type]
- self.padding, # type: ignore[arg-type]
- self.dilation, # type: ignore[arg-type]
- self.groups,
- self.bias is not None,
- self.padding_mode)
- conv.weight = torch.nn.Parameter(self.weight.detach())
- if self.bias is not None:
- conv.bias = torch.nn.Parameter(self.bias.detach())
- # conv relu
- if issubclass(cls, _FusedModule):
- modules = [conv]
- assert hasattr(cls, "_FLOAT_RELU_MODULE")
- relu = cls._FLOAT_RELU_MODULE() # type: ignore[attr-defined]
- modules.append(relu)
- fused = cls._FLOAT_MODULE(*modules) # type: ignore[arg-type, attr-defined, operator]
- fused.train(self.training)
- return fused
- else:
- return conv
- class Conv1d(_ConvNd, nn.Conv1d):
- r"""
- A Conv1d module attached with FakeQuantize modules for weight,
- used for quantization aware training.
- We adopt the same interface as :class:`~torch.nn.Conv1d`
- Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
- default.
- Attributes:
- weight_fake_quant: fake quant module for weight
- """
- _FLOAT_MODULE = nn.Conv1d
- _FLOAT_CONV_MODULE = nn.Conv1d
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: _size_1_t,
- stride: _size_1_t = 1,
- padding: Union[str, _size_1_t] = 0,
- dilation: _size_1_t = 1,
- groups: int = 1,
- bias: bool = True,
- padding_mode: str = 'zeros',
- qconfig=None,
- device=None,
- dtype=None) -> None:
- kernel_size_ = _single(kernel_size)
- stride_ = _single(stride)
- padding_ = padding if isinstance(padding, str) else _single(padding)
- dilation_ = _single(dilation)
- super().__init__(
- in_channels,
- out_channels,
- kernel_size_,
- stride=stride_,
- padding=padding_,
- dilation=dilation_,
- transposed=False,
- output_padding=_single(0),
- groups=groups,
- bias=bias,
- padding_mode=padding_mode,
- qconfig=qconfig,
- device=device,
- dtype=dtype)
- @classmethod
- def from_float(cls, mod):
- return super().from_float(cls, mod)
- class Conv2d(_ConvNd, nn.Conv2d):
- r"""
- A Conv2d module attached with FakeQuantize modules for weight,
- used for quantization aware training.
- We adopt the same interface as `torch.nn.Conv2d`, please see
- https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
- for documentation.
- Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
- default.
- Attributes:
- weight_fake_quant: fake quant module for weight
- """
- _FLOAT_MODULE = nn.Conv2d
- _FLOAT_CONV_MODULE = nn.Conv2d
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: _size_2_t,
- stride: _size_2_t = 1,
- padding: Union[str, _size_2_t] = 0,
- dilation: _size_2_t = 1,
- groups: int = 1,
- bias: bool = True,
- padding_mode: str = 'zeros',
- qconfig=None,
- device=None,
- dtype=None) -> None:
- kernel_size_ = _pair(kernel_size)
- stride_ = _pair(stride)
- padding_ = padding if isinstance(padding, str) else _pair(padding)
- dilation_ = _pair(dilation)
- super().__init__(
- in_channels,
- out_channels,
- kernel_size_,
- stride=stride_,
- padding=padding_,
- dilation=dilation_,
- transposed=False,
- output_padding=_pair(0),
- groups=groups,
- bias=bias,
- padding_mode=padding_mode,
- qconfig=qconfig,
- device=device,
- dtype=dtype)
- def forward(self, input):
- return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
- @classmethod
- def from_float(cls, mod):
- return super().from_float(cls, mod)
- class Conv3d(_ConvNd, nn.Conv3d):
- r"""
- A Conv3d module attached with FakeQuantize modules for weight,
- used for quantization aware training.
- We adopt the same interface as `torch.nn.Conv3d`, please see
- https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
- for documentation.
- Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
- default.
- Attributes:
- weight_fake_quant: fake quant module for weight
- """
- _FLOAT_MODULE = nn.Conv3d
- _FLOAT_CONV_MODULE = nn.Conv3d
- def __init__(self,
- in_channels: int,
- out_channels: int,
- kernel_size: _size_3_t,
- stride: _size_3_t = 1,
- padding: Union[str, _size_3_t] = 0,
- dilation: _size_3_t = 1,
- groups: int = 1,
- bias: bool = True,
- padding_mode: str = 'zeros',
- qconfig=None,
- device=None,
- dtype=None) -> None:
- kernel_size_ = _triple(kernel_size)
- stride_ = _triple(stride)
- padding_ = padding if isinstance(padding, str) else _triple(padding)
- dilation_ = _triple(dilation)
- super().__init__(
- in_channels,
- out_channels,
- kernel_size_,
- stride=stride_,
- padding=padding_,
- dilation=dilation_,
- transposed=False,
- output_padding=_triple(0),
- groups=groups,
- bias=bias,
- padding_mode=padding_mode,
- qconfig=qconfig,
- device=device,
- dtype=dtype)
- def forward(self, input):
- return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
- @classmethod
- def from_float(cls, mod):
- return super().from_float(cls, mod)
|