123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293 |
- # TODO: rename executorch to qnnpack_executorch since executorch is a general runtime
- # not a specific backend
- import operator
- from typing import List
- import torch
- import torch.nn.functional as F
- import torch.nn as nn
- import torch.ao.nn.qat as nnqat
- import torch.ao.nn.quantized.reference as nnqr
- from .backend_config import (
- BackendConfig,
- BackendPatternConfig,
- DTypeConfig,
- ObservationType,
- )
- from .qnnpack import (
- qnnpack_weighted_op_qint8_symmetric_dtype_config,
- qnnpack_default_op_qint8_symmetric_dtype_config
- )
- from ._common_operator_config_utils import _Conv2dMetadata
- from ..fuser_method_mappings import _sequential_wrapper2
- __all__ = [
- "get_executorch_backend_config",
- ]
- # ===================
- # | DTYPE CONFIGS |
- # ===================
- executorch_weighted_op_int8_dtype_config = DTypeConfig(
- input_dtype=torch.quint8,
- output_dtype=torch.quint8,
- weight_dtype=torch.qint8,
- bias_dtype=torch.float,
- )
- executorch_default_op_quint8_dtype_config = DTypeConfig(
- input_dtype=torch.quint8,
- output_dtype=torch.quint8,
- )
- executorch_default_dynamic_int8_dtype_config = DTypeConfig(
- input_dtype=torch.quint8,
- output_dtype=torch.float,
- weight_dtype=torch.qint8,
- bias_dtype=torch.float,
- is_dynamic=True,
- )
- executorch_default_dynamic_float16_dtype_config = DTypeConfig(
- input_dtype=torch.float16,
- output_dtype=torch.float,
- weight_dtype=torch.float16,
- bias_dtype=torch.float,
- is_dynamic=True,
- )
- executorch_weight_only_quint8_dtype_config = DTypeConfig(
- input_dtype=torch.float,
- output_dtype=torch.float,
- weight_dtype=torch.quint8,
- )
- # =============================
- # | BACKEND PATTERN CONFIGS |
- # =============================
- def _get_linear_configs() -> List[BackendPatternConfig]:
- """
- Return all configs related to linear modules and ops.
- """
- observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
- dtype_configs = [
- qnnpack_weighted_op_qint8_symmetric_dtype_config,
- executorch_weighted_op_int8_dtype_config,
- executorch_default_dynamic_int8_dtype_config,
- executorch_default_dynamic_float16_dtype_config,
- ]
- linear_configs: List[BackendPatternConfig] = []
- # linear module
- linear_configs.append(
- BackendPatternConfig(torch.nn.Linear)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs)
- .set_root_module(torch.nn.Linear)
- .set_reference_quantized_module(nnqr.Linear)
- .set_qat_module(nnqat.Linear))
- # functional linear
- linear_configs.append(
- BackendPatternConfig(torch.nn.functional.linear)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs)
- ._set_input_type_to_index({"weight": 1, "bias": 2}))
- return linear_configs
- def _get_conv_configs() -> List[BackendPatternConfig]:
- """
- Return all configs related to conv modules and ops.
- """
- observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
- dtype_configs = [
- qnnpack_weighted_op_qint8_symmetric_dtype_config,
- executorch_weighted_op_int8_dtype_config
- ]
- conv_configs = []
- for convs in [_Conv2dMetadata]:
- # conv module
- conv_configs.append(
- BackendPatternConfig(convs.root)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs)
- .set_root_module(convs.root)
- .set_reference_quantized_module(convs.reference)
- .set_qat_module(convs.qat))
- # functional conv
- conv_configs.append(
- BackendPatternConfig(convs.func)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs)
- ._set_input_type_to_index({"weight": 1, "bias": 2}))
- # conv module + relu module
- conv_configs.append(
- BackendPatternConfig((convs.root, nn.ReLU))
- .set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
- .set_fused_module(convs.fused_conv_relu))
- # conv module + functional relu
- conv_configs.append(
- BackendPatternConfig((convs.root, F.relu))
- .set_dtype_configs(dtype_configs) # noqa: E131
- .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
- .set_fused_module(convs.fused_conv_relu))
- # fused conv relu module
- conv_configs.append(
- BackendPatternConfig(convs.fused_conv_relu)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs)
- .set_root_module(convs.root)
- .set_reference_quantized_module(convs.reference)
- .set_qat_module(convs.relu_qat))
- # functional conv + relu module
- conv_configs.append(
- BackendPatternConfig((convs.func, nn.ReLU))
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs))
- # functional conv + functional relu
- conv_configs.append(
- BackendPatternConfig((convs.func, F.relu))
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs))
- return conv_configs
- def _get_binary_ops_configs() -> List[BackendPatternConfig]:
- """
- Return all configs related to binary ops.
- """
- dtype_configs = [
- qnnpack_default_op_qint8_symmetric_dtype_config,
- executorch_weighted_op_int8_dtype_config
- ]
- num_tensor_args_to_observation_type_mapping = {
- # TODO: this is not used right now since we have extra check in prepare
- # will need to change this to NO_OBSERVER later after we implemented
- # Tensor dtype inference properly
- 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
- 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
- 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
- }
- binary_op_configs: List[BackendPatternConfig] = []
- for op in [operator.add, torch.add]:
- binary_op_configs.append(
- BackendPatternConfig(op)
- .set_dtype_configs(dtype_configs) # noqa: E131
- ._set_num_tensor_args_to_observation_type(num_tensor_args_to_observation_type_mapping))
- return binary_op_configs
- def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]:
- """
- Return the operator configs for the operators that works for both float and quantized
- input if input is quantized, the output Tensor shares the same quantization parameter
- with input.
- Example operator: avgpool2d, reshape, transpose, maxpool2d
- Example observed operator:
- observer_0 - avgpool2d - observer_0 (same observer instance as input)
- """
- observation_type = ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
- dtype_configs = [
- qnnpack_default_op_qint8_symmetric_dtype_config,
- executorch_default_op_quint8_dtype_config
- ]
- share_qparams_ops = [
- F.adaptive_avg_pool2d,
- F.relu,
- F.relu6,
- torch.nn.AdaptiveAvgPool2d,
- torch.squeeze,
- "permute",
- "reshape",
- "relu",
- "relu_",
- "squeeze",
- "squeeze_",
- ]
- share_qparams_op_configs: List[BackendPatternConfig] = []
- for op in share_qparams_ops:
- share_qparams_op_configs.append(
- BackendPatternConfig(op)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs))
- return share_qparams_op_configs
- def _get_bn_configs() -> List[BackendPatternConfig]:
- """
- Return all configs related to batchnorm.
- """
- observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
- dtype_configs = [
- qnnpack_default_op_qint8_symmetric_dtype_config,
- executorch_default_op_quint8_dtype_config
- ]
- bn_configs = []
- bn_configs.append(
- BackendPatternConfig(nn.BatchNorm2d)
- .set_observation_type(observation_type) # noqa: E131
- .set_dtype_configs(dtype_configs))
- return bn_configs
- def _get_cat_configs() -> List[BackendPatternConfig]:
- dtype_configs = [
- qnnpack_default_op_qint8_symmetric_dtype_config,
- executorch_default_op_quint8_dtype_config
- ]
- cat_configs = []
- cat_configs.append(
- BackendPatternConfig(torch.cat)
- .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT)
- .set_dtype_configs(dtype_configs))
- return cat_configs
- def _get_embedding_op_configs() -> List[BackendPatternConfig]:
- dtype_configs = [
- executorch_weight_only_quint8_dtype_config,
- ]
- embedding_op_configs = []
- for embedding_op, qat_embedding_op, ref_embedding_op in [
- (nn.Embedding, nnqat.Embedding, nnqr.Embedding),
- (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag),
- ]:
- embedding_op_configs.append(
- BackendPatternConfig(embedding_op)
- .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131
- .set_dtype_configs(dtype_configs)
- .set_qat_module(qat_embedding_op)
- .set_root_module(embedding_op)
- .set_reference_quantized_module(ref_embedding_op))
- # config for qat op
- embedding_op_configs.append(
- BackendPatternConfig(qat_embedding_op)
- .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131
- .set_dtype_configs(dtype_configs)
- .set_root_module(embedding_op)
- .set_reference_quantized_module(ref_embedding_op))
- # config for functional embedding
- embedding_op_configs.append(
- BackendPatternConfig(torch.nn.functional.embedding)
- .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131
- .set_dtype_configs(dtype_configs)
- ._set_input_type_to_index({"weight": 1}))
- return embedding_op_configs
- # =====================
- # | BACKEND CONFIGS |
- # =====================
- def get_executorch_backend_config() -> BackendConfig:
- """
- Return the `BackendConfig` for backends PyTorch lowers to through the Executorch stack.
- """
- return BackendConfig("executorch") \
- .set_backend_pattern_configs(_get_linear_configs()) \
- .set_backend_pattern_configs(_get_conv_configs()) \
- .set_backend_pattern_configs(_get_binary_ops_configs()) \
- .set_backend_pattern_configs(_get_share_qparams_ops_configs()) \
- .set_backend_pattern_configs(_get_bn_configs()) \
- .set_backend_pattern_configs(_get_cat_configs()) \
- .set_backend_pattern_configs(_get_embedding_op_configs())
|