_globals.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """Globals used internally by the ONNX exporter.
  2. Do not use this module outside of `torch.onnx` and its tests.
  3. Be very judicious when adding any new global variables. Do not create new global
  4. variables unless they are absolutely necessary.
  5. """
  6. import os
  7. import torch._C._onnx as _C_onnx
  8. # This module should only depend on _constants and nothing else in torch.onnx to keep
  9. # dependency direction clean.
  10. from torch.onnx import _constants, _exporter_states
  11. class _InternalGlobals:
  12. """Globals used internally by ONNX exporter.
  13. NOTE: Be very judicious when adding any new variables. Do not create new
  14. global variables unless they are absolutely necessary.
  15. """
  16. def __init__(self):
  17. self._export_onnx_opset_version = _constants.ONNX_DEFAULT_OPSET
  18. self._training_mode: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
  19. self._in_onnx_export: bool = False
  20. # Whether the user's model is training during export
  21. self.export_training: bool = False
  22. self.operator_export_type: _C_onnx.OperatorExportTypes = (
  23. _C_onnx.OperatorExportTypes.ONNX
  24. )
  25. self.onnx_shape_inference: bool = True
  26. # Internal feature flags
  27. if os.getenv("TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK") == "WARNINGS":
  28. self.runtime_type_check_state = (
  29. _exporter_states.RuntimeTypeCheckState.WARNINGS
  30. )
  31. elif os.getenv("TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK") == "DISABLED":
  32. self.runtime_type_check_state = (
  33. _exporter_states.RuntimeTypeCheckState.DISABLED
  34. )
  35. else:
  36. self.runtime_type_check_state = (
  37. _exporter_states.RuntimeTypeCheckState.ERRORS
  38. )
  39. @property
  40. def training_mode(self):
  41. """The training mode for the exporter."""
  42. return self._training_mode
  43. @training_mode.setter
  44. def training_mode(self, training_mode: _C_onnx.TrainingMode):
  45. if not isinstance(training_mode, _C_onnx.TrainingMode):
  46. raise TypeError(
  47. "training_mode must be of type 'torch.onnx.TrainingMode'. This is "
  48. "likely a bug in torch.onnx."
  49. )
  50. self._training_mode = training_mode
  51. @property
  52. def export_onnx_opset_version(self) -> int:
  53. """Opset version used during export."""
  54. return self._export_onnx_opset_version
  55. @export_onnx_opset_version.setter
  56. def export_onnx_opset_version(self, value: int):
  57. supported_versions = range(
  58. _constants.ONNX_MIN_OPSET, _constants.ONNX_MAX_OPSET + 1
  59. )
  60. if value not in supported_versions:
  61. raise ValueError(f"Unsupported ONNX opset version: {value}")
  62. self._export_onnx_opset_version = value
  63. @property
  64. def in_onnx_export(self) -> bool:
  65. """Whether it is in the middle of ONNX export."""
  66. return self._in_onnx_export
  67. @in_onnx_export.setter
  68. def in_onnx_export(self, value: bool):
  69. if type(value) is not bool:
  70. raise TypeError("in_onnx_export must be a boolean")
  71. self._in_onnx_export = value
  72. GLOBALS = _InternalGlobals()