linear.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import torch
  2. import torch.ao.nn.quantized as nnq
  3. from torch.ao.nn.quantized.modules.utils import _quantize_weight
  4. import torch.ao.nn.intrinsic as nni
  5. __all__ = [
  6. "Linear",
  7. ]
  8. class Linear(nnq.Linear):
  9. r"""
  10. A dynamic quantized linear module with floating point tensor as inputs and outputs.
  11. We adopt the same interface as `torch.nn.Linear`, please see
  12. https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
  13. Similar to :class:`torch.nn.Linear`, attributes will be randomly
  14. initialized at module creation time and will be overwritten later
  15. Attributes:
  16. weight (Tensor): the non-learnable quantized weights of the module which are of
  17. shape :math:`(\text{out\_features}, \text{in\_features})`.
  18. bias (Tensor): the non-learnable floating point bias of the module of shape
  19. :math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
  20. the values are initialized to zero.
  21. Examples::
  22. >>> # xdoctest: +SKIP
  23. >>> m = nn.quantized.dynamic.Linear(20, 30)
  24. >>> input = torch.randn(128, 20)
  25. >>> output = m(input)
  26. >>> print(output.size())
  27. torch.Size([128, 30])
  28. """
  29. # version used in this class is different from the parent class nnq.Linear
  30. _version = 4
  31. def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
  32. super().__init__(in_features, out_features, bias_, dtype=dtype)
  33. # We don't muck around with buffers or attributes or anything here
  34. # to keep the module simple. *everything* is simply a Python attribute.
  35. # Serialization logic is explicitly handled in the below serialization and
  36. # deserialization modules
  37. self.version = 4
  38. def forward(self, x):
  39. # Note that we can handle self.bias == None case.
  40. if self._packed_params.dtype == torch.qint8:
  41. if self.version is None or self.version < 4:
  42. Y = torch.ops.quantized.linear_dynamic(
  43. x, self._packed_params._packed_params)
  44. else:
  45. Y = torch.ops.quantized.linear_dynamic(
  46. x, self._packed_params._packed_params, reduce_range=True)
  47. elif self._packed_params.dtype == torch.float16:
  48. Y = torch.ops.quantized.linear_dynamic_fp16(
  49. x, self._packed_params._packed_params)
  50. else:
  51. raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
  52. return Y.to(x.dtype)
  53. def _get_name(self):
  54. return 'DynamicQuantizedLinear'
  55. def extra_repr(self):
  56. extra_repr_str = 'in_features={}, out_features={}, dtype={}'.format(
  57. self.in_features, self.out_features, self._packed_params.dtype
  58. )
  59. if self._packed_params.dtype == torch.qint8:
  60. extra_repr_str += ', qscheme={}'.format(self.weight().qscheme())
  61. return extra_repr_str
  62. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  63. missing_keys, unexpected_keys, error_msgs):
  64. version = local_metadata.get('version', None)
  65. self.version = version
  66. super()._load_from_state_dict(state_dict, prefix, local_metadata, False,
  67. missing_keys, unexpected_keys, error_msgs)
  68. @classmethod
  69. def from_float(cls, mod):
  70. r"""Create a dynamic quantized module from a float module or qparams_dict
  71. Args:
  72. mod (Module): a float module, either produced by torch.ao.quantization
  73. utilities or provided by the user
  74. """
  75. float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
  76. torch.ao.nn.intrinsic.modules.fused.LinearReLU, torch.ao.nn.qat.dynamic.Linear]
  77. assert type(mod) in float_modules, \
  78. 'nn.quantized.dynamic.Linear.from_float only works for one of' + \
  79. str([float_mod.__name__ for float_mod in float_modules])
  80. assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
  81. if type(mod) == nni.LinearReLU:
  82. mod = mod[0]
  83. if mod.qconfig is not None and mod.qconfig.weight is not None:
  84. weight_observer = mod.qconfig.weight()
  85. else:
  86. # We have the circular import issues if we import the qconfig in the beginning of this file:
  87. # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
  88. # import until we need it.
  89. from torch.ao.quantization.qconfig import default_dynamic_qconfig
  90. weight_observer = default_dynamic_qconfig.weight()
  91. dtype = weight_observer.dtype
  92. assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \
  93. "dynamic quantized linear are qint8 and float16 got: {}".format(dtype)
  94. weight_observer(mod.weight)
  95. if dtype == torch.qint8:
  96. qweight = _quantize_weight(mod.weight.float(), weight_observer)
  97. elif dtype == torch.float16:
  98. qweight = mod.weight.float()
  99. else:
  100. raise RuntimeError('Unsupported dtype specified for dynamic quantized Linear!')
  101. qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
  102. qlinear.set_weight_bias(qweight, mod.bias)
  103. return qlinear
  104. @classmethod
  105. def from_reference(cls, ref_qlinear):
  106. """ Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
  107. module
  108. Args:
  109. ref_qlinear (Module): a reference quantized module, either produced by
  110. torch.ao.quantization functions or provided by the user
  111. """
  112. qlinear = cls(ref_qlinear.in_features, ref_qlinear.out_features, dtype=ref_qlinear.weight_dtype)
  113. qweight = ref_qlinear.get_quantized_weight()
  114. bias = ref_qlinear.bias
  115. qlinear.set_weight_bias(qweight, bias)
  116. return qlinear