linear.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.ao.nn.intrinsic import LinearReLU
  5. from torch.nn.utils.parametrize import (
  6. is_parametrized,
  7. type_before_parametrizations,
  8. transfer_parametrizations_and_params,
  9. )
  10. __all__ = [
  11. "Linear"
  12. ]
  13. class Linear(nn.Linear):
  14. r"""
  15. A linear module attached with FakeQuantize modules for weight,
  16. used for quantization aware training.
  17. We adopt the same interface as `torch.nn.Linear`, please see
  18. https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
  19. for documentation.
  20. Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
  21. default.
  22. Attributes:
  23. weight: fake quant module for weight
  24. """
  25. _FLOAT_MODULE = nn.Linear
  26. def __init__(self, in_features, out_features, bias=True,
  27. qconfig=None, device=None, dtype=None) -> None:
  28. factory_kwargs = {'device': device, 'dtype': dtype}
  29. super().__init__(in_features, out_features, bias, **factory_kwargs)
  30. assert qconfig, 'qconfig must be provided for QAT module'
  31. self.qconfig = qconfig
  32. self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
  33. def forward(self, input):
  34. return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
  35. @classmethod
  36. def from_float(cls, mod):
  37. r"""Create a qat module from a float module or qparams_dict
  38. Args: `mod` a float module, either produced by torch.ao.quantization utilities
  39. or directly from user
  40. """
  41. assert type_before_parametrizations(mod) == cls._FLOAT_MODULE, (
  42. " qat."
  43. + cls.__name__
  44. + ".from_float only works for "
  45. + cls._FLOAT_MODULE.__name__
  46. )
  47. assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
  48. assert mod.qconfig, "Input float module must have a valid qconfig"
  49. if type_before_parametrizations(mod) == LinearReLU:
  50. mod = mod[0]
  51. qconfig = mod.qconfig
  52. qat_linear = cls(mod.in_features, mod.out_features, bias=mod.bias is not None, qconfig=qconfig)
  53. if is_parametrized(mod, "weight"):
  54. transfer_parametrizations_and_params(mod, qat_linear, "weight")
  55. else:
  56. qat_linear.weight = mod.weight
  57. if is_parametrized(mod, "bias"):
  58. transfer_parametrizations_and_params(mod, qat_linear, "bias")
  59. else:
  60. qat_linear.bias = mod.bias
  61. return qat_linear
  62. def to_float(self):
  63. linear = torch.nn.Linear(self.in_features, self.out_features, self.bias is not None)
  64. linear.weight = torch.nn.Parameter(self.weight.detach())
  65. if self.bias is not None:
  66. linear.bias = torch.nn.Parameter(self.bias.detach())
  67. linear.train(self.training)
  68. return linear