normalization.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import torch
  2. __all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
  3. class LayerNorm(torch.nn.LayerNorm):
  4. r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
  5. Additional args:
  6. * **scale** - quantization scale of the output, type: double.
  7. * **zero_point** - quantization zero point of the output, type: long.
  8. """
  9. def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
  10. elementwise_affine=True, device=None, dtype=None) -> None:
  11. factory_kwargs = {'device': device, 'dtype': dtype}
  12. super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
  13. **factory_kwargs)
  14. self.weight = weight
  15. self.bias = bias
  16. self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
  17. self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
  18. def forward(self, input):
  19. return torch.ops.quantized.layer_norm(
  20. input, self.normalized_shape, weight=self.weight, bias=self.bias,
  21. eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
  22. def _get_name(self):
  23. return 'QuantizedLayerNorm'
  24. @classmethod
  25. def from_float(cls, mod):
  26. scale, zero_point = mod.activation_post_process.calculate_qparams()
  27. new_mod = cls(
  28. mod.normalized_shape, mod.weight, mod.bias, float(scale),
  29. int(zero_point), mod.eps, mod.elementwise_affine)
  30. return new_mod
  31. @classmethod
  32. def from_reference(cls, mod, scale, zero_point):
  33. return cls(
  34. mod.normalized_shape, mod.weight, mod.bias, float(scale),
  35. int(zero_point), mod.eps, mod.elementwise_affine)
  36. class GroupNorm(torch.nn.GroupNorm):
  37. r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
  38. Additional args:
  39. * **scale** - quantization scale of the output, type: double.
  40. * **zero_point** - quantization zero point of the output, type: long.
  41. """
  42. __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
  43. def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
  44. affine=True, device=None, dtype=None) -> None:
  45. factory_kwargs = {'device': device, 'dtype': dtype}
  46. super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
  47. self.weight = weight
  48. self.bias = bias
  49. self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
  50. self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
  51. def forward(self, input):
  52. return torch.ops.quantized.group_norm(
  53. input, self.num_groups, self.weight, self.bias, self.eps, self.scale,
  54. self.zero_point)
  55. def _get_name(self):
  56. return 'QuantizedGroupNorm'
  57. @classmethod
  58. def from_float(cls, mod):
  59. scale, zero_point = mod.activation_post_process.calculate_qparams()
  60. new_mod = cls(
  61. mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
  62. mod.eps, mod.affine)
  63. return new_mod
  64. class InstanceNorm1d(torch.nn.InstanceNorm1d):
  65. r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
  66. Additional args:
  67. * **scale** - quantization scale of the output, type: double.
  68. * **zero_point** - quantization zero point of the output, type: long.
  69. """
  70. def __init__(self, num_features, weight, bias, scale, zero_point,
  71. eps=1e-5, momentum=0.1, affine=False,
  72. track_running_stats=False, device=None, dtype=None) -> None:
  73. factory_kwargs = {'device': device, 'dtype': dtype}
  74. super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
  75. self.weight = weight
  76. self.bias = bias
  77. self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
  78. self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
  79. def forward(self, input):
  80. return torch.ops.quantized.instance_norm(
  81. input, self.weight, self.bias, self.eps, self.scale,
  82. self.zero_point)
  83. def _get_name(self):
  84. return 'QuantizedInstanceNorm1d'
  85. @classmethod
  86. def from_float(cls, mod):
  87. scale, zero_point = mod.activation_post_process.calculate_qparams()
  88. new_mod = cls(
  89. mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
  90. mod.eps, mod.affine)
  91. return new_mod
  92. @classmethod
  93. def from_reference(cls, mod, scale, zero_point):
  94. return cls(
  95. mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
  96. mod.eps, mod.affine)
  97. class InstanceNorm2d(torch.nn.InstanceNorm2d):
  98. r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
  99. Additional args:
  100. * **scale** - quantization scale of the output, type: double.
  101. * **zero_point** - quantization zero point of the output, type: long.
  102. """
  103. def __init__(self, num_features, weight, bias, scale, zero_point,
  104. eps=1e-5, momentum=0.1, affine=False,
  105. track_running_stats=False, device=None, dtype=None) -> None:
  106. factory_kwargs = {'device': device, 'dtype': dtype}
  107. super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
  108. self.weight = weight
  109. self.bias = bias
  110. self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
  111. self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
  112. def forward(self, input):
  113. return torch.ops.quantized.instance_norm(
  114. input, self.weight, self.bias, self.eps, self.scale,
  115. self.zero_point)
  116. def _get_name(self):
  117. return 'QuantizedInstanceNorm2d'
  118. @classmethod
  119. def from_float(cls, mod):
  120. scale, zero_point = mod.activation_post_process.calculate_qparams()
  121. new_mod = cls(
  122. mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
  123. mod.eps, mod.affine)
  124. return new_mod
  125. @classmethod
  126. def from_reference(cls, mod, scale, zero_point):
  127. return cls(
  128. mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
  129. mod.eps, mod.affine)
  130. class InstanceNorm3d(torch.nn.InstanceNorm3d):
  131. r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
  132. Additional args:
  133. * **scale** - quantization scale of the output, type: double.
  134. * **zero_point** - quantization zero point of the output, type: long.
  135. """
  136. def __init__(self, num_features, weight, bias, scale, zero_point,
  137. eps=1e-5, momentum=0.1, affine=False,
  138. track_running_stats=False, device=None, dtype=None) -> None:
  139. factory_kwargs = {'device': device, 'dtype': dtype}
  140. super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
  141. self.weight = weight
  142. self.bias = bias
  143. self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
  144. self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
  145. def forward(self, input):
  146. return torch.ops.quantized.instance_norm(
  147. input, self.weight, self.bias, self.eps, self.scale,
  148. self.zero_point)
  149. def _get_name(self):
  150. return 'QuantizedInstanceNorm3d'
  151. @classmethod
  152. def from_float(cls, mod):
  153. scale, zero_point = mod.activation_post_process.calculate_qparams()
  154. new_mod = cls(
  155. mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
  156. mod.eps, mod.affine)
  157. return new_mod
  158. @classmethod
  159. def from_reference(cls, mod, scale, zero_point):
  160. return cls(
  161. mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
  162. mod.eps, mod.affine)