options.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. """Options for FX exporter."""
  2. from __future__ import annotations
  3. import dataclasses
  4. from typing import Callable, Dict
  5. import torch
  6. from torch.onnx import _constants
  7. from torch.onnx._internal.fx import function_dispatcher
  8. @dataclasses.dataclass
  9. class ExportOptions:
  10. """Options for FX-ONNX export.
  11. Attributes:
  12. opset_version: The export ONNX version.
  13. use_binary_format: Whether to Return ModelProto in binary format.
  14. decomposition_table: The decomposition table for graph ops. Default is for torch ops, including aten and prim.
  15. op_level_debug: Whether to export the model with op level debug information with onnxruntime evaluator.
  16. """
  17. opset_version: int = _constants.ONNX_DEFAULT_OPSET
  18. use_binary_format: bool = True
  19. op_level_debug: bool = False
  20. decomposition_table: Dict[torch._ops.OpOverload, Callable] = dataclasses.field(
  21. default_factory=lambda: function_dispatcher._ONNX_FRIENDLY_DECOMPOSITION_TABLE
  22. )
  23. def update(self, **kwargs):
  24. for key, value in kwargs.items():
  25. if hasattr(self, key):
  26. if value is not None:
  27. setattr(self, key, value)
  28. else:
  29. raise KeyError(f"ExportOptions has no attribute {key}")