123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- import torch
- __all__ = ['LayerNorm', 'GroupNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
- class LayerNorm(torch.nn.LayerNorm):
- r"""This is the quantized version of :class:`~torch.nn.LayerNorm`.
- Additional args:
- * **scale** - quantization scale of the output, type: double.
- * **zero_point** - quantization zero point of the output, type: long.
- """
- def __init__(self, normalized_shape, weight, bias, scale, zero_point, eps=1e-5,
- elementwise_affine=True, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine,
- **factory_kwargs)
- self.weight = weight
- self.bias = bias
- self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
- self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
- def forward(self, input):
- return torch.ops.quantized.layer_norm(
- input, self.normalized_shape, weight=self.weight, bias=self.bias,
- eps=self.eps, output_scale=self.scale, output_zero_point=self.zero_point)
- def _get_name(self):
- return 'QuantizedLayerNorm'
- @classmethod
- def from_float(cls, mod):
- scale, zero_point = mod.activation_post_process.calculate_qparams()
- new_mod = cls(
- mod.normalized_shape, mod.weight, mod.bias, float(scale),
- int(zero_point), mod.eps, mod.elementwise_affine)
- return new_mod
- @classmethod
- def from_reference(cls, mod, scale, zero_point):
- return cls(
- mod.normalized_shape, mod.weight, mod.bias, float(scale),
- int(zero_point), mod.eps, mod.elementwise_affine)
- class GroupNorm(torch.nn.GroupNorm):
- r"""This is the quantized version of :class:`~torch.nn.GroupNorm`.
- Additional args:
- * **scale** - quantization scale of the output, type: double.
- * **zero_point** - quantization zero point of the output, type: long.
- """
- __constants__ = ['num_groups', 'num_channels', 'eps', 'affine']
- def __init__(self, num_groups, num_channels, weight, bias, scale, zero_point, eps=1e-5,
- affine=True, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
- self.weight = weight
- self.bias = bias
- self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
- self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
- def forward(self, input):
- return torch.ops.quantized.group_norm(
- input, self.num_groups, self.weight, self.bias, self.eps, self.scale,
- self.zero_point)
- def _get_name(self):
- return 'QuantizedGroupNorm'
- @classmethod
- def from_float(cls, mod):
- scale, zero_point = mod.activation_post_process.calculate_qparams()
- new_mod = cls(
- mod.num_groups, mod.num_channels, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
- return new_mod
- class InstanceNorm1d(torch.nn.InstanceNorm1d):
- r"""This is the quantized version of :class:`~torch.nn.InstanceNorm1d`.
- Additional args:
- * **scale** - quantization scale of the output, type: double.
- * **zero_point** - quantization zero point of the output, type: long.
- """
- def __init__(self, num_features, weight, bias, scale, zero_point,
- eps=1e-5, momentum=0.1, affine=False,
- track_running_stats=False, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
- self.weight = weight
- self.bias = bias
- self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
- self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
- def forward(self, input):
- return torch.ops.quantized.instance_norm(
- input, self.weight, self.bias, self.eps, self.scale,
- self.zero_point)
- def _get_name(self):
- return 'QuantizedInstanceNorm1d'
- @classmethod
- def from_float(cls, mod):
- scale, zero_point = mod.activation_post_process.calculate_qparams()
- new_mod = cls(
- mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
- return new_mod
- @classmethod
- def from_reference(cls, mod, scale, zero_point):
- return cls(
- mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
- class InstanceNorm2d(torch.nn.InstanceNorm2d):
- r"""This is the quantized version of :class:`~torch.nn.InstanceNorm2d`.
- Additional args:
- * **scale** - quantization scale of the output, type: double.
- * **zero_point** - quantization zero point of the output, type: long.
- """
- def __init__(self, num_features, weight, bias, scale, zero_point,
- eps=1e-5, momentum=0.1, affine=False,
- track_running_stats=False, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
- self.weight = weight
- self.bias = bias
- self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
- self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
- def forward(self, input):
- return torch.ops.quantized.instance_norm(
- input, self.weight, self.bias, self.eps, self.scale,
- self.zero_point)
- def _get_name(self):
- return 'QuantizedInstanceNorm2d'
- @classmethod
- def from_float(cls, mod):
- scale, zero_point = mod.activation_post_process.calculate_qparams()
- new_mod = cls(
- mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
- return new_mod
- @classmethod
- def from_reference(cls, mod, scale, zero_point):
- return cls(
- mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
- class InstanceNorm3d(torch.nn.InstanceNorm3d):
- r"""This is the quantized version of :class:`~torch.nn.InstanceNorm3d`.
- Additional args:
- * **scale** - quantization scale of the output, type: double.
- * **zero_point** - quantization zero point of the output, type: long.
- """
- def __init__(self, num_features, weight, bias, scale, zero_point,
- eps=1e-5, momentum=0.1, affine=False,
- track_running_stats=False, device=None, dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__(num_features, eps, momentum, affine, track_running_stats, **factory_kwargs)
- self.weight = weight
- self.bias = bias
- self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
- self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
- def forward(self, input):
- return torch.ops.quantized.instance_norm(
- input, self.weight, self.bias, self.eps, self.scale,
- self.zero_point)
- def _get_name(self):
- return 'QuantizedInstanceNorm3d'
- @classmethod
- def from_float(cls, mod):
- scale, zero_point = mod.activation_post_process.calculate_qparams()
- new_mod = cls(
- mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
- return new_mod
- @classmethod
- def from_reference(cls, mod, scale, zero_point):
- return cls(
- mod.num_features, mod.weight, mod.bias, float(scale), int(zero_point),
- mod.eps, mod.affine)
|