1234567891011121314151617181920212223242526272829303132333435 |
- """Options for FX exporter."""
- from __future__ import annotations
- import dataclasses
- from typing import Callable, Dict
- import torch
- from torch.onnx import _constants
- from torch.onnx._internal.fx import function_dispatcher
- @dataclasses.dataclass
- class ExportOptions:
- """Options for FX-ONNX export.
- Attributes:
- opset_version: The export ONNX version.
- use_binary_format: Whether to Return ModelProto in binary format.
- decomposition_table: The decomposition table for graph ops. Default is for torch ops, including aten and prim.
- op_level_debug: Whether to export the model with op level debug information with onnxruntime evaluator.
- """
- opset_version: int = _constants.ONNX_DEFAULT_OPSET
- use_binary_format: bool = True
- op_level_debug: bool = False
- decomposition_table: Dict[torch._ops.OpOverload, Callable] = dataclasses.field(
- default_factory=lambda: function_dispatcher._ONNX_FRIENDLY_DECOMPOSITION_TABLE
- )
- def update(self, **kwargs):
- for key, value in kwargs.items():
- if hasattr(self, key):
- if value is not None:
- setattr(self, key, value)
- else:
- raise KeyError(f"ExportOptions has no attribute {key}")
|