1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050 |
- """Functions to export models into the ONNX IR format.
- These models can be loaded with the ONNX library and then
- converted to models which run on other deep learning frameworks.
- """
- from __future__ import annotations
- import contextlib
- import copy
- import inspect
- import io
- import re
- import textwrap
- import typing
- import warnings
- from typing import (
- Any,
- Callable,
- cast,
- Collection,
- Dict,
- List,
- Mapping,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
- )
- import torch
- import torch._C._onnx as _C_onnx
- import torch.jit._trace
- import torch.serialization
- from torch import _C
- from torch.onnx import ( # noqa: F401
- _constants,
- _exporter_states,
- errors,
- symbolic_caffe2,
- symbolic_helper,
- )
- from torch.onnx._globals import GLOBALS
- from torch.onnx._internal import (
- _beartype,
- diagnostics,
- jit_utils,
- onnx_proto_utils,
- registration,
- )
- __all__ = [
- "is_in_onnx_export",
- "select_model_mode_for_export",
- "disable_apex_o2_state_dict_hook",
- "setup_onnx_logging",
- "exporter_context",
- "export",
- "warn_on_static_input_change",
- "unpack_quantized_tensor",
- "export_to_pretty_string",
- "unconvertible_ops",
- "register_custom_op_symbolic",
- "unregister_custom_op_symbolic",
- ]
- def is_in_onnx_export() -> bool:
- """Returns whether it is in the middle of ONNX export."""
- return GLOBALS.in_onnx_export
- # TODO(justinchuby): Remove dependency to this global variable from constant_fold.cpp
- # Skip check due to cannot import IValue from torch._C
- _params_dict = {} # type: ignore[var-annotated]
- @contextlib.contextmanager
- @_beartype.beartype
- def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
- r"""A context manager to temporarily set the training mode of ``model``
- to ``mode``, resetting it when we exit the with-block.
- Args:
- model: Same type and meaning as ``model`` arg to :func:`export`.
- mode: Same type and meaning as ``training`` arg to :func:`export`.
- """
- if not isinstance(mode, _C_onnx.TrainingMode):
- raise TypeError(
- f"'mode' should be a torch.onnx.TrainingMode enum, but got '{type(mode)}'."
- )
- originally_training: bool = False
- if hasattr(model, "training"):
- originally_training = model.training
- # ONNX opset 12 has better support for training amenable models, with updated
- # versions of the dropout and batch_norm operators
- if mode == _C_onnx.TrainingMode.TRAINING or (
- mode == _C_onnx.TrainingMode.PRESERVE and originally_training
- ):
- GLOBALS.export_training = True
- if GLOBALS.export_onnx_opset_version < 12:
- warnings.warn(
- "You are exporting the model in training mode with onnx opset "
- f"version {GLOBALS.export_onnx_opset_version}. "
- "Opset versions lower than opset 12 will not be able to export "
- "nodes such as Dropout and BatchNorm correctly."
- )
- else:
- GLOBALS.export_training = False
- GLOBALS.training_mode = mode
- if mode == _C_onnx.TrainingMode.TRAINING:
- model.train(True)
- elif mode == _C_onnx.TrainingMode.EVAL:
- model.train(False)
- # else mode == _C_onnx.TrainingMode.PRESERVE, do nothing
- try:
- yield
- finally:
- if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE:
- model.train(originally_training)
- @contextlib.contextmanager
- @_beartype.beartype
- def disable_apex_o2_state_dict_hook(
- model: Union[torch.nn.Module, torch.jit.ScriptFunction]
- ):
- # Apex O2 hook state_dict to return fp16 weights as fp32.
- # Exporter cannot identify them as same tensors.
- # Since this hook is only used by optimizer, it is safe to
- # remove this hook while exporting.
- if not isinstance(model, torch.jit.ScriptFunction):
- model_hooks = {} # type: ignore[var-annotated]
- for module in model.modules():
- for key, hook in module._state_dict_hooks.items():
- if type(hook).__name__ == "O2StateDictHook":
- if module not in model_hooks:
- model_hooks[module] = {}
- model_hooks[module][key] = hook
- if module in model_hooks:
- for key in model_hooks[module]:
- module._state_dict_hooks.pop(key)
- try:
- yield
- finally:
- # Add the hooks back
- for module, m_map in model_hooks.items():
- for key, hook in m_map.items():
- module._state_dict_hooks[key] = hook
- else:
- try:
- yield
- finally:
- pass
- @contextlib.contextmanager
- @_beartype.beartype
- def setup_onnx_logging(verbose: bool):
- is_originally_enabled = torch.onnx.is_onnx_log_enabled()
- if is_originally_enabled or verbose:
- torch.onnx.enable_log()
- try:
- yield
- finally:
- if not is_originally_enabled:
- torch.onnx.disable_log()
- @contextlib.contextmanager
- @_beartype.beartype
- def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
- with select_model_mode_for_export(
- model, mode
- ) as mode_ctx, disable_apex_o2_state_dict_hook(
- model
- ) as apex_ctx, setup_onnx_logging(
- verbose
- ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx:
- yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx)
- @_beartype.beartype
- def export(
- model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
- args: Union[Tuple[Any, ...], torch.Tensor],
- f: Union[str, io.BytesIO],
- export_params: bool = True,
- verbose: bool = False,
- training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
- input_names: Optional[Sequence[str]] = None,
- output_names: Optional[Sequence[str]] = None,
- operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,
- opset_version: Optional[int] = None,
- do_constant_folding: bool = True,
- dynamic_axes: Optional[
- Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]
- ] = None,
- keep_initializers_as_inputs: Optional[bool] = None,
- custom_opsets: Optional[Mapping[str, int]] = None,
- export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
- ) -> None:
- r"""Exports a model into ONNX format.
- If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
- :class:`torch.jit.ScriptFunction`, this runs
- ``model`` once in order to convert it to a TorchScript graph to be exported
- (the equivalent of :func:`torch.jit.trace`). Thus this has the same limited support
- for dynamic control flow as :func:`torch.jit.trace`.
- Args:
- model (:class:`torch.nn.Module`, :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`):
- the model to be exported.
- args (tuple or torch.Tensor):
- args can be structured either as:
- 1. ONLY A TUPLE OF ARGUMENTS::
- args = (x, y, z)
- The tuple should contain model inputs such that ``model(*args)`` is a valid
- invocation of the model. Any non-Tensor arguments will be hard-coded into the
- exported model; any Tensor arguments will become inputs of the exported model,
- in the order they occur in the tuple.
- 2. A TENSOR::
- args = torch.Tensor([1])
- This is equivalent to a 1-ary tuple of that Tensor.
- 3. A TUPLE OF ARGUMENTS ENDING WITH A DICTIONARY OF NAMED ARGUMENTS::
- args = (
- x,
- {
- "y": input_y,
- "z": input_z
- }
- )
- All but the last element of the tuple will be passed as non-keyword arguments,
- and named arguments will be set from the last element. If a named argument is
- not present in the dictionary, it is assigned the default value, or None if a
- default value is not provided.
- .. note::
- If a dictionary is the last element of the args tuple, it will be
- interpreted as containing named arguments. In order to pass a dict as the
- last non-keyword arg, provide an empty dict as the last element of the args
- tuple. For example, instead of::
- torch.onnx.export(
- model,
- (
- x,
- # WRONG: will be interpreted as named arguments
- {y: z}
- ),
- "test.onnx.pb"
- )
- Write::
- torch.onnx.export(
- model,
- (
- x,
- {y: z},
- {}
- ),
- "test.onnx.pb"
- )
- f: a file-like object (such that ``f.fileno()`` returns a file descriptor)
- or a string containing a file name. A binary protocol buffer will be written
- to this file.
- export_params (bool, default True): if True, all parameters will
- be exported. Set this to False if you want to export an untrained model.
- In this case, the exported model will first take all of its parameters
- as arguments, with the ordering as specified by ``model.state_dict().values()``
- verbose (bool, default False): if True, prints a description of the
- model being exported to stdout. In addition, the final ONNX graph will include the
- field ``doc_string``` from the exported model which mentions the source code locations
- for ``model``. If True, ONNX exporter logging will be turned on.
- training (enum, default TrainingMode.EVAL):
- * ``TrainingMode.EVAL``: export the model in inference mode.
- * ``TrainingMode.PRESERVE``: export the model in inference mode if model.training is
- False and in training mode if model.training is True.
- * ``TrainingMode.TRAINING``: export the model in training mode. Disables optimizations
- which might interfere with training.
- input_names (list of str, default empty list): names to assign to the
- input nodes of the graph, in order.
- output_names (list of str, default empty list): names to assign to the
- output nodes of the graph, in order.
- operator_export_type (enum, default OperatorExportTypes.ONNX):
- * ``OperatorExportTypes.ONNX``: Export all ops as regular ONNX ops
- (in the default opset domain).
- * ``OperatorExportTypes.ONNX_FALLTHROUGH``: Try to convert all ops
- to standard ONNX ops in the default opset domain. If unable to do so
- (e.g. because support has not been added to convert a particular torch op to ONNX),
- fall back to exporting the op into a custom opset domain without conversion. Applies
- to `custom ops <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
- as well as ATen ops. For the exported model to be usable, the runtime must support
- these non-standard ops.
- * ``OperatorExportTypes.ONNX_ATEN``: All ATen ops (in the TorchScript namespace "aten")
- are exported as ATen ops (in opset domain "org.pytorch.aten").
- `ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library, so
- this instructs the runtime to use PyTorch's implementation of these ops.
- .. warning::
- Models exported this way are probably runnable only by Caffe2.
- This may be useful if the numeric differences in implementations of operators are
- causing large differences in behavior between PyTorch and Caffe2 (which is more
- common on untrained models).
- * ``OperatorExportTypes.ONNX_ATEN_FALLBACK``: Try to export each ATen op
- (in the TorchScript namespace "aten") as a regular ONNX op. If we are unable to do so
- (e.g. because support has not been added to convert a particular torch op to ONNX),
- fall back to exporting an ATen op. See documentation on OperatorExportTypes.ONNX_ATEN for
- context.
- For example::
- graph(%0 : Float):
- %3 : int = prim::Constant[value=0]()
- # conversion unsupported
- %4 : Float = aten::triu(%0, %3)
- # conversion supported
- %5 : Float = aten::mul(%4, %0)
- return (%5)
- Assuming ``aten::triu`` is not supported in ONNX, this will be exported as::
- graph(%0 : Float):
- %1 : Long() = onnx::Constant[value={0}]()
- # not converted
- %2 : Float = aten::ATen[operator="triu"](%0, %1)
- # converted
- %3 : Float = onnx::Mul(%2, %0)
- return (%3)
- If PyTorch was built with Caffe2 (i.e. with ``BUILD_CAFFE2=1``), then
- Caffe2-specific behavior will be enabled, including special support
- for ops are produced by the modules described in
- `Quantization <https://pytorch.org/docs/stable/quantization.html>`_.
- .. warning::
- Models exported this way are probably runnable only by Caffe2.
- opset_version (int, default 14): The version of the
- `default (ai.onnx) opset <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
- to target. Must be >= 7 and <= 16.
- do_constant_folding (bool, default True): Apply the constant-folding optimization.
- Constant-folding will replace some of the ops that have all constant inputs
- with pre-computed constant nodes.
- dynamic_axes (dict[string, dict[int, string]] or dict[string, list(int)], default empty dict):
- By default the exported model will have the shapes of all input and output tensors
- set to exactly match those given in ``args``. To specify axes of tensors as
- dynamic (i.e. known only at run-time), set ``dynamic_axes`` to a dict with schema:
- * KEY (str): an input or output name. Each name must also be provided in ``input_names`` or
- ``output_names``.
- * VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
- list, each element is an axis index.
- For example::
- class SumModule(torch.nn.Module):
- def forward(self, x):
- return torch.sum(x, dim=1)
- torch.onnx.export(
- SumModule(),
- (torch.ones(2, 2),),
- "onnx.pb",
- input_names=["x"],
- output_names=["sum"]
- )
- Produces::
- input {
- name: "x"
- ...
- shape {
- dim {
- dim_value: 2 # axis 0
- }
- dim {
- dim_value: 2 # axis 1
- ...
- output {
- name: "sum"
- ...
- shape {
- dim {
- dim_value: 2 # axis 0
- ...
- While::
- torch.onnx.export(
- SumModule(),
- (torch.ones(2, 2),),
- "onnx.pb",
- input_names=["x"],
- output_names=["sum"],
- dynamic_axes={
- # dict value: manually named axes
- "x": {0: "my_custom_axis_name"},
- # list value: automatic names
- "sum": [0],
- }
- )
- Produces::
- input {
- name: "x"
- ...
- shape {
- dim {
- dim_param: "my_custom_axis_name" # axis 0
- }
- dim {
- dim_value: 2 # axis 1
- ...
- output {
- name: "sum"
- ...
- shape {
- dim {
- dim_param: "sum_dynamic_axes_1" # axis 0
- ...
- keep_initializers_as_inputs (bool, default None): If True, all the
- initializers (typically corresponding to parameters) in the
- exported graph will also be added as inputs to the graph. If False,
- then initializers are not added as inputs to the graph, and only
- the non-parameter inputs are added as inputs.
- This may allow for better optimizations (e.g. constant folding) by
- backends/runtimes.
- If ``opset_version < 9``, initializers MUST be part of graph
- inputs and this argument will be ignored and the behavior will be
- equivalent to setting this argument to True.
- If None, then the behavior is chosen automatically as follows:
- * If ``operator_export_type=OperatorExportTypes.ONNX``, the behavior is equivalent
- to setting this argument to False.
- * Else, the behavior is equivalent to setting this argument to True.
- custom_opsets (dict[str, int], default empty dict): A dict with schema:
- * KEY (str): opset domain name
- * VALUE (int): opset version
- If a custom opset is referenced by ``model`` but not mentioned in this dictionary,
- the opset version is set to 1. Only custom opset domain name and version should be
- indicated through this argument.
- export_modules_as_functions (bool or set of type of nn.Module, default False): Flag to enable
- exporting all ``nn.Module`` forward calls as local functions in ONNX. Or a set to indicate the
- particular types of modules to export as local functions in ONNX.
- This feature requires ``opset_version`` >= 15, otherwise the export will fail. This is because
- ``opset_version`` < 15 implies IR version < 8, which means no local function support.
- Module variables will be exported as function attributes. There are two categories of function
- attributes.
- 1. Annotated attributes: class variables that have type annotations via
- `PEP 526-style <https://www.python.org/dev/peps/pep-0526/#class-and-instance-variable-annotations>`_
- will be exported as attributes.
- Annotated attributes are not used inside the subgraph of ONNX local function because
- they are not created by PyTorch JIT tracing, but they may be used by consumers
- to determine whether or not to replace the function with a particular fused kernel.
- 2. Inferred attributes: variables that are used by operators inside the module. Attribute names
- will have prefix "inferred::". This is to differentiate from predefined attributes retrieved from
- python module annotations. Inferred attributes are used inside the subgraph of ONNX local function.
- * ``False`` (default): export ``nn.Module`` forward calls as fine grained nodes.
- * ``True``: export all ``nn.Module`` forward calls as local function nodes.
- * Set of type of nn.Module: export ``nn.Module`` forward calls as local function nodes,
- only if the type of the ``nn.Module`` is found in the set.
- Raises:
- :class:`torch.onnx.errors.CheckerError`: If the ONNX checker detects an invalid ONNX graph.
- :class:`torch.onnx.errors.UnsupportedOperatorError`: If the ONNX graph cannot be exported because it
- uses an operator that is not supported by the exporter.
- :class:`torch.onnx.errors.OnnxExporterError`: Other errors that can occur during export.
- All errors are subclasses of :class:`errors.OnnxExporterError`.
- """
- _export(
- model,
- args,
- f,
- export_params,
- verbose,
- training,
- input_names,
- output_names,
- operator_export_type=operator_export_type,
- opset_version=opset_version,
- do_constant_folding=do_constant_folding,
- dynamic_axes=dynamic_axes,
- keep_initializers_as_inputs=keep_initializers_as_inputs,
- custom_opsets=custom_opsets,
- export_modules_as_functions=export_modules_as_functions,
- )
- @_beartype.beartype
- def _is_constant_tensor_list(node):
- if node.kind() != "prim::Constant":
- return False
- output_type = node.output().type()
- if output_type.isSubtypeOf(_C.ListType.ofTensors()):
- return True
- if output_type.isSubtypeOf(_C.ListType(_C.OptionalType.ofTensor())):
- return True
- # ONNX can't handle constants that are lists of tensors, which can
- # get generated in constant prop. So we split them back into prim::ListConstructs
- @_beartype.beartype
- def _split_tensor_list_constants(g, block):
- for node in block.nodes():
- for subblock in node.blocks():
- _split_tensor_list_constants(g, subblock)
- if _is_constant_tensor_list(node):
- inputs = []
- for val in node.output().toIValue():
- input = g.insertConstant(val)
- input.node().moveBefore(node)
- input.node().copyMetadata(node)
- inputs.append(input)
- lc = (
- g.create("prim::ListConstruct", inputs)
- .insertBefore(node)
- .output()
- .setType(_C.ListType.ofTensors())
- )
- lc.node().copyMetadata(node)
- node.output().replaceAllUsesWith(lc)
- @_beartype.beartype
- def _optimize_graph(
- graph: _C.Graph,
- operator_export_type: _C_onnx.OperatorExportTypes,
- _disable_torch_constant_prop: bool = False,
- fixed_batch_size: bool = False,
- params_dict=None,
- dynamic_axes=None,
- input_names=None,
- module=None,
- ):
- if params_dict is None:
- params_dict = {}
- # Inline everything
- _C._jit_pass_inline(graph)
- # Remove fork/wait nodes
- _C._jit_pass_inline_fork_wait(graph)
- _C._jit_pass_lint(graph)
- _C._jit_pass_onnx_autograd_function_process(graph)
- _C._jit_pass_lower_all_tuples(graph)
- # we now record some ops like ones/zeros
- # into a trace where we previously recorded constants.
- # use constant prop to maintain our current level of onnx support
- # without implementing symbolics for all of them
- if _disable_torch_constant_prop is False:
- _C._jit_pass_constant_propagation(graph)
- _split_tensor_list_constants(graph, graph)
- # run dce to eliminate dead parts of the graph that might have been
- # left behind by things like symbolic_override
- _C._jit_pass_dce(graph)
- _C._jit_pass_lint(graph)
- # CSE should improve perf when Autocast is used with disabled cache
- # Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092
- # Must run before _C._jit_pass_erase_number_types to prevent type substitution
- if _C._jit_pass_cse(graph):
- _C._jit_pass_onnx_lint(graph)
- _C._jit_pass_canonicalize_graph_fuser_ops(graph)
- _C._jit_pass_lint(graph)
- _C._jit_pass_peephole(graph, True)
- _C._jit_pass_fuse_addmm(graph)
- _C._jit_pass_lint(graph)
- _C._jit_pass_peephole(graph, True)
- _C._jit_pass_lower_all_tuples(graph)
- # in _jit_pass_onnx, symbolic functions are called for each node for conversion.
- # However, there are nodes that cannot be converted without additional context.
- # For example, the number of outputs from split (and whether it is static or dynamic) is unknown
- # until the point where it is unpacked by listUnpack node.
- # This pass does a preprocess, and prepares the nodes such that enough context can be received
- # by the symbolic function.
- _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
- _C._jit_pass_onnx_preprocess(graph)
- # onnx does not support tuples, so try to remove them
- _C._jit_pass_lint(graph)
- # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
- _C._jit_pass_prepare_division_for_onnx(graph)
- _C._jit_pass_onnx_remove_print(graph)
- _C._jit_pass_onnx_preprocess_caffe2(graph)
- symbolic_helper._quantized_ops.clear()
- # Unpack quantized weights for conv and linear ops and insert into graph.
- _C._jit_pass_onnx_unpack_quantized_weights(
- graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
- )
- if symbolic_helper.is_caffe2_aten_fallback():
- # Insert permutes before and after each conv op to ensure correct order.
- _C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
- # Find consecutive permutes that are no-ops and remove them.
- _C._jit_pass_custom_pattern_based_rewrite_graph(
- textwrap.dedent(
- """\
- graph(%Pi):
- %Pq = quantized::nhwc2nchw(%Pi)
- %Pr = quantized::nchw2nhwc(%Pq)
- return (%Pr)"""
- ),
- textwrap.dedent(
- """\
- graph(%Ri):
- return (%Ri)"""
- ),
- graph,
- )
- # onnx only supports tensors, so we turn all out number types into tensors
- _C._jit_pass_erase_number_types(graph)
- if GLOBALS.onnx_shape_inference:
- input_names = [] if input_names is None else input_names
- dynamic_axes = {} if dynamic_axes is None else dynamic_axes
- _C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
- _C._jit_pass_onnx_lint(graph)
- graph = _C._jit_pass_onnx(graph, operator_export_type)
- _C._jit_pass_onnx_lint(graph)
- _C._jit_pass_lint(graph)
- _C._jit_pass_onnx_scalar_type_analysis(
- graph, True, GLOBALS.export_onnx_opset_version
- )
- _C._jit_pass_lint(graph)
- _C._jit_pass_onnx_peephole(
- graph, GLOBALS.export_onnx_opset_version, fixed_batch_size
- )
- _C._jit_pass_lint(graph)
- # graph is not a valid jit graph anymore because types have been replaced
- # (e.g. int with Tensor), so it now contains operators that don't actually
- # exist. We can't run normal dead code elimination because it'd fail trying
- # to look up if an operator has side effects, but we can run a dead code
- # elimination variant that doesn't need to look up if an op has side effects.
- _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
- _C._jit_pass_lint(graph)
- graph = _C._jit_pass_canonicalize(graph)
- _C._jit_pass_lint(graph)
- if GLOBALS.onnx_shape_inference:
- _C._jit_pass_onnx_graph_shape_type_inference(
- graph, params_dict, GLOBALS.export_onnx_opset_version
- )
- return graph
- @_beartype.beartype
- def warn_on_static_input_change(input_states):
- """Warns that changes to input dictionaries and strings won't take effect in the traced ONNX graph.
- We accept dictionaries and strings as ONNX inputs, but they should be only for
- configuration use. we detect here if these inputs are modified, and if so we warn
- the user that the changes won't take effect in the traced ONNX graph.
- """
- for input, traced_input in zip(input_states[0], input_states[1]):
- if isinstance(input, dict):
- if list(input.keys()) != list(traced_input.keys()):
- warning = (
- "We detected that you are modifying a dictionary that is an input to your "
- "model. "
- "Note that dictionaries are allowed as inputs in ONNX but they should be "
- "handled with care. "
- "Usages of dictionaries is not recommended, and should not be used except "
- "for configuration use. "
- "Also note that the order and values of the keys must remain the same. "
- )
- warnings.warn(warning)
- elif isinstance(input, str):
- if input != traced_input:
- warning = (
- "The model seems to have string inputs/outputs. "
- "Note that strings will not appear as inputs/outputs of the ONNX graph. "
- )
- warnings.warn(warning)
- @_beartype.beartype
- def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
- """Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
- if (
- operator_export_type is not operator_export_type.ONNX
- and _C_onnx._CAFFE2_ATEN_FALLBACK
- ):
- if arg_value is True:
- warnings.warn(
- f"'{arg_name}' can be set to True only when 'operator_export_type' is "
- "`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
- f"'{arg_name}' argument will be ignored."
- )
- arg_value = False
- return arg_value
- @_beartype.beartype
- def _decide_keep_init_as_input(
- keep_initializers_as_inputs: Optional[bool],
- operator_export_type: _C_onnx.OperatorExportTypes,
- opset_version: int,
- ):
- """Decides whether the initializers in the graph should be listed as ONNX graph inputs.
- This method encapsulates the logic to decide whether the initializers in the graph
- should be listed as ONNX graph inputs (i.e., whether to choose ONNX IR v3 or v4).
- If keep_initializers_as_inputs is not specified (None), then we decide whether to keep
- initializers as graph inputs (val_keep_init_as_ip) based on export type. If export type
- is ONNX, then do not keep initializers as input (val_keep_init_as_ip=False). For all other
- export types keep initializers as input (val_keep_init_as_ip=True).
- If keep_initializers_as_inputs is specified, then respect it. Unless opset version <= 8,
- in which case it must be ignored because for opset version <= 8, all initializers MUST be
- part of graph input (only ONNX IR v3 is allowed), i.e. val_keep_init_as_ip=True.
- Special handling is needed for opset version 8 or lower, because irrespective
- of user input for keep_initializers_as_inputs, the graph must follow ONNX IR v3
- semantics, i.e. all initializers must be listed as ONNX graph input.
- """
- if opset_version < 9:
- if keep_initializers_as_inputs is False:
- warnings.warn(
- "Setting 'keep_initializers_as_inputs=False' for opset version"
- "8 or lower would lead to an invalid ONNX graph. Therefore, "
- "'keep_initializers_as_inputs=False' is ignored during export."
- "Exported model will have initializers as graph inputs (compliant "
- " to ONNX IR v3)."
- )
- return True # i.e. True == initializers are part of graph input (ONNX IR v3)
- val_keep_init_as_ip = (
- True if keep_initializers_as_inputs is None else keep_initializers_as_inputs
- )
- if (
- keep_initializers_as_inputs is None
- and operator_export_type is _C_onnx.OperatorExportTypes.ONNX
- ):
- val_keep_init_as_ip = False
- return val_keep_init_as_ip
- @_beartype.beartype
- def _decide_add_node_names(add_node_names, operator_export_type):
- return _resolve_args_by_export_type(
- "add_node_names", add_node_names, operator_export_type
- )
- @_beartype.beartype
- def _decide_constant_folding(do_constant_folding, operator_export_type, training):
- do_constant_folding = _resolve_args_by_export_type(
- "do_constant_folding", do_constant_folding, operator_export_type
- )
- if do_constant_folding and (
- training is not None and training is not _C_onnx.TrainingMode.EVAL
- ):
- warnings.warn(
- "It is recommended that constant folding be turned off ('do_constant_folding=False') "
- "when exporting the model in training-amenable mode, i.e. with 'training=TrainingMode.TRAIN' "
- "or 'training=TrainingMode.PRESERVE' (when model is in training mode). Otherwise, some "
- "learnable model parameters may not translate correctly in the exported ONNX model "
- "because constant folding mutates model parameters. Please consider "
- "turning off constant folding or setting the training=TrainingMode.EVAL."
- )
- return do_constant_folding
- @_beartype.beartype
- def _signature(model) -> inspect.Signature:
- should_be_callable = getattr(model, "forward", model)
- if callable(should_be_callable):
- return inspect.signature(should_be_callable)
- raise ValueError("model has no forward method and is not callable")
- @_beartype.beartype
- def _decide_input_format(model, args):
- try:
- sig = _signature(model)
- except ValueError as e:
- warnings.warn(f"{e}, skipping _decide_input_format")
- return args
- try:
- ordered_list_keys = list(sig.parameters.keys())
- if ordered_list_keys[0] == "self":
- ordered_list_keys = ordered_list_keys[1:]
- args_dict: Dict = {}
- if isinstance(args, list):
- args_list = args
- elif isinstance(args, tuple):
- args_list = list(args)
- else:
- args_list = [args]
- if isinstance(args_list[-1], dict):
- args_dict = args_list[-1]
- args_list = args_list[:-1]
- n_nonkeyword = len(args_list)
- for optional_arg in ordered_list_keys[n_nonkeyword:]:
- if optional_arg in args_dict:
- args_list.append(args_dict[optional_arg])
- # Check if this arg has a default value
- else:
- param = sig.parameters[optional_arg]
- if param.default != param.empty:
- args_list.append(param.default)
- args = args_list if isinstance(args, list) else tuple(args_list)
- # Cases of models with no input args
- except IndexError:
- warnings.warn("No input args, skipping _decide_input_format")
- except Exception as e:
- warnings.warn(f"Skipping _decide_input_format\n {e.args[0]}")
- return args
- @_beartype.beartype
- def _trace(func, args, operator_export_type, return_outs=False):
- # Special case for common case of passing a single Tensor
- if isinstance(args, torch.Tensor):
- args = (args,)
- trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
- func,
- args,
- strict=False,
- _force_outplace=False,
- _return_inputs_states=True,
- )
- warn_on_static_input_change(inputs_states)
- trace_graph = _optimize_graph(trace_graph, operator_export_type, params_dict={})
- if return_outs:
- return trace_graph, torch_out
- return trace_graph
- @_beartype.beartype
- def _trace_and_get_graph_from_model(model, args):
- # A basic sanity check: make sure the state_dict keys are the same
- # before and after running the model. Fail fast!
- orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
- # Disable Autocast cache because it replaces kernel's weight and bias
- # by (undesired) constants.
- # No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
- # TODO: https://github.com/pytorch/pytorch/issues/84092
- prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
- torch.set_autocast_cache_enabled(False)
- trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
- model,
- args,
- strict=False,
- _force_outplace=False,
- _return_inputs_states=True,
- )
- torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
- warn_on_static_input_change(inputs_states)
- if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
- raise RuntimeError(
- "state_dict changed after running the tracer; "
- "something weird is happening in your model!"
- )
- return trace_graph, torch_out
- @_beartype.beartype
- def _get_param_count_list(method_graph, args_params):
- param_count_list = []
- for input_, arg_params_ in zip(method_graph.inputs(), args_params):
- if "PackedParams" in str(input_.type()):
- in_vars, _ = torch.jit._flatten(arg_params_)
- param_count_list.append(len(in_vars))
- else:
- param_count_list.append(arg_params_ is not None)
- return param_count_list
- @_beartype.beartype
- def _check_flatten_did_not_remove(original, jit_flattened):
- """torch.jit._flatten removes None. Check if it did so in this case."""
- @_beartype.beartype
- def flatten(x):
- if isinstance(x, (list, tuple)):
- for inner in x:
- yield from flatten(inner)
- elif isinstance(x, dict):
- for inner in x.values():
- yield from flatten(inner)
- else:
- yield x
- flattened_with_none = list(flatten(original))
- num_none = len(flattened_with_none) - len(jit_flattened)
- assert num_none >= 0
- if num_none:
- raise ValueError(
- f"args contained {num_none} None's after flattening. "
- "When exporting a ScriptModule or ScriptFunction, no args may "
- "be None because that breaks type propagation."
- )
- def _create_jit_graph(
- model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
- ) -> Tuple[_C.Graph, List[_C.IValue], Optional[Any], Optional[_C.ScriptModule]]:
- if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
- flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
- _check_flatten_did_not_remove(args, flattened_args)
- torch_out = None
- if isinstance(model, torch.jit.ScriptModule):
- try:
- graph = model.forward.graph # type: ignore[attr-defined]
- except AttributeError as e:
- raise RuntimeError("'forward' method must be a script method") from e
- _C._jit_pass_onnx_function_substitution(graph)
- freezed_module = _C._freeze_module(
- cast(_C.ScriptModule, model._c), preserveParameters=True
- )
- module, params = _C._jit_onnx_list_model_parameters(freezed_module)
- method_graph = module._get_method("forward").graph
- args_params = tuple(args) + tuple(params)
- param_count_list = _get_param_count_list(method_graph, args_params)
- in_vars, _ = torch.jit._flatten(args_params)
- graph = _C._propagate_and_assign_input_shapes(
- method_graph, tuple(in_vars), param_count_list, False, False
- )
- return graph, params, torch_out, module
- # torch.jit.ScriptFunction
- params = []
- graph = model.graph
- _C._jit_pass_onnx_function_substitution(graph)
- param_count_list = _get_param_count_list(graph, args)
- graph = _C._propagate_and_assign_input_shapes(
- graph, flattened_args, param_count_list, False, False
- )
- return graph, params, torch_out, None
- graph, torch_out = _trace_and_get_graph_from_model(model, args)
- _C._jit_pass_onnx_lint(graph)
- state_dict = torch.jit._unique_state_dict(model)
- params = list(state_dict.values())
- graph_inputs = list(graph.inputs())
- user_input_num = len(graph_inputs) - len(state_dict)
- param_names = list(state_dict.keys())
- for i, inp in enumerate(graph_inputs):
- if i >= user_input_num:
- inp.setDebugName(param_names[i - user_input_num])
- _C._jit_pass_onnx_function_substitution(graph)
- return graph, params, torch_out, None
- @_beartype.beartype
- def _get_named_param_dict(graph, params):
- input_and_param_names = [val.debugName() for val in graph.inputs()]
- param_names = input_and_param_names[len(input_and_param_names) - len(params) :]
- _params_dict = dict(zip(param_names, params))
- return _params_dict
- @_beartype.beartype
- def _get_example_outputs(model, args):
- input_args = copy.deepcopy(args)
- input_kwargs = {}
- if input_args and isinstance(input_args[-1], dict):
- input_kwargs = input_args[-1]
- input_args = input_args[:-1]
- example_outputs = model(*input_args, **input_kwargs)
- if isinstance(example_outputs, list):
- example_outputs = [example_outputs]
- elif not isinstance(example_outputs, tuple):
- example_outputs = (example_outputs,)
- return example_outputs
- _qtype_vtype_map = {
- torch.quint8: torch.uint8,
- torch.qint8: torch.int8,
- torch.qint32: torch.int32,
- torch.quint4x2: torch.int8,
- }
- @_beartype.beartype
- def unpack_quantized_tensor(value, cast_onnx_accepted=True):
- if isinstance(value, torch.Tensor) and value.dtype in _qtype_vtype_map:
- q_value_dequantize = value.dequantize()
- q_scale = (
- torch.tensor(value.q_scale(), dtype=torch.double)
- if cast_onnx_accepted
- else torch.tensor(value.q_scale(), dtype=torch.float32)
- )
- q_zero_point = (
- torch.tensor(value.q_zero_point(), dtype=torch.int64)
- if cast_onnx_accepted
- else torch.tensor(value.q_zero_point(), dtype=_qtype_vtype_map[value.dtype])
- )
- q_value = q_value_dequantize / q_scale + q_zero_point
- q_value = q_value.to(dtype=_qtype_vtype_map[value.dtype])
- return q_value, q_scale, q_zero_point
- else:
- return (value,)
- @_beartype.beartype
- def _pre_trace_quant_model(model, args):
- r"""Returns `torch.jit.trace(model, args)` if model is quantized. Otherwise do nothing and return
- original model.
- This is due to https://github.com/pytorch/pytorch/issues/75761.
- """
- if any(
- hasattr(m, "_packed_params") for m in getattr(model, "modules", list)()
- ) or any(getattr(arg, "is_quantized", False) for arg in args):
- return torch.jit.trace(model, args)
- return model
- @_beartype.beartype
- def _model_to_graph(
- model,
- args,
- verbose=False,
- input_names=None,
- output_names=None,
- operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
- do_constant_folding=True,
- _disable_torch_constant_prop=False,
- fixed_batch_size=False,
- training=_C_onnx.TrainingMode.EVAL,
- dynamic_axes=None,
- ) -> Tuple[
- _C.Graph,
- Dict[str, torch.Tensor],
- Optional[
- Union[
- torch.Tensor,
- Tuple[torch.Tensor, ...],
- List[torch.Tensor],
- Dict[str, torch.Tensor],
- Any, # Can be nested tuples etc.
- ]
- ],
- ]:
- """Converts model into an ONNX graph.
- Returns:
- graph: A TorchScript IR Graph with ONNX nodes.
- params_dict: Dict from input param name to param value.
- torch_out: The output tensors resulting from the trace of ``model``.
- If ``model`` is a :class:`torch.jit.ScriptModule` or :class:`torch.jit.ScriptFunction`,
- this will be None, since we are not doing any tracing.
- """
- # TODO: can we simplify this to always return a tuple of Tensor or None?
- # Special case for common case of passing a single Tensor
- if isinstance(args, (torch.Tensor, int, float, bool)):
- args = (args,)
- model = _pre_trace_quant_model(model, args)
- graph, params, torch_out, module = _create_jit_graph(model, args)
- params_dict = _get_named_param_dict(graph, params)
- try:
- graph = _optimize_graph(
- graph,
- operator_export_type,
- _disable_torch_constant_prop=_disable_torch_constant_prop,
- fixed_batch_size=fixed_batch_size,
- params_dict=params_dict,
- dynamic_axes=dynamic_axes,
- input_names=input_names,
- module=module,
- )
- except Exception as e:
- torch.onnx.log("Torch IR graph at exception: ", graph)
- raise
- is_script = isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule))
- if is_script:
- example_outputs = _get_example_outputs(model, args)
- example_outputs_final = ()
- for example_output in example_outputs:
- example_outputs_final += unpack_quantized_tensor(example_output)
- out_vars, desc = torch.jit._flatten(example_outputs_final)
- _C._jit_pass_onnx_assign_output_shape(
- graph,
- out_vars,
- desc,
- GLOBALS.onnx_shape_inference,
- is_script,
- GLOBALS.export_onnx_opset_version,
- )
- # NB: ONNX requires complete information about output types, which might be
- # erased by some optimizations, so we need to set it explicitly again.
- else:
- if not isinstance(torch_out, (list, tuple)):
- output_wrapped = [torch_out]
- else:
- output_wrapped = torch_out # type: ignore[assignment]
- output_tensors, out_desc = torch.jit._flatten(tuple(output_wrapped))
- # assign_output_shape pass is not compatible with quantized outputs.
- # Quantized outputs are flattened to 3 values in ONNX, while packed as
- # single value in PyTorch.
- if not any(getattr(out, "is_quantized", False) for out in output_tensors):
- _C._jit_pass_onnx_assign_output_shape(
- graph,
- output_tensors,
- out_desc,
- GLOBALS.onnx_shape_inference,
- is_script,
- GLOBALS.export_onnx_opset_version,
- )
- _set_input_and_output_names(graph, input_names, output_names)
- params_dict = _get_named_param_dict(graph, params)
- if training is None or training == _C_onnx.TrainingMode.EVAL:
- params_dict = _C._jit_pass_onnx_eval_peephole(graph, params_dict)
- if (
- do_constant_folding
- and GLOBALS.export_onnx_opset_version
- >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
- ):
- params_dict = _C._jit_pass_onnx_constant_fold(
- graph, params_dict, GLOBALS.export_onnx_opset_version
- )
- _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
- if GLOBALS.onnx_shape_inference:
- _C._jit_pass_onnx_graph_shape_type_inference(
- graph, params_dict, GLOBALS.export_onnx_opset_version
- )
- params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
- # For ONNX opset < 9, constants only have three data types: float16, float, double.
- # In this pass transform constants of other data types to float/double + cast operator.
- if GLOBALS.export_onnx_opset_version < 9:
- _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
- params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
- _C._jit_decay_packed_param_input_types(graph)
- # If output names lack a proper name and are identified only by their unique
- # give them a legible name for debugging purposes
- _apply_friendly_debug_names(graph, params_dict)
- return graph, params_dict, torch_out
- @_beartype.beartype
- def export_to_pretty_string(
- model,
- args,
- export_params=True,
- verbose=False,
- training=_C_onnx.TrainingMode.EVAL,
- input_names=None,
- output_names=None,
- operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
- export_type=None,
- google_printer=False,
- opset_version=None,
- keep_initializers_as_inputs=None,
- custom_opsets=None,
- add_node_names=True,
- do_constant_folding=True,
- dynamic_axes=None,
- ):
- r"""
- Similar to :func:`export`, but returns a text representation of the ONNX
- model. Only differences in args listed below. All other args are the same
- as :func:`export`.
- Args:
- add_node_names (bool, default True): Whether or not to set
- NodeProto.name. This makes no difference unless
- ``google_printer=True``.
- google_printer (bool, default False): If False, will return a custom,
- compact representation of the model. If True will return the
- protobuf's `Message::DebugString()`, which is more verbose.
- Returns:
- A UTF-8 str containing a human-readable representation of the ONNX model.
- """
- if opset_version is None:
- opset_version = _constants.ONNX_DEFAULT_OPSET
- if custom_opsets is None:
- custom_opsets = {}
- GLOBALS.export_onnx_opset_version = opset_version
- GLOBALS.operator_export_type = operator_export_type
- with exporter_context(model, training, verbose):
- val_keep_init_as_ip = _decide_keep_init_as_input(
- keep_initializers_as_inputs, operator_export_type, opset_version
- )
- val_add_node_names = _decide_add_node_names(
- add_node_names, operator_export_type
- )
- val_do_constant_folding = _decide_constant_folding(
- do_constant_folding, operator_export_type, training
- )
- args = _decide_input_format(model, args)
- graph, params_dict, torch_out = _model_to_graph(
- model,
- args,
- verbose,
- input_names,
- output_names,
- operator_export_type,
- val_do_constant_folding,
- training=training,
- dynamic_axes=dynamic_axes,
- )
- return graph._pretty_print_onnx( # type: ignore[attr-defined]
- params_dict,
- opset_version,
- False,
- operator_export_type,
- google_printer,
- val_keep_init_as_ip,
- custom_opsets,
- val_add_node_names,
- )
- @_beartype.beartype
- def unconvertible_ops(
- model,
- args,
- training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
- opset_version: Optional[int] = None,
- ) -> Tuple[_C.Graph, List[str]]:
- """Returns an approximated list of all ops that are yet supported by :mod:`torch.onnx`.
- The list is approximated because some ops may be removed during the conversion
- process and don't need to be converted. Some other ops may have partial support
- that will fail conversion with particular inputs. Please open a Github Issue
- for op support requests.
- Args:
- model: Same as the `model` parameter in :func:`torch.onnx.export`.
- args: Same as the `args` parameter in :func:`torch.onnx.export`.
- training: Same as the `training` parameter in :func:`torch.onnx.export`.
- opset_version: Same as the `opset_version` parameter in :func:`torch.onnx.export`.
- Returns:
- The JIT graph and a list of unconvertible ops in the format of "domain::op".
- """
- opset_version = opset_version or _constants.ONNX_DEFAULT_OPSET
- GLOBALS.export_onnx_opset_version = opset_version
- try:
- with exporter_context(model, training, verbose=False):
- # Create a mostly clean JIT graph that contains the plain aten and
- # other ops we can check with the symbolic registry.
- # NOTE: We don't want to actually convert any ops to ONNX or run any
- # symbolic functions because there is a higher chance that a pass
- # fails or an unconvertible op messes up the graph during ONNX conversion.
- # This way we can always generate a list just by looking at the names
- # of the ops in the graph.
- args = _decide_input_format(model, args)
- model = _pre_trace_quant_model(model, args)
- graph, _, _, module = _create_jit_graph(model, args)
- _C._jit_pass_inline(graph)
- _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
- _C._jit_pass_erase_number_types(graph)
- _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
- except Exception as e:
- raise errors.OnnxExporterError(
- "Failed to discover unconvertible ops because of errors during the JIT graph "
- "generation process."
- ) from e
- unsupported_ops = []
- for node in graph.nodes():
- domain_op = node.kind()
- if domain_op.startswith("onnx::") or domain_op.startswith("prim::"):
- # We consider onnx and prim ops as supported ops, even though some "prim"
- # ops are not implemented as symbolic functions, because they may be
- # eliminated in the conversion passes. Users may still see errors caused
- # by prim ops even though they don't show up in the list.
- continue
- if not registration.registry.is_registered_op(
- domain_op.rstrip("_"), opset_version
- ):
- # We consider all registered ops supported, even though some of them are
- # only partially supported, because there is not yet a good way to check
- # if an op is fully supported.
- # TODO(justinchuby): Create a way to check if an op is fully supported.
- unsupported_ops.append(domain_op)
- return graph, unsupported_ops
- @_beartype.beartype
- def _setup_trace_module_map(
- model: Union[torch.nn.Module, torch.jit.ScriptModule],
- export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
- ) -> Set[str]:
- def __register_attribute_hook():
- attr_name = "_onnx_attrs"
- def _track_module_attributes_forward_pre_hook(module, input):
- setattr(module, attr_name, _get_module_attributes(module))
- def _track_module_attributes_forward_hook(module, input, output):
- tracing_state = _C._get_tracing_state()
- if not tracing_state:
- return
- graph = tracing_state.graph()
- onnx_attrs = {}
- if hasattr(module, attr_name):
- onnx_attrs = getattr(module, attr_name)
- delattr(module, attr_name)
- _C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
- for m in model.modules():
- m.register_forward_hook(_track_module_attributes_forward_hook)
- m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
- def _unqualified_variable_name(qualified_name: str) -> str:
- """
- Parse qualified variable name and return the unqualified version.
- Pure numeric atoms are considered inadequate, so this function will look past them,
- and start from the first non-numeric atom.
- Example:
- >>> _unqualified_variable_name('__main__.Foo.bar')
- 'bar'
- >>> _unqualified_variable_name('__main__.Foo.bar.0')
- 'bar.0'
- """
- name_atoms = qualified_name.split(".")
- for i, atom in reversed(list(enumerate(name_atoms))):
- if not atom.isnumeric():
- return ".".join(name_atoms[i:])
- return qualified_name
- trace_module_map = {
- _m: torch._C._jit_onnx_create_full_scope_name(
- torch.typename(type(_m)), _unqualified_variable_name(_n)
- )
- for _n, _m in model.named_modules()
- }
- torch.jit._trace._trace_module_map = trace_module_map
- if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
- module_typenames = {torch.typename(type(module)) for module in trace_module_map}
- elif isinstance(export_modules_as_functions, set) and export_modules_as_functions:
- def _find_typename(v):
- if isinstance(v, type):
- return torch.typename(v)
- else:
- raise RuntimeError(
- "Only type of the `nn.Module` should be "
- "passed in the set for argument `export_modules_as_functions`. "
- "Got `%s`." % (type(v).__name__)
- )
- module_typenames = {_find_typename(v) for v in export_modules_as_functions}
- else:
- module_typenames = set()
- if module_typenames:
- __register_attribute_hook()
- return module_typenames
- @_beartype.beartype
- def _reset_trace_module_map():
- torch.jit._trace._trace_module_map = None
- _C._jit_pass_onnx_clear_scope_records()
- @_beartype.beartype
- def _get_module_attributes(module):
- annotations = typing.get_type_hints(type(module))
- base_m_annotations = typing.get_type_hints(torch.nn.Module)
- [annotations.pop(k, None) for k in base_m_annotations]
- return {k: getattr(module, k) for k in annotations}
- @_beartype.beartype
- def _export(
- model,
- args,
- f,
- export_params=True,
- verbose=False,
- training=_C_onnx.TrainingMode.EVAL,
- input_names=None,
- output_names=None,
- operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
- export_type=None,
- opset_version=None,
- do_constant_folding=True,
- dynamic_axes=None,
- keep_initializers_as_inputs=None,
- fixed_batch_size=False,
- custom_opsets=None,
- add_node_names=True,
- onnx_shape_inference=True,
- export_modules_as_functions=False,
- ):
- assert GLOBALS.in_onnx_export is False
- if export_type is None:
- export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
- # Discussed deprecation with Nikita Shulga and Sergii Dymchenko from Meta
- if _C_onnx._CAFFE2_ATEN_FALLBACK:
- warnings.warn(
- "Caffe2 ONNX exporter is deprecated in version 2.0 and will be "
- "removed in 2.2. Please use PyTorch 2.1 or older for this capability.",
- category=FutureWarning,
- stacklevel=2,
- )
- if isinstance(model, torch.nn.DataParallel):
- raise ValueError(
- "torch.nn.DataParallel is not supported by ONNX "
- "exporter, please use 'attribute' module to "
- "unwrap model from torch.nn.DataParallel. Try "
- "torch.onnx.export(model.module, ...)"
- )
- GLOBALS.onnx_shape_inference = onnx_shape_inference
- if opset_version is None:
- opset_version = _constants.ONNX_DEFAULT_OPSET
- if export_modules_as_functions and opset_version < 15:
- raise ValueError(
- "`export_modules_as_functions` is not supported for `opset_version` < 15."
- "This is because `opset_version` < 15 implies IR version < 8, which means "
- "no local function support. "
- )
- if not operator_export_type:
- if _C_onnx._CAFFE2_ATEN_FALLBACK:
- operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
- else:
- operator_export_type = _C_onnx.OperatorExportTypes.ONNX
- # By default, training=TrainingMode.EVAL,
- # which is good because running a model in training mode could result in
- # internal buffers getting updated, dropout getting applied, etc.
- # If you really know what you're doing, you can turn
- # training=TrainingMode.TRAINING or training=TrainingMode.PRESERVE,
- # (to preserve whatever the original training mode was.)
- GLOBALS.export_onnx_opset_version = opset_version
- GLOBALS.operator_export_type = operator_export_type
- try:
- GLOBALS.in_onnx_export = True
- module_typenames_to_export_as_functions: Set[str] = set()
- if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
- module_typenames_to_export_as_functions = _setup_trace_module_map(
- model, export_modules_as_functions
- )
- with exporter_context(model, training, verbose):
- val_keep_init_as_ip = _decide_keep_init_as_input(
- keep_initializers_as_inputs,
- operator_export_type,
- opset_version,
- )
- val_add_node_names = _decide_add_node_names(
- add_node_names, operator_export_type
- )
- val_do_constant_folding = _decide_constant_folding(
- do_constant_folding, operator_export_type, training
- )
- # Normally f can be a file-like object, but for large models, the external data format requires a
- # valid `model_file_location`. Code in export.cpp will enforce this.
- if isinstance(f, str):
- model_file_location = f
- else:
- model_file_location = ""
- args = _decide_input_format(model, args)
- if dynamic_axes is None:
- dynamic_axes = {}
- _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
- graph, params_dict, torch_out = _model_to_graph(
- model,
- args,
- verbose,
- input_names,
- output_names,
- operator_export_type,
- val_do_constant_folding,
- fixed_batch_size=fixed_batch_size,
- training=training,
- dynamic_axes=dynamic_axes,
- )
- # TODO: Don't allocate a in-memory string for the protobuf
- defer_weight_export = (
- export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE
- )
- if custom_opsets is None:
- custom_opsets = {}
- _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
- node_attr_to_name = {} # type: ignore[var-annotated]
- if module_typenames_to_export_as_functions:
- # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
- node_attr_to_name = _C._jit_pass_onnx_function_extraction(
- graph,
- module_typenames_to_export_as_functions,
- list(params_dict.keys()),
- )
- params_dict = _C._jit_pass_onnx_deduplicate_initializers( # type: ignore[assignment]
- graph, params_dict, getattr(model, "training", False) # type: ignore[arg-type]
- )
- _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph)
- if export_params:
- (
- proto,
- export_map,
- val_use_external_data_format,
- node_names,
- ) = graph._export_onnx( # type: ignore[attr-defined]
- params_dict,
- opset_version,
- dynamic_axes,
- defer_weight_export,
- operator_export_type,
- not verbose,
- val_keep_init_as_ip,
- custom_opsets,
- val_add_node_names,
- model_file_location,
- node_attr_to_name,
- )
- else:
- (
- proto,
- export_map,
- val_use_external_data_format,
- node_names,
- ) = graph._export_onnx( # type: ignore[attr-defined]
- {},
- opset_version,
- dynamic_axes,
- False,
- operator_export_type,
- not verbose,
- val_keep_init_as_ip,
- custom_opsets,
- val_add_node_names,
- model_file_location,
- node_attr_to_name,
- )
- # insert function_proto into model_proto.
- proto = onnx_proto_utils._add_onnxscript_fn(
- proto,
- custom_opsets,
- )
- if verbose:
- torch.onnx.log("Exported graph: ", graph)
- onnx_proto_utils._export_file(proto, f, export_type, export_map)
- # The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX,
- # we can skip this check.
- # If large model format export is enabled, proto will only contain data location instead of
- # raw data and _check_onnx_proto() will fail because it can only handle the raw ONNX proto
- # string in memory.
- if (operator_export_type is _C_onnx.OperatorExportTypes.ONNX) and (
- not val_use_external_data_format
- ):
- try:
- _C._check_onnx_proto(proto)
- except RuntimeError as e:
- raise errors.CheckerError(e) from e
- finally:
- assert GLOBALS.in_onnx_export
- GLOBALS.in_onnx_export = False
- _reset_trace_module_map()
- return torch_out
- @_beartype.beartype
- def _apply_friendly_debug_names(graph, params):
- for n in graph.nodes():
- for v in n.inputs():
- old_name = v.debugName()
- if old_name != str(v.unique()):
- continue
- new_name = f"{n.kind()}_{v.unique()}"
- v.setDebugName(new_name)
- if old_name in params:
- params[new_name] = params.pop(old_name)
- @_beartype.beartype
- def _set_input_and_output_names(graph, input_names, output_names):
- @_beartype.beartype
- def set_names(node_list, name_list, descriptor):
- if name_list is None:
- return
- if len(name_list) > len(node_list):
- raise RuntimeError(
- "number of %s names provided (%d) exceeded number of %ss (%d)"
- % (descriptor, len(name_list), descriptor, len(node_list))
- )
- # Mark if the output node DebugName is set before.
- output_node_set = set()
- for i, (name, node) in enumerate(zip(name_list, node_list)):
- # Duplicated output node, insert onnx::Identity to avoid setting the same DebugName after setDebugName().
- if descriptor == "output":
- if node in output_node_set:
- identity_node = graph.create("onnx::Identity")
- identity_node.insertAfter(node.node())
- identity_node.addInput(node)
- identity_node.output().setType(node.type())
- graph.return_node().replaceInput(i, identity_node.output())
- node = identity_node.output()
- output_node_set.add(node)
- if node.debugName() != name:
- node.setDebugName(name)
- set_names(list(graph.inputs()), input_names, "input")
- set_names(list(graph.outputs()), output_names, "output")
- @_beartype.beartype
- def _run_symbolic_method(g, op_name, symbolic_fn, args):
- r"""
- This trampoline function gets invoked for every symbolic method
- call from C++.
- """
- try:
- graph_context = jit_utils.GraphContext(
- graph=g,
- block=g.block(),
- opset=GLOBALS.export_onnx_opset_version,
- original_node=None, # type: ignore[arg-type]
- params_dict=_params_dict,
- env={},
- )
- return symbolic_fn(graph_context, *args)
- except TypeError as e:
- # Handle the specific case where we didn't successfully dispatch
- # to symbolic_fn. Otherwise, the backtrace will have the clues
- # you need.
- e.args = (f"{e.args[0]} (occurred when translating {op_name})",)
- raise
- @_beartype.beartype
- def _add_block(node: _C.Node) -> _C.Block:
- return node.addBlock()
- @_beartype.beartype
- def _add_input_to_block(block: _C.Block):
- return block.addInputToBlock() # type: ignore[attr-defined]
- @_beartype.beartype
- def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:
- return block.registerOutput(value)
- @_beartype.beartype
- def _should_aten_fallback(
- name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
- ):
- # For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
- # an aten::ATen operator is created regardless of symbolics existence
- # For BUILD_CAFFE2=1, the same applies only if there is no symbolic available
- is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
- is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
- is_aten_fallback_export = (
- operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
- )
- is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK
- if not name.startswith("aten::"):
- return False
- if is_caffe2_build:
- if (
- is_onnx_aten_export or is_aten_fallback_export
- ) and not is_exportable_aten_op:
- return True
- else:
- if is_onnx_aten_export or (
- is_aten_fallback_export and not is_exportable_aten_op
- ):
- return True
- return False
- @_beartype.beartype
- def _need_symbolic_context(symbolic_fn: Callable) -> bool:
- """Checks if the first argument to symbolic_fn is annotated as type `torch.onnx.SymbolicContext`."""
- params = tuple(inspect.signature(symbolic_fn).parameters.values())
- # When the annotation is postpone-evaluated, the annotation is a string
- # and not a type. We need to use get_type_hints to get the real type.
- if not params:
- return False
- first_param_name = params[0].name
- type_hints = typing.get_type_hints(symbolic_fn)
- if first_param_name not in type_hints:
- return False
- param_type = type_hints[first_param_name]
- return issubclass(param_type, _exporter_states.SymbolicContext)
- @_beartype.beartype
- def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
- """Decorator that provides the symbolic context to the symbolic function if needed."""
- if _need_symbolic_context(symbolic_fn):
- # TODO(justinchuby): Update the module name of GraphContext when it is public
- warnings.warn(
- "The first argument to symbolic functions is deprecated in 1.13 and will be "
- "removed in the future. Please annotate treat the first argument (g) as GraphContext "
- "and use context information from the object instead.",
- category=FutureWarning,
- )
- def wrapper(graph_context: jit_utils.GraphContext, *args, **kwargs):
- symbolic_context = _exporter_states.SymbolicContext(
- params_dict=graph_context.params_dict,
- env=graph_context.env,
- cur_node=graph_context.original_node,
- onnx_block=graph_context.block,
- )
- return symbolic_fn(symbolic_context, graph_context, *args, **kwargs)
- return wrapper
- return symbolic_fn
- @_beartype.beartype
- def _get_aten_op_overload_name(n: _C.Node) -> str:
- # Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
- schema = n.schema()
- if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
- return ""
- return _C.parse_schema(schema).overload_name
- @_beartype.beartype
- def _run_symbolic_function(
- graph: _C.Graph,
- block: _C.Block,
- node: _C.Node,
- inputs: Any,
- env: Dict[_C.Value, _C.Value],
- operator_export_type=_C_onnx.OperatorExportTypes.ONNX,
- ) -> Optional[Union[_C.Value, Sequence[Optional[_C.Value]]]]:
- """Runs a symbolic function.
- The function is used in C++ to export the node to ONNX.
- Returns:
- A single or a tuple of Values.
- None when the node gets cloned as is into the new graph.
- """
- opset_version = GLOBALS.export_onnx_opset_version
- # See Note [Export inplace]
- node_kind = node.kind()
- if node_kind.endswith("_"):
- # Treat relu_ -> relu; add_ -> add etc.
- ns_op_name = node_kind[:-1]
- else:
- ns_op_name = node_kind
- namespace, op_name = jit_utils.parse_node_kind(ns_op_name)
- graph_context = jit_utils.GraphContext(
- graph=graph,
- block=block,
- opset=opset_version,
- original_node=node,
- params_dict=_params_dict,
- env=env,
- )
- # Direct ATen export requested
- if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
- attrs = {
- k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
- for k in node.attributeNames()
- }
- outputs = node.outputsSize()
- attrs["outputs"] = outputs
- return graph_context.aten_op(
- op_name,
- *inputs,
- overload_name=_get_aten_op_overload_name(node),
- **attrs,
- )
- try:
- # Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
- if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
- symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
- if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
- domain = "caffe2"
- else:
- domain = namespace
- symbolic_function_name = f"{domain}::{op_name}"
- symbolic_function_group = registration.registry.get_function_group(
- symbolic_function_name
- )
- if symbolic_function_group is not None:
- symbolic_fn = symbolic_function_group.get(opset_version)
- if symbolic_fn is not None:
- # TODO Wrap almost identical attrs assignment or comment the difference.
- attrs = {
- k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
- }
- return symbolic_fn(graph_context, *inputs, **attrs)
- attrs = {
- k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
- for k in node.attributeNames()
- }
- if namespace == "onnx":
- # Clone node to trigger ONNX shape inference
- return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]
- raise errors.UnsupportedOperatorError(
- symbolic_function_name,
- opset_version,
- symbolic_function_group.get_min_supported()
- if symbolic_function_group
- else None,
- )
- except RuntimeError:
- if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
- return None
- elif (
- operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
- and not symbolic_helper.is_caffe2_aten_fallback()
- ):
- # Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
- attrs = {
- k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
- for k in node.attributeNames()
- }
- return graph_context.aten_op(
- op_name,
- *inputs,
- overload_name=_get_aten_op_overload_name(node),
- **attrs,
- )
- raise
- except TypeError as e:
- # Handle the specific case where we didn't successfully dispatch.
- # Otherwise, the backtrace will have the clues you need.
- e.args = (f"{e.args[0]} \n(Occurred when translating {op_name}).",)
- raise
- @_beartype.beartype
- def _verify_custom_op_name(symbolic_name: str):
- if not re.match(r"^[a-zA-Z0-9-_]+::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name):
- raise errors.OnnxExporterError(
- f"Failed to register operator {symbolic_name}. "
- "The symbolic name must match the format domain::name, "
- "and should start with a letter and contain only "
- "alphanumerical characters"
- )
- ns, _ = jit_utils.parse_node_kind(symbolic_name)
- if ns == "onnx":
- raise ValueError(
- f"Failed to register operator {symbolic_name}. {ns} domain cannot be modified."
- )
- @_beartype.beartype
- def register_custom_op_symbolic(
- symbolic_name: str,
- symbolic_fn: Callable,
- opset_version: int,
- ):
- """Registers a symbolic function for a custom operator.
- When the user registers symbolic for custom/contrib ops,
- it is highly recommended to add shape inference for that operator via setType API,
- otherwise the exported graph may have incorrect shape inference in some extreme cases.
- An example of setType is `test_aten_embedding_2` in `test_operators.py`.
- See "Custom Operators" in the module documentation for an example usage.
- Args:
- symbolic_name (str): The name of the custom operator in "<domain>::<op>"
- format.
- symbolic_fn (Callable): A function that takes in the ONNX graph and
- the input arguments to the current operator, and returns new
- operator nodes to add to the graph.
- opset_version (int): The ONNX opset version in which to register.
- """
- if symbolic_name.startswith("::"):
- symbolic_name = f"aten{symbolic_name}"
- _verify_custom_op_name(symbolic_name)
- registration.custom_onnx_symbolic(
- symbolic_name,
- opset_version,
- decorate=[
- _symbolic_context_handler,
- ],
- )(symbolic_fn)
- @_beartype.beartype
- def unregister_custom_op_symbolic(symbolic_name: str, opset_version: int):
- """Unregisters ``symbolic_name``.
- See "Custom Operators" in the module documentation for an example usage.
- Args:
- symbolic_name (str): The name of the custom operator in "<domain>::<op>"
- format.
- opset_version (int): The ONNX opset version in which to unregister.
- """
- if symbolic_name.startswith("::"):
- symbolic_name = f"aten{symbolic_name}"
- _verify_custom_op_name(symbolic_name)
- registration.registry.unregister(symbolic_name, opset_version)
- @_beartype.beartype
- def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
- """Ensures dynamic axes argument is follows the expected format."""
- if len(dynamic_axes) == 0:
- return
- if hasattr(model, "graph"):
- # Extracting set of valid input/output names that shall be used for dynamic_axes
- if (input_names is None) or len(input_names) == 0:
- input_names = [x.debugName() for x in model.graph.inputs()]
- if (output_names is None) or len(output_names) == 0:
- output_names = [y.debugName() for y in model.graph.outputs()]
- valid_names = set((input_names or []) + (output_names or []))
- # If dynamic axes are provided as a list rather than dictionary, they should
- # first get converted to a dictionary in expected format. If desired axes names
- # are not provided for dynamic axes, automatic names shall be generated for
- # provided dynamic axes of specified input/output
- for key, value in dynamic_axes.items():
- if key not in valid_names:
- warnings.warn(
- f"Provided key {key} for dynamic axes is not a valid input/output name"
- )
- if isinstance(value, list):
- warnings.warn(
- "No names were found for specified dynamic axes of provided input."
- f"Automatically generated names will be applied to each dynamic axes of input {key}"
- )
- value_dict = {}
- for i, x in enumerate(value):
- if not isinstance(x, int):
- raise ValueError(
- "The type of axis index is expected to be an integer"
- )
- if x in value_dict:
- warnings.warn(
- f"Duplicate dynamic axis index {x} was provided for input {key}."
- )
- else:
- value_dict[x] = str(key) + "_dynamic_axes_" + str(i + 1)
- dynamic_axes[key] = value_dict
|