exporter.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164
  1. from __future__ import annotations
  2. import copy
  3. import functools
  4. import inspect
  5. import itertools
  6. import operator
  7. import os
  8. import re
  9. import warnings
  10. from types import FunctionType
  11. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  12. import numpy as np
  13. import onnx
  14. import onnxscript # type: ignore[import]
  15. from onnxscript import evaluator # type: ignore[import]
  16. from onnxscript.function_libs.torch_aten import graph_building # type: ignore[import]
  17. import torch
  18. import torch._C
  19. import torch._decomp
  20. import torch._dynamo
  21. import torch._ops
  22. import torch.fx
  23. from torch._subclasses import fake_tensor
  24. from torch.fx.experimental import proxy_tensor
  25. from torch.fx.passes import fake_tensor_prop
  26. from torch.nn.utils import stateless
  27. from torch.onnx import _constants, _type_utils
  28. from torch.onnx._internal import _beartype
  29. from torch.onnx._internal.fx import diagnostics, function_dispatcher, options
  30. from torch.utils import _pytree
  31. # TODO: Separate into individual components.
  32. # TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
  33. def _onnx_function_diagnose_call_message_formatter(
  34. fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
  35. ) -> str:
  36. if len(args) > 0 and isinstance(args[0], onnxscript.OnnxFunction):
  37. onnx_function: onnxscript.OnnxFunction = args[0] # self
  38. return f"{onnx_function.name}: {onnxscript.OnnxFunction}"
  39. return f"{fn.__name__}: {fn}"
  40. def _onnx_function_diagnose_call_append_symbolic_source_location(
  41. diagnostic: diagnostics.infra.Diagnostic,
  42. fn: Callable,
  43. args: Tuple[Any, ...],
  44. kwargs: Dict[str, Any],
  45. return_values: Any,
  46. ) -> None:
  47. # TODO(bowbao): Record source location of symbolic.
  48. # Need this separate step because normally only the source location of
  49. # class `onnxscript.OnnxFunction.__call__` is recorded.
  50. pass
  51. # TODO(bowbao): Delete this once diagnostics is introduced in onnxscript.
  52. _diagnose_onnx_function = diagnostics.diagnose_call(
  53. rule=diagnostics.rules.atenlib_symbolic_function,
  54. diagnostic_message_formatter=_onnx_function_diagnose_call_message_formatter,
  55. diagnostic_modifier=_onnx_function_diagnose_call_append_symbolic_source_location,
  56. )
  57. for key, onnx_function in function_dispatcher._ATENLIB_FUNCTIONS.items():
  58. if isinstance(onnx_function, FunctionType):
  59. function_dispatcher._ATENLIB_FUNCTIONS[key] = _diagnose_onnx_function(
  60. onnx_function
  61. )
  62. onnxscript.OnnxFunction.__call__ = _diagnose_onnx_function(
  63. onnxscript.OnnxFunction.__call__
  64. )
  65. class ModuleExpansionTracer(torch.fx._symbolic_trace.Tracer):
  66. """Tracer to create ONNX-exporting friendly FX graph.
  67. This tracer traces models into operators. That is,
  68. the traced graph mostly contains call_function nodes and
  69. has no call_module nodes. The call_module nodes
  70. are problematic to the use of make_fx(...) in ONNX
  71. exporter.
  72. """
  73. @_beartype.beartype
  74. def is_leaf_module(
  75. self, module: torch.nn.Module, module_qualified_name: str
  76. ) -> bool:
  77. # This returns False so that all sub-modules are considered as not leaves
  78. # and therefore expanded into operators in
  79. # torch.fx._symbolic_trace.Tracer.call_module.
  80. return False
  81. @_beartype.beartype
  82. def to_bool(self, obj: "torch.fx.Proxy") -> bool:
  83. # This is a hack to tracing through if-else Python blocks.
  84. # It may generate incorrect ONNX graphs if the if-else block
  85. return False
  86. # Functions directly wrapped to produce torch.fx.Proxy so that symbolic
  87. # data can flow through those functions. Python functions (e.g., `torch.arange`)
  88. # not defined by pybind11 in C++ do not go though Python dispatcher, so
  89. # they are not automatically patched by FX's Python dispatcher.
  90. # The list below means `torch.arange`, `torch.tensor`, and so on will be
  91. # patched.
  92. _TORCH_METHODS_TO_PATCH: Tuple[str, ...] = (
  93. "arange",
  94. "tensor",
  95. "finfo",
  96. "full",
  97. "empty",
  98. )
  99. def _wrap_for_symbolic_trace(target: Callable) -> Tuple[Callable, Callable]:
  100. """This function wraps ```target`` for symbolic tracing.
  101. This function wraps ```target``` so that its wrapper produces
  102. torch.fx.Proxy in symbolic computation. The returned values are
  103. the wrapper and then the original function. Per `_TORCH_METHODS_TO_PATCH`,
  104. this function shall receive `torch.arange`, `torch.tensor`, etc. as inputs.
  105. """
  106. @functools.wraps(target)
  107. def wrapper(*args, **kwargs):
  108. proxy = None
  109. def check_has_proxy(v):
  110. if isinstance(v, torch.fx.Proxy):
  111. nonlocal proxy
  112. proxy = v
  113. torch.fx.node.map_aggregate(args, check_has_proxy)
  114. torch.fx.node.map_aggregate(kwargs, check_has_proxy)
  115. if proxy is not None:
  116. return proxy.tracer.create_proxy("call_function", target, args, kwargs)
  117. else:
  118. return target(*args, **kwargs)
  119. return wrapper, target
  120. @_beartype.beartype
  121. def _module_expansion_symbolic_trace(
  122. root: Union[torch.nn.Module, Callable[..., Any]],
  123. concrete_args: Optional[Dict[str, Any]] = None,
  124. ) -> "torch.fx.GraphModule":
  125. """Trace a callable into FX graph.
  126. When "root" is torch.nn.Module, calls to its submodule (type: torch.nn.Module) will be
  127. expanded into operators (e.g., torch.matmul, torch.add, +, and -) to simplify graph
  128. structure.
  129. """
  130. # For functions doesn't support symbolic tracing, create wrappers
  131. # which produce symbolic results during tracing.
  132. patched_torch_methods = {
  133. target_name: _wrap_for_symbolic_trace(getattr(torch, target_name))
  134. for target_name in _TORCH_METHODS_TO_PATCH
  135. }
  136. # Set the symbolic-tracing friendly functions so that `tracer.trace` below
  137. # can work.
  138. for name, (wrapper, _) in patched_torch_methods.items():
  139. setattr(torch, name, wrapper)
  140. try:
  141. # Set up a tracer.
  142. tracer = ModuleExpansionTracer()
  143. # Trace the model.
  144. graph = tracer.trace(root, concrete_args)
  145. name = (
  146. root.__class__.__name__
  147. if isinstance(root, torch.nn.Module)
  148. else root.__name__
  149. )
  150. return torch.fx.GraphModule(tracer.root, graph, name)
  151. finally:
  152. # Revert the patches for symbolic tracing.
  153. for name, (_, wrapped) in patched_torch_methods.items():
  154. # wrapped is the original version of `torch.name`.
  155. setattr(torch, name, wrapped)
  156. def _retrieve_or_adapt_input_to_graph_set(fx_node_arg, fx_name_to_onnxscipt_value):
  157. """Map FX value to TorchScript value.
  158. When creating TorchScript graph from FX graph, we need a mapping from FX variable
  159. to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value.
  160. """
  161. onnx_tensor = fx_node_arg
  162. if isinstance(onnx_tensor, torch.fx.Node):
  163. # 1. fx_node_arg is a torch.fx.Node, which means
  164. # fx_node_arg stands for the output of that torch.fx.Node.
  165. # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to
  166. # torch.jit.Value, fx_name_to_onnxscipt_value[fx_node_arg.name],
  167. # in TorchScript graph.
  168. onnx_tensor = fx_name_to_onnxscipt_value[onnx_tensor.name]
  169. elif isinstance(onnx_tensor, torch.dtype):
  170. onnx_tensor = int(_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type())
  171. return onnx_tensor
  172. def _filter_incompatible_kwargs(kwargs):
  173. """Filter out kwargs that are not supported by onnxscript."""
  174. filtered = {}
  175. for key, value in kwargs.items():
  176. if key in {
  177. "layout",
  178. "device",
  179. "requires_grad",
  180. "pin_memory",
  181. "memory_format",
  182. }:
  183. continue
  184. if key == "dtype":
  185. if value is None:
  186. filtered["dtype"] = -1
  187. else:
  188. filtered["dtype"] = int(
  189. _type_utils.JitScalarType.from_dtype(value).onnx_type()
  190. )
  191. continue
  192. filtered[key] = value
  193. return filtered
  194. def _wrap_fx_args_as_onnxscript_args(
  195. node: torch.fx.Node,
  196. fx_name_to_onnxscipt_value: Dict[
  197. str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
  198. ],
  199. ) -> Tuple[tuple, dict, tuple, dict]:
  200. """Map all FX arguments of a node to arguments in TorchScript graph."""
  201. # This function assumes the order of arguments in FX op is the
  202. # same as the order of arguments in TorchScript op.
  203. # (1) Complete the arguments with default values.
  204. complete_args: List[Any] = []
  205. complete_kwargs: Dict[str, Any] = {}
  206. if inspect.isbuiltin(node.target):
  207. complete_args = list(node.args)
  208. else:
  209. for i, expected_arg in enumerate(node.target._schema.arguments): # type: ignore[union-attr]
  210. if i < len(node.args):
  211. complete_args.append(node.args[i])
  212. else:
  213. if expected_arg.name in node.kwargs:
  214. complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name]
  215. else:
  216. # Get default from schema.
  217. complete_kwargs[expected_arg.name] = expected_arg.default_value
  218. graph_args = tuple(
  219. _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscipt_value)
  220. for arg in complete_args
  221. )
  222. graph_kwargs = _filter_incompatible_kwargs(complete_kwargs)
  223. # prepare torch format args and kwargs for op-level validation
  224. # Use fake tensor to create real tensor to feed in ops
  225. torch_args = []
  226. for arg in complete_args:
  227. if isinstance(arg, torch.fx.Node):
  228. # Create a concreate test tensor based on the fake tensor
  229. with torch.utils._mode_utils.no_dispatch():
  230. # TODO(titaiwang): improve engineering
  231. if isinstance(arg.meta["val"], list):
  232. for meta_value in arg.meta["val"]:
  233. torch_args.append(
  234. torch.randn_like(meta_value, dtype=torch.float)
  235. )
  236. else:
  237. torch_args.append(
  238. torch.randn_like(arg.meta["val"], dtype=torch.float)
  239. )
  240. else:
  241. torch_args.append(arg)
  242. torch_kwargs = complete_kwargs
  243. return (graph_args, graph_kwargs, tuple(torch_args), torch_kwargs)
  244. def _fill_tensor_meta(
  245. onnxscript_values,
  246. name: str,
  247. expected_values: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
  248. ):
  249. """Fill the meta information of onnxscript_values with that from the fx FakeTensor."""
  250. flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values)
  251. flat_expected_values, _ = _pytree.tree_flatten(expected_values)
  252. for i, (onnxscript_value, expected_value) in enumerate(
  253. zip(flat_onnxscript_values, flat_expected_values)
  254. ):
  255. # Only set shape for now as we don't need type information.
  256. onnxscript_value.shape = tuple(expected_value.size())
  257. if i > 0:
  258. onnxscript_value.name = f"{name}_{i}"
  259. else:
  260. onnxscript_value.name = name
  261. def _location_from_fx_stack_trace(
  262. node_stack_trace: str,
  263. ) -> Optional[diagnostics.infra.Location]:
  264. """Extract location from FX node stack trace.
  265. Args:
  266. node_stack_trace: The stack trace of the FX node. Example:
  267. File "path/file.py", line 311, in <function>
  268. <code>
  269. | File "path/file2.py", line 389, in <function>
  270. <code>
  271. Returns:
  272. location: The location of the FX node.
  273. """
  274. if "File" not in node_stack_trace:
  275. return None
  276. lines = node_stack_trace.strip().split("\n")
  277. idx = 0
  278. while idx < len(lines) and "File" not in lines[idx]:
  279. idx += 1
  280. if idx + 1 >= len(lines):
  281. return None
  282. pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
  283. matches = pattern.match(lines[idx].strip())
  284. if matches:
  285. uri = matches.group(1)
  286. line_number = int(matches.group(2))
  287. snippet = lines[idx + 1].strip()
  288. return diagnostics.infra.Location(uri=uri, line=line_number, snippet=snippet)
  289. return None
  290. @_beartype.beartype
  291. def _fx_node_to_onnx_message_formatter(
  292. fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
  293. ) -> str:
  294. assert len(args) > 0
  295. node = args[0]
  296. assert isinstance(node, torch.fx.Node)
  297. return f"FX Node: {node.op}:{node.target}[name={node.name}]"
  298. @_beartype.beartype
  299. @diagnostics.diagnose_call(
  300. rule=diagnostics.rules.fx_node_to_onnx,
  301. exception_report_level=diagnostics.levels.ERROR,
  302. diagnostic_message_formatter=_fx_node_to_onnx_message_formatter,
  303. )
  304. def _export_fx_node_to_onnxscript(
  305. node: torch.fx.Node,
  306. onnxscript_graph: graph_building.TorchScriptGraph,
  307. fx_name_to_onnxscipt_value: Dict[
  308. str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
  309. ],
  310. onnxscript_value_name_to_real_tensor: Dict[
  311. str, Union[torch.Tensor, Tuple[torch._C.Value, ...]]
  312. ],
  313. tracer: graph_building.TorchScriptTracingEvaluator,
  314. fx_module_with_metadata: torch.fx.GraphModule,
  315. options: options.ExportOptions,
  316. ):
  317. # Record stack trace of node in diagnostic.
  318. node_stack_trace = node.stack_trace
  319. if node_stack_trace:
  320. diagnostic = diagnostics.export_context().inflight_diagnostic(
  321. rule=diagnostics.rules.fx_node_to_onnx
  322. )
  323. diagnostic.with_additional_message(
  324. f"### PyTorch source information\n```\n{node_stack_trace}\n```"
  325. )
  326. location = _location_from_fx_stack_trace(node_stack_trace)
  327. if location is not None:
  328. diagnostic.with_location(location)
  329. if node.op == "placeholder":
  330. # Input of graph.
  331. output = onnxscript_graph.add_input(
  332. input_name=node.name,
  333. # The node.meta["val"] is generated by FakeTensorProp.
  334. input_value=node.meta["val"],
  335. )
  336. assert (
  337. output is not None
  338. ), f"Node creates None with target={node.target} and name={node.name}"
  339. assert isinstance(output, graph_building.TorchScriptTensor)
  340. assert isinstance(output, onnxscript.tensor.Tensor)
  341. fx_name_to_onnxscipt_value[node.name] = output
  342. elif node.op == "call_function":
  343. # aten ops and other stateless functions.
  344. if node.target == operator.getitem and isinstance(
  345. fx_name_to_onnxscipt_value[node.args[0].name], tuple # type: ignore[union-attr]
  346. ):
  347. onnx_tensor_tuple = fx_name_to_onnxscipt_value[node.args[0].name] # type: ignore[union-attr]
  348. index = node.args[1]
  349. output = onnx_tensor_tuple[index] # type: ignore[index]
  350. assert (
  351. output is not None
  352. ), f"Node creates None with target={node.target} and name={node.name}"
  353. assert isinstance(output, (graph_building.TorchScriptTensor, tuple)), type(
  354. output
  355. )
  356. fx_name_to_onnxscipt_value[node.name] = output
  357. return
  358. if node.target == operator.getitem:
  359. # __getitem__ on Tensor or Sequence of tensors. Not tuple.
  360. exporter_key = "getitem"
  361. elif (
  362. isinstance(node.target, torch._ops.OpOverload)
  363. and node.target in function_dispatcher._OP_OVERLOAD_TO_EXPORTER_KEY_TABLE
  364. ):
  365. exporter_key = function_dispatcher._OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[
  366. node.target
  367. ]
  368. else:
  369. raise RuntimeError(f"Unknown call_function target: {node.target}")
  370. # Only the latest opset version is only supported in atenlib for now
  371. symbolic_fn = function_dispatcher._ATENLIB_FUNCTIONS.get(exporter_key)
  372. if symbolic_fn is None:
  373. raise RuntimeError(f"Cannot find function for {exporter_key}")
  374. # Map FX inputs to ONNX inputs and fill optional inputs with default values.
  375. # torch_args and torch_kwargs are for op-level validation
  376. (
  377. onnx_args,
  378. onnx_kwargs,
  379. torch_args,
  380. torch_kwargs,
  381. ) = _wrap_fx_args_as_onnxscript_args(node, fx_name_to_onnxscipt_value)
  382. with evaluator.default_as(tracer):
  383. output: Union[ # type: ignore[no-redef]
  384. graph_building.TorchScriptTensor,
  385. Tuple[graph_building.TorchScriptTensor],
  386. ] = symbolic_fn(*onnx_args, **onnx_kwargs)
  387. assert (
  388. output is not None
  389. ), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
  390. # TODO(justinchuby): Add diagnostic information.
  391. # Assign type and shape obtained from FakeTensorProp.
  392. _fill_tensor_meta(output, node.name, node.meta["val"])
  393. # One fx node could produce multiple outputs (e.g., tuple of tensors); in
  394. # that case, v is a tuple of TorchScriptTensors.
  395. assert isinstance(output, (graph_building.TorchScriptTensor, tuple)), type(
  396. output
  397. )
  398. if options.op_level_debug:
  399. _validate_op_between_ort_torch(node, symbolic_fn, torch_args, torch_kwargs)
  400. fx_name_to_onnxscipt_value[node.name] = output
  401. elif node.op == "output":
  402. if isinstance(node.args[0], torch.fx.Node):
  403. onnx_tensor_or_tensor_tuple = fx_name_to_onnxscipt_value[node.args[0].name]
  404. onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
  405. else:
  406. # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
  407. # tensor, etc), we flatten the collection and register each element as output.
  408. flat_args, _ = _pytree.tree_flatten(node.args[0])
  409. for arg in flat_args:
  410. assert isinstance(
  411. arg, torch.fx.Node
  412. ), f"arg must be a torch.fx.Node, not {type(arg)}"
  413. onnx_tensor_or_tensor_tuple = fx_name_to_onnxscipt_value[arg.name]
  414. onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
  415. elif node.op == "call_method":
  416. # TODO(wechi): Support call_method.
  417. raise RuntimeError("call_method is not supported yet.")
  418. elif node.op == "call_module":
  419. # TODO(wechi): Support call_module.
  420. raise RuntimeError("call_module is not supported yet.")
  421. elif node.op == "get_attr":
  422. current_attr = fx_module_with_metadata
  423. sub_attr_names = node.target.split(".") # type: ignore[union-attr]
  424. # If node.targe is "conv.weight", the following loop first
  425. # assigns fx_module_with_metadata.conv to current_attr, and then
  426. # fx_module_with_metadata.conv.weight to current_attr.
  427. while sub_attr_names:
  428. sub_attr_name = sub_attr_names.pop(0)
  429. if not hasattr(current_attr, sub_attr_name):
  430. raise AttributeError(
  431. f"Attribute {sub_attr_name} is not found in {current_attr}."
  432. )
  433. current_attr = getattr(current_attr, sub_attr_name)
  434. input_ = onnxscript_graph.add_input(
  435. input_name=node.name, input_value=current_attr
  436. )
  437. assert isinstance(input_, graph_building.TorchScriptTensor)
  438. assert isinstance(input_, onnxscript.tensor.Tensor)
  439. fx_name_to_onnxscipt_value[node.name] = input_
  440. onnxscript_value_name_to_real_tensor[input_.name] = current_attr # type: ignore[assignment]
  441. else:
  442. # TODO(wechi): Support get_attr, call_module, call_method.
  443. raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
  444. @diagnostics.diagnose_call(diagnostics.rules.atenlib_fx_to_onnx)
  445. def _export_fx_to_onnxscript(
  446. fx_module_with_metadata: torch.fx.GraphModule, options: options.ExportOptions
  447. ):
  448. # Initialize the ONNX graph
  449. onnxscript_graph = graph_building.TorchScriptGraph()
  450. tracer = graph_building.TorchScriptTracingEvaluator(onnxscript_graph)
  451. # In the following loop, a TorchScript graph is created to
  452. # represent the input FX graph with ONNX symbols (e.g., onnx::add).
  453. # To connect the values to nodes in the TorchScript graph, we maintian
  454. # fx_name_to_onnxscipt_value. Basically, we want to translate
  455. # fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node)
  456. # to
  457. # fx_name_to_onnxscipt_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscipt_value[fx_tensor_y.name]
  458. fx_name_to_onnxscipt_value: Dict[
  459. str, Union[torch._C.Value, Tuple[torch._C.Value, ...]]
  460. ] = {}
  461. # Similar to fx_name_to_onnxscipt_value, we need a mapping fo real tensors (usually tensor parameters
  462. # in nn.Module). Note that TorchScript's cannot store real tensors; TorchScript values are all
  463. # symbolic. This is passed into ONNX ModelProto as the initializers.
  464. onnxscript_value_name_to_real_tensor: Dict[
  465. str, Union[torch.Tensor, Tuple[torch._C.Value, ...]]
  466. ] = {}
  467. for node in fx_module_with_metadata.graph.nodes:
  468. _export_fx_node_to_onnxscript(
  469. node,
  470. onnxscript_graph,
  471. fx_name_to_onnxscipt_value,
  472. onnxscript_value_name_to_real_tensor,
  473. tracer,
  474. fx_module_with_metadata,
  475. options,
  476. )
  477. # Apply TorchScript's type promotion code.
  478. # Ideally, we should implement our type promotion but
  479. # to save time, we just reuse.
  480. onnxscript_graph.apply(
  481. torch._C._jit_pass_onnx_scalar_type_analysis,
  482. lowprecision_cast=True,
  483. opset_version=options.opset_version,
  484. )
  485. return onnxscript_graph, onnxscript_value_name_to_real_tensor
  486. @_beartype.beartype
  487. def _shape_inference_with_fake_tensor(decomposed_module: "torch.fx.GraphModule", *args):
  488. # Use this FakeTensorMode to
  489. # 1. convert nn.Parameter's in nn.Module to FakeTensor
  490. # 2. run FakeTensorProp
  491. # If (1) and (2) are done with difference FakeTensorMode's, undefined behavior may
  492. # happen.
  493. fake_tensor_mode = fake_tensor.FakeTensorMode()
  494. def to_fake_tensor(x):
  495. if isinstance(x, torch.Tensor) and not isinstance(x, fake_tensor.FakeTensor):
  496. return fake_tensor_mode.from_tensor(x)
  497. return x
  498. # "args" are FakeTensor in FakeTensorProp so the parameters and buffers
  499. # in model must be converted to FakeTensor as well.
  500. fake_parameters_and_buffers = {
  501. k: to_fake_tensor(v)
  502. for k, v in itertools.chain(
  503. decomposed_module.named_parameters(), decomposed_module.named_buffers()
  504. )
  505. }
  506. # Shape inference via FakeTensorProp
  507. with stateless._reparametrize_module(
  508. decomposed_module, fake_parameters_and_buffers
  509. ):
  510. # Assign output types and shapes to each node.
  511. # TODO(wechi): It's possible to get symbolic types (and shapes)
  512. # for each node's output. Consider to set "tracing_mode=symbolic"
  513. # when calling make_fx and then remove FakeTensorProp below.
  514. fake_tensor_prop.FakeTensorProp(decomposed_module, fake_tensor_mode).propagate(
  515. *args
  516. )
  517. return decomposed_module
  518. @_beartype.beartype
  519. def _rename_placeholder_targets(
  520. module: "torch.fx.GraphModule", reference_module: "torch.fx.GraphModule"
  521. ):
  522. """Align the argument names in module with those in reference_module.
  523. After calling this function, the two forward(...) in module and reference_module should have
  524. the same signature.
  525. """
  526. placeholders = [node for node in module.graph.nodes if node.op == "placeholder"]
  527. reference_placeholders = [
  528. node for node in reference_module.graph.nodes if node.op == "placeholder"
  529. ]
  530. for placeholder, reference_placeholder in zip(placeholders, reference_placeholders):
  531. placeholder.target = reference_placeholder.target
  532. placeholder.name = reference_placeholder.name
  533. module.recompile()
  534. @_beartype.beartype
  535. def _export(
  536. module: torch.fx.GraphModule,
  537. args,
  538. **kwargs,
  539. ) -> Union["onnx.ModelProto", bytes]:
  540. export_options = options.ExportOptions()
  541. export_options.update(**kwargs)
  542. # Apply decomposition table to the input graph.
  543. # Make sure the feed-in "module" is stateless.
  544. decomposed_module = proxy_tensor.make_fx(
  545. module,
  546. decomposition_table=export_options.decomposition_table,
  547. tracing_mode="fake",
  548. _allow_non_fake_inputs=True,
  549. )(*args)
  550. # Rename placeholder targets to match the original module's signature since
  551. # We don't want to map forward(x, y, z) to forward(arg0, arg1, arg2).
  552. _rename_placeholder_targets(decomposed_module, module)
  553. # Run FakeTensorProp on decomposed_module.
  554. # Symbolic output of the i-th node can be accessed via
  555. # decomposed_module.graph.nodes[i].meta["val"]
  556. decomposed_module = _shape_inference_with_fake_tensor(decomposed_module, *args)
  557. # We want to pass list of ints and floats to TorchScript graph correctly
  558. # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may
  559. # receive FakeTensor and results runtime error. In addition, TorchScript-based
  560. # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible
  561. # with FakeTensorMode.
  562. with torch.utils._mode_utils.no_dispatch():
  563. onnxscript_graph, initializers = _export_fx_to_onnxscript(
  564. decomposed_module, export_options
  565. )
  566. # Export TorchScript graph to ONNX ModelProto.
  567. onnx_model = onnxscript_graph.to_model_proto(
  568. initializers, export_options.opset_version
  569. )
  570. if export_options.use_binary_format:
  571. # Return ModelProto in binary format.
  572. return onnx_model.SerializeToString()
  573. # Return ModelProto
  574. return onnx_model
  575. @_beartype.beartype
  576. def export(
  577. fn: Union[torch.nn.Module, Callable],
  578. *args,
  579. use_binary_format: bool = True,
  580. opset_version: int = _constants.ONNX_DEFAULT_OPSET,
  581. op_level_debug: bool = False,
  582. ) -> Union["onnx.ModelProto", bytes]:
  583. # args will be converted to symbolic tensor. Let's copy to avoid side effects.
  584. args = copy.deepcopy(args)
  585. # Translate callable to FX graph.
  586. #
  587. # TODO(wechi): There are several symbolic tracing mechanisms to convert
  588. # nn.Module to FX graph. We should choose the right one after they are
  589. # matured.
  590. graph_module, graph_guard = torch._dynamo.export(fn, *args, aten_graph=True)
  591. del graph_guard # Unused
  592. # Export FX graph to ONNX ModelProto.
  593. #
  594. # Note that ALL kwargs are folded into constants in graph_module, so we don't pass kwargs
  595. # to _export.
  596. return _export(
  597. graph_module,
  598. args,
  599. opset_version=opset_version,
  600. decomposition_table=function_dispatcher._ONNX_FRIENDLY_DECOMPOSITION_TABLE,
  601. use_binary_format=use_binary_format,
  602. op_level_debug=op_level_debug,
  603. )
  604. @_beartype.beartype
  605. def export_without_kwargs(
  606. fn: Union[torch.nn.Module, Callable],
  607. *args,
  608. use_binary_format: bool = True,
  609. opset_version: int = _constants.ONNX_DEFAULT_OPSET,
  610. op_level_debug: bool = False,
  611. **kwargs,
  612. ) -> Union["onnx.ModelProto", bytes]:
  613. if isinstance(fn, torch.nn.Module):
  614. signature = inspect.signature(fn.forward)
  615. else:
  616. signature = inspect.signature(fn)
  617. # We hope the input kwargs will be mapped to bound.args after binding.
  618. # If not, we will raise an error.
  619. bound = signature.bind(*args, **kwargs)
  620. bound.apply_defaults()
  621. # kwargs are not handled.
  622. assert not bound.kwargs
  623. class Wrapper(torch.nn.Module):
  624. def __init__(self, fn):
  625. super().__init__()
  626. self.fn = fn
  627. def forward(self, *args):
  628. result, _ = _pytree.tree_flatten(self.fn(*args))
  629. return result
  630. # args will be converted to symbolic tensor. Let's copy to avoid side effects.
  631. bound_args = copy.deepcopy(bound.args)
  632. # Translate callable to FX graph.
  633. #
  634. # TODO(wechi): There are several symbolic tracing mechanisms to convert
  635. # nn.Module to FX graph. We should choose the right one after they are
  636. # matured.
  637. class GraphCaptureCompiler:
  638. def __init__(self):
  639. self.captured_graph: Optional["torch.fx.GraphModule"] = None
  640. self.captured_graph_count = 0
  641. def compile(self, graph_module: "torch.fx.GraphModule", _):
  642. assert self.captured_graph_count == 0
  643. self.captured_graph = graph_module
  644. self.captured_graph_count += 1
  645. return graph_module
  646. compiler = GraphCaptureCompiler()
  647. torch._dynamo.reset()
  648. torch._dynamo.optimize(compiler.compile, nopython=True)(Wrapper(fn))(*bound_args)
  649. torch._dynamo.reset()
  650. assert compiler.captured_graph
  651. # Export FX graph to ONNX ModelProto.
  652. return _export(
  653. compiler.captured_graph,
  654. # Function optimized by _dynamo doesn't have None in args.
  655. tuple(arg for arg in bound_args if arg is not None),
  656. opset_version=opset_version,
  657. decomposition_table=function_dispatcher._ONNX_FRIENDLY_DECOMPOSITION_TABLE,
  658. use_binary_format=use_binary_format,
  659. op_level_debug=op_level_debug,
  660. )
  661. @_beartype.beartype
  662. def _move_placeholder_to_front(graph_module: "torch.fx.GraphModule") -> None:
  663. """
  664. This function move all placeholder nodes to the front of the graph node list.
  665. In torch.fx.Graph, placeholder is a special assignment node. If it's not
  666. executed in the beginning, it could overwrite values computed by upstream
  667. nodes.
  668. """
  669. graph = graph_module.graph
  670. placeholders = []
  671. first_not_placeholder = None
  672. for node in graph.nodes:
  673. if node.op == "placeholder":
  674. placeholders.append(node)
  675. if first_not_placeholder is None and node.op != "placeholder":
  676. first_not_placeholder = node
  677. if first_not_placeholder is None:
  678. return
  679. for placeholder in placeholders:
  680. first_not_placeholder.prepend(placeholder)
  681. @_beartype.beartype
  682. def _replace_get_attr_with_placeholder(
  683. graph_module: "torch.fx.GraphModule",
  684. ) -> Tuple[torch.Tensor, ...]:
  685. """
  686. Replace get_attr with placeholder.
  687. The parameters and buffers accessed by the original get_attr are returned;
  688. they are useful when creating random inputs for the modified graph_module.
  689. """
  690. graph = graph_module.graph
  691. replaced_attrs: List[torch.Tensor] = []
  692. for node in graph.nodes:
  693. if node.op == "get_attr":
  694. replaced_attr: Optional[torch.Tensor] = None
  695. # get_attr could retrieve either parameter or buffer, so
  696. # we need to try both.
  697. try:
  698. replaced_attr = graph_module.get_parameter(node.target)
  699. except AttributeError:
  700. # It's possible that model author use buffer instead of
  701. # parameter to store trainable weights. In this case,
  702. # 1. get_parameter will throw something like
  703. # AttributeError: `bias` is not an nn.Parameter.
  704. # 2. get_buffer should work.
  705. replaced_attr = graph_module.get_buffer(node.target)
  706. # Reassign op type so that get_attr node becomes placeholder node.
  707. node.op = "placeholder"
  708. # The target name in placeholder must be a valid Python identifier.
  709. # Thus, we replace, e.g., "module.submodule.weight" with
  710. # "module_submodule_weight".
  711. node.target = node.target.replace(".", "_")
  712. # Default value is None. This is needed as long as the "graph_module"
  713. # has optional inputs. Assume the original forward signature is
  714. # def forward(self, x, y=None)
  715. # and the replaced get_attr node has target "z". Then, the modified
  716. # signature should be
  717. # def forward(self, x, y=None, z=None)
  718. # Without the following line, the signature will be
  719. # def forward(self, x, y=None, z)
  720. # , which is not valid Python code.
  721. node.args = (None,)
  722. replaced_attrs.append(replaced_attr)
  723. return tuple(replaced_attrs)
  724. @_beartype.beartype
  725. def _trace_into_fx_graph_via_fx_symbolic_trace(
  726. module: torch.nn.Module,
  727. *args,
  728. # kwargs are the keyword arguments to call "module"; that is,
  729. # module(*args, **kwargs) must run.
  730. **kwargs,
  731. ) -> Tuple["torch.fx.GraphModule", Tuple[Any, ...]]:
  732. signature = inspect.signature(module.forward)
  733. # We hope the input kwargs will be mapped to bound.args after binding.
  734. # If not, we will raise an error.
  735. bound = signature.bind(*args, **kwargs)
  736. bound.apply_defaults()
  737. # After apply_defaults, all non keyword-only arguments are in bound.args.
  738. # Because below code do not support keyword-word arguments, bound.kwargs
  739. # must be empty.
  740. assert len(bound.kwargs) == 0, bound.kwargs
  741. # Create inputs to call symbolic trace (torch.fx.symbolic_trace)
  742. # Example content of concrete_args:
  743. # concrete_args["x"] = torch.fx._symbolic_trace.PH
  744. # concrete_args["b"] = 1
  745. # where "x" and "b" are argument names in "signature".
  746. concrete_args = {}
  747. for param_name, param_value in bound.arguments.items():
  748. if isinstance(param_value, torch.Tensor):
  749. # param_value can be, e.g., a real tensor or a fake tensor.
  750. # param_value is treated as substitutable tensor symbol (aka placeholder).
  751. concrete_args[param_name] = torch.fx._symbolic_trace.PH
  752. else:
  753. concrete_args[param_name] = param_value
  754. return (
  755. _module_expansion_symbolic_trace(module, concrete_args=concrete_args),
  756. bound.args,
  757. )
  758. @_beartype.beartype
  759. def export_without_parameters_and_buffers(
  760. module: torch.nn.Module,
  761. *args,
  762. decomposition_table: Optional[Dict[torch._ops.OpOverload, Callable]] = None,
  763. use_binary_format: bool = True,
  764. opset_version: int = _constants.ONNX_DEFAULT_OPSET,
  765. op_level_debug: bool = False,
  766. # kwargs are the keyword arguments to call "module"; that is,
  767. # module(*args, **kwargs) must run.
  768. **kwargs,
  769. ) -> Tuple[
  770. Union["onnx.ModelProto", bytes],
  771. "torch.fx.GraphModule",
  772. Tuple[Any, ...],
  773. Tuple[Any, ...],
  774. ]:
  775. graph_module, bound_args = _trace_into_fx_graph_via_fx_symbolic_trace(
  776. module, *args, **kwargs
  777. )
  778. # Make sure all placeholder nodes are executed before get_attr nodes.
  779. # Otherwise, inputs can interleave with initializers in the final ModeoProto.graph.input.
  780. # Basically, we want
  781. # ModeoProto.graph.input =
  782. # [input_0, input_1, ..., input_n, weight_0, weight_1, ..., weight_m]
  783. # and we don't want
  784. # ModeoProto.graph.input =
  785. # [input_0, weight_0, input_1, weight_1, ..., input_n, weight_0, weight_1, ..., weight_m]
  786. _move_placeholder_to_front(graph_module)
  787. # To save memory, move get_attr to input so that the generated model doesn't
  788. # have weigh tensors. "replaced_attrs" are the list of replaced weight tensors.
  789. replaced_attrs = _replace_get_attr_with_placeholder(graph_module)
  790. # Move all newly created placeholder nodes to the front of the graph.
  791. _move_placeholder_to_front(graph_module)
  792. # Finalize the graph editing.
  793. graph_module.recompile()
  794. return (
  795. _export(
  796. graph_module,
  797. (*bound_args, *replaced_attrs),
  798. opset_version=opset_version,
  799. decomposition_table=decomposition_table,
  800. use_binary_format=use_binary_format,
  801. op_level_debug=op_level_debug,
  802. ),
  803. graph_module,
  804. bound_args,
  805. replaced_attrs,
  806. )
  807. @_beartype.beartype
  808. def _create_tensor_proto_with_external_data(
  809. tensor: torch.Tensor, name: str, location: str, basepath: str
  810. ) -> "onnx.TensorProto":
  811. """Create a TensorProto with external data from a PyTorch tensor.
  812. The external data is saved to os.path.join(basepath, location).
  813. Args:
  814. tensor: Tensor to be saved.
  815. name: Name of the tensor (i.e., initializer name in ONNX graph).
  816. location: Relative location of the external data file
  817. (e.g., "/tmp/initializers/weight_0" when model is "/tmp/model_name.onnx").
  818. basepath: Base path of the external data file (e.g., "/tmp/external_data" while model must be in "/tmp").
  819. Reference for ONNX's external data format:
  820. How to load?
  821. https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L187
  822. How to save?
  823. https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L43
  824. How to set ONNX fields?
  825. https://github.com/onnx/onnx/blob/5dac81ac0707bdf88f56c35c0a5e8855d3534673/onnx/external_data_helper.py#L88
  826. """
  827. tensor_proto = onnx.TensorProto()
  828. tensor_proto.name = name
  829. tensor_proto.data_type = torch.onnx._type_utils._SCALAR_TYPE_TO_ONNX[ # type: ignore[assignment]
  830. torch.onnx._type_utils._DTYPE_TO_SCALAR_TYPE[tensor.dtype]
  831. ]
  832. tensor_proto.dims.extend(tensor.shape)
  833. tensor_proto.data_location = onnx.TensorProto.EXTERNAL
  834. # Settings for saving one tensor per file.
  835. # Offset is zero because there is no other tensor in the same file.
  836. key_value_pairs = {
  837. "location": location,
  838. "offset": 0,
  839. "length": tensor.untyped_storage().nbytes(),
  840. }
  841. for k, v in key_value_pairs.items():
  842. entry = tensor_proto.external_data.add()
  843. entry.key = k
  844. entry.value = str(v)
  845. # Actual path to write content of tensor.
  846. external_data_file_path = os.path.join(basepath, location)
  847. if os.path.exists(external_data_file_path):
  848. os.remove(external_data_file_path)
  849. # Create external data's folder if not exists.
  850. external_data_dir_path = os.path.dirname(external_data_file_path)
  851. if not os.path.exists(external_data_dir_path):
  852. # if the demo_folder directory is not present
  853. # then create it.
  854. os.makedirs(external_data_dir_path)
  855. # Create a fresh file.
  856. with open(external_data_file_path, "xb") as data_file:
  857. # No need to call "seek" because offset is 0.
  858. # data_file.seek(0)
  859. # Write tensor content to the file.
  860. data_file.write(tensor.numpy().tobytes())
  861. return tensor_proto
  862. @_beartype.beartype
  863. def save_model_with_external_data(
  864. basepath: str,
  865. model_location: str,
  866. initializer_location: str,
  867. torch_load_paths: Tuple[str, ...],
  868. onnx_model: "onnx.ModelProto",
  869. ) -> None:
  870. """Load PyTorch tensors from files and add to "onnx_model" as external initializers.
  871. Output files:
  872. ONNX model file path:
  873. ONNX initializer folder: os.path.join(basepath, initializer_location)
  874. After running this function, you can do
  875. ort_sess = onnxruntime.InferenceSession(os.path.join(basepath, model_location))
  876. to execute the model.
  877. Arguments:
  878. basepath: Base path of the external data file (e.g., "/tmp/large-onnx-model").
  879. model_location: Relative location of the ONNX model file.
  880. E.g., "model.onnx" so that the model file is saved to
  881. "/tmp/large-onnx-model/model.onnx".
  882. initializer_location: Relative location of the ONNX initializer folder.
  883. E.g., "initializers" so that the initializers are saved to
  884. "/tmp/large-onnx-model/initializers".
  885. torch_load_paths: Files which containing serialized PyTorch tensors to be saved
  886. as ONNX initializers. They are loaded by torch.load.
  887. onnx_model: ONNX model to be saved with external initializers.
  888. If an input name matches a tensor loaded from "torch_load_paths",
  889. the tensor will be saved as that input's external initializer.
  890. """
  891. onnx_model_with_initializers = onnx.ModelProto()
  892. onnx_model_with_initializers.CopyFrom(onnx_model)
  893. onnx_input_names = [input.name for input in onnx_model.graph.input]
  894. for path in torch_load_paths:
  895. state_ditc = torch.load(path)
  896. for name, tensor in state_ditc.items():
  897. # Basically, "transformer.attention.self.query.weight" is mapped
  898. # to "transformer_attention_self_query_weight" for mimicking the
  899. # name-modifying code in FX-to-ONNX exporter.
  900. # See function _replace_get_attr_with_placeholder for details.
  901. refined_name = name.replace(".", "_")
  902. # For each refined PyTorch tensor name loaded by torch.load,
  903. # 1. Search its best match in ONNX model. E.g., the match of
  904. # "transformer_attention_weight" could be "attention_weight".
  905. # 2. Set "tensor" as the initializer of the matched ONNX input.
  906. # E.g., "tensor" is stored as the initializer of "attention_weight".
  907. # Step 1 is required because sometimes, tensor names are stored with prefix the dictionary
  908. # loaded by torch.load.
  909. for onnx_input_name in onnx_input_names:
  910. if onnx_input_name.endswith(refined_name) or refined_name.endswith(
  911. onnx_input_name
  912. ):
  913. # Find a match. Change refined_name to the matched ONNX input name, so that we
  914. # create initializer with the right ONNX name.
  915. refined_name = onnx_input_name
  916. break
  917. relative_tensor_file_path = os.path.join(initializer_location, refined_name)
  918. # Create one file per tensor.
  919. # tensor_proto.raw_data is stored to external file at
  920. # os.path.join(basepath, relative_tensor_file_path).
  921. tensor_proto = _create_tensor_proto_with_external_data(
  922. tensor, refined_name, relative_tensor_file_path, basepath
  923. )
  924. # Add the tensor_proto to the ONNX model as an initializer with external data.
  925. onnx_model_with_initializers.graph.initializer.append(tensor_proto)
  926. # model_location should be a pure file name such as "file_name.onnx", not "folder/file_name.onnx".
  927. onnx.save(onnx_model_with_initializers, os.path.join(basepath, model_location))
  928. # TODO(titaiwang): copied from ops_correctness_test.py, should have a common place?
  929. TORCH_TYPE_TO_ONNX = {
  930. torch.bool: onnx.TensorProto.BOOL,
  931. torch.uint8: onnx.TensorProto.UINT8,
  932. torch.int8: onnx.TensorProto.INT8,
  933. torch.int16: onnx.TensorProto.INT16,
  934. torch.int32: onnx.TensorProto.INT32,
  935. torch.int64: onnx.TensorProto.INT64,
  936. torch.float16: onnx.TensorProto.FLOAT16,
  937. torch.float32: onnx.TensorProto.FLOAT,
  938. torch.float64: onnx.TensorProto.DOUBLE,
  939. torch.complex64: onnx.TensorProto.COMPLEX64,
  940. torch.complex128: onnx.TensorProto.COMPLEX128,
  941. torch.bfloat16: onnx.TensorProto.BFLOAT16,
  942. }
  943. # TODO(titaiwang): copied from ops_correctness_test.py, should have a common place?
  944. def _convert_tensor_to_numpy(input: Any) -> Any:
  945. if isinstance(input, torch.Tensor):
  946. return input.detach().cpu().numpy()
  947. if isinstance(input, (tuple, list)):
  948. if len(input) == 0:
  949. return np.array((), dtype=np.int64)
  950. if isinstance(input[0], torch.Tensor):
  951. return [_convert_tensor_to_numpy(x) for x in input]
  952. if isinstance(input[0], bool):
  953. return np.array(input, dtype=np.bool_)
  954. # Just a sequence of numbers
  955. if isinstance(input[0], int):
  956. return np.array(input, dtype=np.int64)
  957. if isinstance(input[0], float):
  958. return np.array(input)
  959. return input
  960. # TODO(titaiwang): copied from ops_correctness_test.py, should have a common place?
  961. def _convert_kwargs_for_onnx(kwargs: dict[str, Any]) -> dict[str, Any]:
  962. """Converts kwargs to be compatible with ONNX Runtime.
  963. ONNX Runtime doesn't support torch.bool, so we convert them to torch.uint8.
  964. """
  965. new_kwargs = {}
  966. for key, value in kwargs.items():
  967. if key == "device":
  968. continue
  969. if key == "dtype":
  970. value = TORCH_TYPE_TO_ONNX[value]
  971. new_kwargs[key] = value
  972. return new_kwargs
  973. @_beartype.beartype
  974. def _validate_op_between_ort_torch(
  975. node: torch.fx.Node,
  976. symbolic_fn: onnxscript.OnnxFunction,
  977. torch_args: tuple,
  978. torch_kwargs: dict,
  979. ):
  980. """Validate the op between ONNX Runtime and PyTorch."""
  981. # op-level validation
  982. # Symbolic_fn should have the same output as node.target (torch ops)
  983. try:
  984. with evaluator.default_as(evaluator.ort_evaluator):
  985. expected_outputs = node.target(*torch_args, **torch_kwargs) # type: ignore[operator]
  986. # TODO(titaiwang): Expose _convert_tensor_to_numpy and _convert_kwargs_for_onnx?
  987. input_onnx = [_convert_tensor_to_numpy(x) for x in torch_args]
  988. # deal with dtype and device
  989. kwargs_onnx = _convert_kwargs_for_onnx(torch_kwargs)
  990. ort_outputs = symbolic_fn(*input_onnx, **kwargs_onnx)
  991. for ort_output, expected_output in zip(ort_outputs, expected_outputs):
  992. try:
  993. torch.testing.assert_close(
  994. expected_output.numpy(),
  995. ort_output,
  996. check_device=False,
  997. atol=10e-4,
  998. rtol=10e-3,
  999. )
  1000. except AssertionError as e:
  1001. warnings.warn(
  1002. f"Suppressed AssertionError:\n{e}.\n"
  1003. f"Op {node.target} has mismatch outputs. "
  1004. f"Please check the implementation of {symbolic_fn}."
  1005. )
  1006. diagnostic = diagnostics.export_context().inflight_diagnostic()
  1007. diagnostic.with_additional_message(
  1008. f"### Validation failed\n"
  1009. f"{diagnostics.decorator.format_exception_in_markdown(e)}"
  1010. )
  1011. diagnostic.level = diagnostics.levels.ERROR
  1012. except Exception as e:
  1013. warnings.warn(f"ORT fails to run with error: {e}.")
  1014. diagnostic = diagnostics.export_context().inflight_diagnostic()
  1015. diagnostic.with_additional_message(
  1016. f"### Validation failed\n"
  1017. f"{diagnostics.decorator.format_exception_in_markdown(e)}"
  1018. )
  1019. diagnostic.level = diagnostics.levels.WARNING
  1020. # Register a few argument formatter