123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- import torch
- from .backend_config import (
- BackendConfig,
- BackendPatternConfig,
- DTypeConfig,
- ObservationType
- )
- from ._common_operator_config_utils import (
- _get_binary_op_configs,
- _get_linear_configs,
- _get_conv_configs,
- _get_share_qparams_op_configs,
- _get_tensor_info_op_configs,
- )
- __all__ = [
- "get_tensorrt_backend_config",
- "get_tensorrt_backend_config_dict",
- ]
- def get_tensorrt_backend_config() -> BackendConfig:
- """
- Return the `BackendConfig` for the TensorRT backend.
- NOTE: Current api will change in the future, it's just to unblock experimentation for
- new backends, please don't use it right now.
- TODO: add a README when it's more stable
- """
- # dtype configs
- weighted_op_qint8_dtype_config = DTypeConfig(
- input_dtype=torch.qint8,
- output_dtype=torch.qint8,
- weight_dtype=torch.qint8,
- bias_dtype=torch.float,
- )
- non_weighted_op_qint8_dtype_config = DTypeConfig(
- input_dtype=torch.qint8,
- output_dtype=torch.qint8,
- )
- addmm_config = BackendPatternConfig(torch.addmm) \
- .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
- .add_dtype_config(weighted_op_qint8_dtype_config) \
- ._set_input_type_to_index({
- "bias": 0,
- "input": 1,
- "weight": 2,
- })
- cat_config = BackendPatternConfig(torch.cat) \
- .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
- .add_dtype_config(non_weighted_op_qint8_dtype_config)
- conv_dtype_configs = [
- weighted_op_qint8_dtype_config,
- ]
- linear_dtype_configs = [
- weighted_op_qint8_dtype_config,
- ]
- binary_op_dtype_configs = [
- weighted_op_qint8_dtype_config,
- ]
- share_qparams_op_dtype_configs = [
- non_weighted_op_qint8_dtype_config,
- ]
- tensor_info_op_dtype_configs = [
- non_weighted_op_qint8_dtype_config,
- ]
- # there might be things not supported in fx2trt, but it will error out
- # during fx2trt conversion and can support them after that
- return BackendConfig("tensorrt") \
- .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
- .set_backend_pattern_config(addmm_config) \
- .set_backend_pattern_config(cat_config) \
- .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
- .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
- .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
- .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs))
- def get_tensorrt_backend_config_dict():
- """
- Return the `BackendConfig` for the TensorRT backend in dictionary form.
- """
- return get_tensorrt_backend_config().to_dict()
|