12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- from __future__ import annotations
- import enum
- from typing import Dict
- from torch import _C
- class ExportTypes:
- r"""Specifies how the ONNX model is stored."""
- PROTOBUF_FILE = "Saves model in the specified protobuf file."
- ZIP_ARCHIVE = "Saves model in the specified ZIP file (uncompressed)."
- COMPRESSED_ZIP_ARCHIVE = "Saves model in the specified ZIP file (compressed)."
- DIRECTORY = "Saves model in the specified folder."
- class SymbolicContext:
- """Extra context for symbolic functions.
- Args:
- params_dict (Dict[str, _C.IValue]): Mapping from graph initializer name to IValue.
- env (Dict[_C.Value, _C.Value]): Mapping from Torch domain graph Value to ONNX domain graph Value.
- cur_node (_C.Node): Current node being converted to ONNX domain.
- onnx_block (_C.Block): Current ONNX block that converted nodes are being appended to.
- """
- def __init__(
- self,
- params_dict: Dict[str, _C.IValue],
- env: dict,
- cur_node: _C.Node,
- onnx_block: _C.Block,
- ):
- self.params_dict: Dict[str, _C.IValue] = params_dict
- self.env: Dict[_C.Value, _C.Value] = env
- # Current node that is being converted.
- self.cur_node: _C.Node = cur_node
- # Current onnx block that converted nodes are being appended to.
- self.onnx_block: _C.Block = onnx_block
- @enum.unique
- class RuntimeTypeCheckState(enum.Enum):
- """Runtime type check state."""
- # Runtime type checking is disabled.
- DISABLED = enum.auto()
- # Runtime type checking is enabled but warnings are shown only.
- WARNINGS = enum.auto()
- # Runtime type checking is enabled.
- ERRORS = enum.auto()
|