tensorrt.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import torch
  2. from .backend_config import (
  3. BackendConfig,
  4. BackendPatternConfig,
  5. DTypeConfig,
  6. ObservationType
  7. )
  8. from ._common_operator_config_utils import (
  9. _get_binary_op_configs,
  10. _get_linear_configs,
  11. _get_conv_configs,
  12. _get_share_qparams_op_configs,
  13. _get_tensor_info_op_configs,
  14. )
  15. __all__ = [
  16. "get_tensorrt_backend_config",
  17. "get_tensorrt_backend_config_dict",
  18. ]
  19. def get_tensorrt_backend_config() -> BackendConfig:
  20. """
  21. Return the `BackendConfig` for the TensorRT backend.
  22. NOTE: Current api will change in the future, it's just to unblock experimentation for
  23. new backends, please don't use it right now.
  24. TODO: add a README when it's more stable
  25. """
  26. # dtype configs
  27. weighted_op_qint8_dtype_config = DTypeConfig(
  28. input_dtype=torch.qint8,
  29. output_dtype=torch.qint8,
  30. weight_dtype=torch.qint8,
  31. bias_dtype=torch.float,
  32. )
  33. non_weighted_op_qint8_dtype_config = DTypeConfig(
  34. input_dtype=torch.qint8,
  35. output_dtype=torch.qint8,
  36. )
  37. addmm_config = BackendPatternConfig(torch.addmm) \
  38. .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
  39. .add_dtype_config(weighted_op_qint8_dtype_config) \
  40. ._set_input_type_to_index({
  41. "bias": 0,
  42. "input": 1,
  43. "weight": 2,
  44. })
  45. cat_config = BackendPatternConfig(torch.cat) \
  46. .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \
  47. .add_dtype_config(non_weighted_op_qint8_dtype_config)
  48. conv_dtype_configs = [
  49. weighted_op_qint8_dtype_config,
  50. ]
  51. linear_dtype_configs = [
  52. weighted_op_qint8_dtype_config,
  53. ]
  54. binary_op_dtype_configs = [
  55. weighted_op_qint8_dtype_config,
  56. ]
  57. share_qparams_op_dtype_configs = [
  58. non_weighted_op_qint8_dtype_config,
  59. ]
  60. tensor_info_op_dtype_configs = [
  61. non_weighted_op_qint8_dtype_config,
  62. ]
  63. # there might be things not supported in fx2trt, but it will error out
  64. # during fx2trt conversion and can support them after that
  65. return BackendConfig("tensorrt") \
  66. .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
  67. .set_backend_pattern_config(addmm_config) \
  68. .set_backend_pattern_config(cat_config) \
  69. .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
  70. .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
  71. .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
  72. .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs))
  73. def get_tensorrt_backend_config_dict():
  74. """
  75. Return the `BackendConfig` for the TensorRT backend in dictionary form.
  76. """
  77. return get_tensorrt_backend_config().to_dict()