_exporter_states.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from __future__ import annotations
  2. import enum
  3. from typing import Dict
  4. from torch import _C
  5. class ExportTypes:
  6. r"""Specifies how the ONNX model is stored."""
  7. PROTOBUF_FILE = "Saves model in the specified protobuf file."
  8. ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
  9. COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
  10. DIRECTORY = "Saves model in the specified folder."
  11. class SymbolicContext:
  12. """Extra context for symbolic functions.
  13. Args:
  14. params_dict (Dict[str, _C.IValue]): Mapping from graph initializer name to IValue.
  15. env (Dict[_C.Value, _C.Value]): Mapping from Torch domain graph Value to ONNX domain graph Value.
  16. cur_node (_C.Node): Current node being converted to ONNX domain.
  17. onnx_block (_C.Block): Current ONNX block that converted nodes are being appended to.
  18. """
  19. def __init__(
  20. self,
  21. params_dict: Dict[str, _C.IValue],
  22. env: dict,
  23. cur_node: _C.Node,
  24. onnx_block: _C.Block,
  25. ):
  26. self.params_dict: Dict[str, _C.IValue] = params_dict
  27. self.env: Dict[_C.Value, _C.Value] = env
  28. # Current node that is being converted.
  29. self.cur_node: _C.Node = cur_node
  30. # Current onnx block that converted nodes are being appended to.
  31. self.onnx_block: _C.Block = onnx_block
  32. @enum.unique
  33. class RuntimeTypeCheckState(enum.Enum):
  34. """Runtime type check state."""
  35. # Runtime type checking is disabled.
  36. DISABLED = enum.auto()
  37. # Runtime type checking is enabled but warnings are shown only.
  38. WARNINGS = enum.auto()
  39. # Runtime type checking is enabled.
  40. ERRORS = enum.auto()