__init__.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """ONNX exporter."""
  2. from torch import _C
  3. from torch._C import _onnx as _C_onnx
  4. from torch._C._onnx import (
  5. _CAFFE2_ATEN_FALLBACK,
  6. OperatorExportTypes,
  7. TensorProtoDataType,
  8. TrainingMode,
  9. )
  10. from . import ( # usort:skip. Keep the order instead of sorting lexicographically
  11. _deprecation,
  12. errors,
  13. symbolic_caffe2,
  14. symbolic_helper,
  15. symbolic_opset7,
  16. symbolic_opset8,
  17. symbolic_opset9,
  18. symbolic_opset10,
  19. symbolic_opset11,
  20. symbolic_opset12,
  21. symbolic_opset13,
  22. symbolic_opset14,
  23. symbolic_opset15,
  24. symbolic_opset16,
  25. symbolic_opset17,
  26. symbolic_opset18,
  27. utils,
  28. )
  29. # TODO(After 1.13 release): Remove the deprecated SymbolicContext
  30. from ._exporter_states import ExportTypes, SymbolicContext
  31. from ._type_utils import JitScalarType
  32. from .errors import CheckerError # Backwards compatibility
  33. from .utils import (
  34. _optimize_graph,
  35. _run_symbolic_function,
  36. _run_symbolic_method,
  37. export,
  38. export_to_pretty_string,
  39. is_in_onnx_export,
  40. register_custom_op_symbolic,
  41. select_model_mode_for_export,
  42. unregister_custom_op_symbolic,
  43. )
  44. __all__ = [
  45. # Modules
  46. "symbolic_helper",
  47. "utils",
  48. "errors",
  49. # All opsets
  50. "symbolic_caffe2",
  51. "symbolic_opset7",
  52. "symbolic_opset8",
  53. "symbolic_opset9",
  54. "symbolic_opset10",
  55. "symbolic_opset11",
  56. "symbolic_opset12",
  57. "symbolic_opset13",
  58. "symbolic_opset14",
  59. "symbolic_opset15",
  60. "symbolic_opset16",
  61. "symbolic_opset17",
  62. "symbolic_opset18",
  63. # Enums
  64. "ExportTypes",
  65. "OperatorExportTypes",
  66. "TrainingMode",
  67. "TensorProtoDataType",
  68. "JitScalarType",
  69. # Public functions
  70. "export",
  71. "export_to_pretty_string",
  72. "is_in_onnx_export",
  73. "select_model_mode_for_export",
  74. "register_custom_op_symbolic",
  75. "unregister_custom_op_symbolic",
  76. "disable_log",
  77. "enable_log",
  78. # Errors
  79. "CheckerError", # Backwards compatibility
  80. ]
  81. # Set namespace for exposed private names
  82. ExportTypes.__module__ = "torch.onnx"
  83. JitScalarType.__module__ = "torch.onnx"
  84. producer_name = "pytorch"
  85. producer_version = _C_onnx.PRODUCER_VERSION
  86. @_deprecation.deprecated(
  87. since="1.12.0", removed_in="2.0", instructions="use `torch.onnx.export` instead"
  88. )
  89. def _export(*args, **kwargs):
  90. return utils._export(*args, **kwargs)
  91. # TODO(justinchuby): Deprecate these logging functions in favor of the new diagnostic module.
  92. # Returns True iff ONNX logging is turned on.
  93. is_onnx_log_enabled = _C._jit_is_onnx_log_enabled
  94. def enable_log() -> None:
  95. r"""Enables ONNX logging."""
  96. _C._jit_set_onnx_log_enabled(True)
  97. def disable_log() -> None:
  98. r"""Disables ONNX logging."""
  99. _C._jit_set_onnx_log_enabled(False)
  100. """Sets output stream for ONNX logging.
  101. Args:
  102. stream_name (str, default "stdout"): Only 'stdout' and 'stderr' are supported
  103. as ``stream_name``.
  104. """
  105. set_log_stream = _C._jit_set_onnx_log_output_stream
  106. """A simple logging facility for ONNX exporter.
  107. Args:
  108. args: Arguments are converted to string, concatenated together with a newline
  109. character appended to the end, and flushed to output stream.
  110. """
  111. log = _C._jit_onnx_log