_experimental.py 1.0 KB

12345678910111213141516171819202122232425262728
  1. """Experimental classes and functions used by ONNX export."""
  2. import dataclasses
  3. from typing import Mapping, Optional, Sequence, Set, Type, Union
  4. import torch
  5. import torch._C._onnx as _C_onnx
  6. @dataclasses.dataclass
  7. class ExportOptions:
  8. """Arguments used by :func:`torch.onnx.export`.
  9. TODO: Adopt this in `torch.onnx.export` api to replace keyword arguments.
  10. """
  11. export_params: bool = True
  12. verbose: bool = False
  13. training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL
  14. input_names: Optional[Sequence[str]] = None
  15. output_names: Optional[Sequence[str]] = None
  16. operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX
  17. opset_version: Optional[int] = None
  18. do_constant_folding: bool = True
  19. dynamic_axes: Optional[Mapping[str, Union[Mapping[int, str], Sequence[int]]]] = None
  20. keep_initializers_as_inputs: Optional[bool] = None
  21. custom_opsets: Optional[Mapping[str, int]] = None
  22. export_modules_as_functions: Union[bool, Set[Type[torch.nn.Module]]] = False