123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728 |
- import operator
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- toq = torch.ops.quantized
- import torch.ao.nn.quantized as nnq
- import torch.ao.nn.quantized.dynamic as nnqd
- import torch.ao.nn.intrinsic.quantized as nniq
- import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
- import torch.ao.nn.intrinsic.qat as nniqat
- import torch.ao.nn.intrinsic as nni
- import torch.ao.nn.qat as nnqat
- import torch.ao.nn.qat.dynamic as nnqatd
- from torch.ao.quantization.backend_config import get_native_backend_config
- import torch.ao.quantization.fx._lower_to_native_backend as \
- _lower_to_native_backend
- import torch.ao.quantization.quantization_mappings as quantization_mappings
- from .ns_types import NSNodeTargetType
- from typing import Callable, Dict, List, Optional, Set, Tuple
- def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
- # note: this set is modified below by items from backend_config
- sets_of_related_ops: List[Set[NSNodeTargetType]] = [
- # conv modules
- {
- nn.Conv1d,
- },
- {
- nn.Conv2d,
- },
- {
- nn.Conv3d,
- },
- # conv functionals
- {
- F.conv1d,
- },
- {
- F.conv2d,
- },
- {
- F.conv3d,
- },
- # linear modules
- {
- nn.Linear,
- },
- # linear functionals
- {
- F.linear,
- },
- # average pool
- {
- nn.AvgPool1d,
- torch.avg_pool1d,
- },
- {
- nn.AvgPool2d,
- torch._C._nn.avg_pool2d,
- },
- {
- nn.AvgPool3d,
- torch._C._nn.avg_pool3d,
- },
- # adaptive average pool
- {
- nn.AdaptiveAvgPool1d,
- F.adaptive_avg_pool1d,
- },
- {
- nn.AdaptiveAvgPool2d,
- F.adaptive_avg_pool2d,
- },
- {
- nn.AdaptiveAvgPool3d,
- F.adaptive_avg_pool3d,
- },
- # LSTM
- {
- nn.LSTM,
- },
- # add
- {
- torch.add,
- operator.add, # x + y
- },
- # cat
- {
- torch.cat,
- },
- # mul
- {
- torch.mul,
- operator.mul,
- },
- # relu
- {
- F.relu,
- nn.ReLU,
- 'relu',
- 'relu_',
- torch.relu,
- },
- # maxpool
- {
- nn.MaxPool1d,
- F.max_pool1d,
- },
- {
- nn.MaxPool2d,
- F.max_pool2d,
- },
- {
- nn.MaxPool3d,
- F.max_pool3d,
- },
- # sigmoid
- {
- torch.sigmoid,
- 'sigmoid',
- 'sigmoid_',
- nn.Sigmoid,
- F.sigmoid,
- },
- # BatchNorm
- {
- nn.BatchNorm2d,
- },
- {
- nn.BatchNorm3d,
- },
- # ConvTranspose
- {
- nn.ConvTranspose1d,
- },
- {
- nn.ConvTranspose2d,
- },
- {
- nn.ConvTranspose3d,
- },
- # ELU
- {
- nn.ELU,
- },
- # Embedding
- {
- nn.Embedding,
- },
- # EmbeddingBag
- {
- nn.EmbeddingBag,
- },
- # GroupNorm
- {
- nn.GroupNorm,
- },
- # Hardswish
- {
- nn.Hardswish,
- },
- # InstanceNorm
- {
- nn.InstanceNorm1d,
- },
- {
- nn.InstanceNorm2d,
- },
- {
- nn.InstanceNorm3d,
- },
- # LayerNorm
- {
- nn.LayerNorm,
- },
- # LeakyReLU
- {
- nn.LeakyReLU,
- },
- # ReLU6
- {
- nn.ReLU6,
- F.relu6,
- },
- # F.elu
- {
- F.elu,
- },
- # F.hardswish
- {
- F.hardswish,
- },
- # F.group_norm
- {
- F.group_norm,
- },
- # F.instance_norm
- {
- F.instance_norm,
- },
- # F.layer_norm
- {
- F.layer_norm,
- },
- # F.leaky_relu
- {
- F.leaky_relu,
- },
- # F.silu
- {
- nn.SiLU,
- F.silu,
- },
- # F.mish
- {
- nn.Mish,
- F.mish,
- },
- # F.tanh
- {
- nn.Tanh,
- F.tanh,
- torch.tanh,
- 'tanh_',
- 'tanh',
- },
- # F.hardsigmoid
- {
- 'hardsigmoid_',
- 'hardsigmoid',
- F.hardsigmoid,
- nn.Hardsigmoid,
- },
- # F.hardtanh
- {
- nn.Hardtanh,
- F.hardtanh,
- F.hardtanh_,
- },
- # floordiv
- {
- operator.floordiv,
- },
- # unsqueeze
- {
- torch.unsqueeze,
- },
- # stack
- {
- torch.stack,
- },
- # squeeze
- {
- torch.squeeze,
- },
- # sort
- {
- torch.sort,
- },
- # repeat_interleave
- {
- torch.repeat_interleave,
- },
- # min
- {
- torch.min,
- },
- # mean
- {
- torch.mean,
- },
- # max
- {
- torch.max,
- },
- # transpose
- {
- torch.transpose,
- },
- # flatten
- {
- torch.flatten,
- },
- # clamp
- {
- torch.clamp,
- },
- # chunk
- {
- torch.chunk,
- },
- # interpolate
- {
- torch.nn.functional.interpolate,
- },
- # dropout
- {
- nn.Dropout,
- },
- # F.dropout
- {
- F.dropout,
- },
- # matmul
- {
- torch.matmul,
- },
- # Softmax
- {
- nn.Softmax,
- },
- # PReLU
- {
- nn.PReLU,
- nnq.PReLU,
- },
- # F.prelu
- {
- F.prelu,
- toq.prelu,
- },
- ]
- # for each floating point op, add versions of the op added by
- # backend_config
- backend_config = get_native_backend_config()
- new_connections: List[Tuple[Callable, Callable]] = [
- # technical debt edge case
- (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
- ]
- for pattern, config in backend_config._pattern_complex_format_to_config.items():
- # pattern format: (c, (b, a))
- first_element = pattern
- # look from the end, because pattern is in reverse order
- while isinstance(first_element, (list, tuple)):
- first_element = first_element[-1]
- if config.fused_module is not None:
- # case 1: pattern fuses a pattern of ops into an op
- # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
- new_connections.append((first_element, config.fused_module))
- if config.qat_module is not None:
- # case 2: pattern swaps a module into a QAT module
- # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
- new_connections.append((first_element, config.qat_module))
- if config.reference_quantized_module is not None:
- # case 3: reference version of floating point module, such as
- # nn.Conv2d and nnqr.Conv2d
- new_connections.append((first_element, config.reference_quantized_module))
- #
- # Add reference module swaps from default lowering path
- #
- for source_to_target in (
- _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
- _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
- _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
- _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
- ):
- for source, target in source_to_target.items(): # type: ignore[attr-defined]
- new_connections.append((source, target))
- for source_to_double_target in (
- _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
- _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
- _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
- ):
- for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
- new_connections.append((source, target1))
- new_connections.append((source, target2))
- #
- # Add function swaps from default lowering path
- #
- for source, (target1, target2) in \
- _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
- new_connections.append((source, target1))
- new_connections.append((source, target2))
- for source_to_target in (
- _lower_to_native_backend.QBIN_OP_MAPPING,
- _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
- quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
- ):
- for source, target in source_to_target.items():
- new_connections.append((source, target))
- #
- # Add other swaps, ideally in the future this could be removed
- # after the lowering code stops using these.
- #
- for source_to_target in (
- quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
- ):
- for source, target in source_to_target.items():
- new_connections.append((source, target))
- # add the new connections from backend_config
- for item1, item2 in new_connections:
- for set_of_related_ops in sets_of_related_ops:
- if item1 in set_of_related_ops or item2 in set_of_related_ops:
- set_of_related_ops.add(item1)
- set_of_related_ops.add(item2)
- break
- base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
- counter = 0
- for set_of_related_ops in sets_of_related_ops:
- base_name = str(counter)
- counter += 1
- base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
- return base_name_to_sets_of_related_ops
- def get_base_name_for_op(
- base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
- op: NSNodeTargetType,
- ) -> Optional[str]:
- for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
- if op in set_of_related_ops:
- return base_name
- return None
- def add_op_to_sets_of_related_ops(
- base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
- op: NSNodeTargetType,
- related_op: Optional[NSNodeTargetType],
- ) -> None:
- if related_op is not None:
- for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
- if related_op in set_of_related_ops:
- set_of_related_ops.add(op)
- return
- # if we got here, related_op was not found
- raise AssertionError(f"{related_op} was not found")
- else:
- counter = 0
- while str(counter) in base_name_to_sets_of_related_ops:
- counter += 1
- base_name_to_sets_of_related_ops[str(counter)] = {op}
- # TODO(future PR): clean this up
- def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
- FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
- F.linear,
- F.conv1d,
- F.conv2d,
- F.conv3d,
- torch.cat,
- F.elu,
- F.hardswish,
- F.instance_norm,
- F.layer_norm,
- F.leaky_relu,
- F.dropout,
- F.silu,
- F.mish,
- operator.add,
- torch.add,
- operator.mul,
- torch.mul,
- torch.sum,
- F.prelu,
- }
- FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
- FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
- toq.linear,
- toq.linear_relu,
- toq.conv1d,
- toq.conv1d_relu,
- toq.conv2d,
- toq.conv2d_relu,
- toq.conv3d,
- toq.conv3d_relu,
- toq.cat,
- toq.elu,
- toq.hardswish,
- toq.instance_norm,
- toq.layer_norm,
- toq.leaky_relu,
- toq.dropout,
- toq.prelu,
- # TODO(future PR): implement shadowing for binary ops and
- # uncomment below
- # toq.add,
- # toq.mul,
- }
- FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
- F.relu,
- F.tanh,
- torch.tanh,
- F.sigmoid,
- torch.sigmoid,
- F.hardsigmoid,
- operator.floordiv,
- torch.adaptive_avg_pool1d,
- F.adaptive_avg_pool2d,
- F.adaptive_avg_pool3d,
- F.dropout,
- F.hardtanh,
- F.hardtanh_,
- F.interpolate,
- F.max_pool1d,
- F.max_pool2d,
- F.max_pool3d,
- F.relu6,
- torch.avg_pool1d,
- torch._C._nn.avg_pool2d,
- torch._C._nn.avg_pool3d,
- torch.cat,
- torch.chunk,
- torch.clamp,
- torch.flatten,
- torch.transpose,
- torch.max,
- torch.mean,
- torch.min,
- torch.repeat_interleave,
- torch.sort,
- torch.squeeze,
- torch.stack,
- torch.unsqueeze,
- operator.add,
- }
- MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
- nn.Linear,
- nnqat.Linear,
- nnqatd.Linear,
- nnqd.Linear,
- torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
- nn.Conv1d,
- nn.Conv2d,
- nn.Conv3d,
- nnqat.Conv1d,
- nnqat.Conv2d,
- nnqat.Conv3d,
- nnqat.Embedding,
- nnqat.EmbeddingBag,
- nn.LSTM,
- # note: nnqd.Linear is an instance of nnq.Linear, so this
- # check has to happen before the int8 module check
- nnqd.LSTM,
- nn.BatchNorm2d,
- nn.BatchNorm3d,
- nn.Dropout,
- nn.ConvTranspose1d,
- nn.ConvTranspose2d,
- nn.ConvTranspose3d,
- nn.ELU,
- nn.GroupNorm,
- nn.InstanceNorm1d,
- nn.InstanceNorm2d,
- nn.InstanceNorm3d,
- nn.LayerNorm,
- nn.Hardswish,
- nn.LeakyReLU,
- nn.ReLU6,
- nn.SiLU,
- nn.Mish,
- nn.Softmax,
- nn.PReLU,
- nni.BNReLU2d,
- nni.BNReLU3d,
- nni.ConvReLU1d,
- nni.ConvReLU2d,
- nni.ConvReLU3d,
- nni.LinearReLU,
- nni.LinearBn1d,
- nni.ConvBn1d,
- nni.ConvBn2d,
- nni.ConvBn3d,
- nniqat.ConvBn1d,
- nniqat.ConvBn2d,
- nniqat.ConvBn3d,
- nniqat.ConvBnReLU1d,
- nniqat.ConvBnReLU2d,
- nniqat.ConvBnReLU3d,
- nniqat.ConvReLU1d,
- nniqat.ConvReLU2d,
- nniqat.ConvReLU3d,
- nniqat.LinearReLU,
- nniqat.LinearBn1d,
- nniqd.LinearReLU,
- nni.LinearLeakyReLU,
- nni.LinearTanh,
- nni.ConvAdd2d,
- nni.ConvAddReLU2d,
- }
- MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
- nnq.Linear,
- nnq.Conv1d,
- nnq.Conv2d,
- nnq.Conv3d,
- nnq.BatchNorm2d,
- nnq.BatchNorm3d,
- nnq.Dropout,
- nnq.ConvTranspose1d,
- nnq.ConvTranspose2d,
- nnq.ELU,
- nnq.InstanceNorm1d,
- nnq.InstanceNorm2d,
- nnq.InstanceNorm3d,
- nnq.LayerNorm,
- nnq.Hardswish,
- nnq.LeakyReLU,
- nnq.Embedding,
- nnq.EmbeddingBag,
- nnq.Dropout,
- nnq.Softmax,
- nnq.PReLU,
- nniq.BNReLU2d,
- nniq.BNReLU3d,
- nniq.ConvReLU1d,
- nniq.ConvReLU2d,
- nniq.ConvReLU3d,
- nniq.LinearReLU,
- nniq.LinearLeakyReLU,
- nniq.LinearTanh,
- nniq.ConvAdd2d,
- nniq.ConvAddReLU2d,
- }
- MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
- nn.ReLU,
- nn.Tanh,
- nn.Sigmoid,
- nn.Hardsigmoid,
- nn.AdaptiveAvgPool1d,
- nn.AdaptiveAvgPool2d,
- nn.AdaptiveAvgPool3d,
- nn.AvgPool1d,
- nn.AvgPool2d,
- nn.AvgPool3d,
- nn.Dropout,
- nn.Hardtanh,
- nn.Identity,
- nn.MaxPool1d,
- nn.MaxPool2d,
- nn.MaxPool3d,
- nn.ReLU6,
- }
- METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
- 'sigmoid_',
- 'sigmoid',
- 'tanh_',
- 'tanh',
- 'hardsigmoid_',
- 'hardsigmoid',
- 'relu_',
- 'relu',
- }
- return {
- 'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
- 'funs_io_type_fp16': FUNS_IO_TYPE_FP16,
- 'funs_io_type_int8': FUNS_IO_TYPE_INT8,
- 'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8,
- 'mods_io_type_fp32': MODS_IO_TYPE_FP32,
- 'mods_io_type_int8': MODS_IO_TYPE_INT8,
- 'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8,
- 'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8,
- }
- def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
- FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
- torch.quantize_per_tensor,
- operator.getitem,
- }
- MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
- nn.Identity,
- }
- METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
- 'to',
- 'dequantize',
- 'reshape',
- 'view',
- 'unsqueeze_',
- 'unsqueeze',
- 'transpose',
- 'squeeze_',
- 'squeeze',
- 'size',
- 'shape',
- 'resize_',
- 'repeat_interleave',
- 'repeat',
- 'permute',
- 'numel',
- 'mean',
- 'detach_',
- 'detach',
- 'contiguous',
- 'clamp',
- 'chunk',
- }
- return {
- 'funs_unmatchable': FUNS_UNMATCHABLE,
- 'mods_unmatchable': MODS_UNMATCHABLE,
- 'meths_unmatchable': METHS_UNMATCHABLE,
- }
|