verification.py 68 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884
  1. """Functions to verify exported ONNX model is functionally equivalent to original PyTorch model.
  2. ONNX Runtime is required, and is used as the ONNX backend for export verification.
  3. """
  4. from __future__ import annotations
  5. import contextlib
  6. import copy
  7. import dataclasses
  8. import datetime
  9. import difflib
  10. import enum
  11. import functools
  12. import io
  13. import itertools
  14. import os
  15. import tempfile
  16. import warnings
  17. from typing import (
  18. Any,
  19. Callable,
  20. Collection,
  21. Dict,
  22. FrozenSet,
  23. List,
  24. Mapping,
  25. Optional,
  26. Sequence,
  27. Set,
  28. Tuple,
  29. Union,
  30. )
  31. import numpy as np
  32. import torch
  33. import torch._C._onnx as _C_onnx
  34. from torch import _C
  35. from torch.onnx import _constants, _experimental, _exporter_states, utils
  36. from torch.onnx._globals import GLOBALS
  37. from torch.onnx._internal import _beartype, onnx_proto_utils
  38. from torch.types import Number
  39. _ORT_PROVIDERS = ("CPUExecutionProvider",)
  40. _NumericType = Union[Number, torch.Tensor, np.ndarray]
  41. _ModelType = Union[torch.nn.Module, torch.jit.ScriptModule]
  42. _InputArgsType = Union[torch.Tensor, Tuple[Any, ...]]
  43. _InputKwargsType = Mapping[str, Any]
  44. _OutputsType = Union[Sequence[_NumericType], Sequence]
  45. class OnnxBackend(enum.Enum):
  46. """Enum class for ONNX backend used for export verification."""
  47. REFERENCE = "ONNXReferenceEvaluator"
  48. ONNX_RUNTIME_CPU = "CPUExecutionProvider"
  49. ONNX_RUNTIME_CUDA = "CUDAExecutionProvider"
  50. @dataclasses.dataclass
  51. class VerificationOptions:
  52. """Options for ONNX export verification.
  53. Attributes:
  54. flatten: If True, unpack nested list/tuple/dict inputs into a flattened list of
  55. Tensors for ONNX. Set this to False if nested structures are to be preserved
  56. for ONNX, which is usually the case with exporting ScriptModules. Default True.
  57. ignore_none: Whether to ignore None type in torch output, which is usually the
  58. case with tracing. Set this to False, if torch output should keep None type,
  59. which is usually the case with exporting ScriptModules. Default to True.
  60. check_shape: Whether to check the shapes between PyTorch and ONNX Runtime outputs
  61. are exactly the same. Set this to False to allow output shape broadcasting.
  62. Default to True.
  63. check_dtype: Whether to check the dtypes between PyTorch and ONNX Runtime outputs
  64. are consistent. Default to True.
  65. backend: ONNX backend for verification. Default to OnnxBackend.ONNX_RUNTIME_CPU.
  66. rtol: relative tolerance in comparison between ONNX and PyTorch outputs.
  67. atol: absolute tolerance in comparison between ONNX and PyTorch outputs.
  68. remained_onnx_input_idx: If provided, only the specified inputs will be passed
  69. to the ONNX model. Supply a list when there are unused inputs in the model.
  70. Since unused inputs will be removed in the exported ONNX model, supplying
  71. all inputs will cause an error on unexpected inputs. This parameter tells
  72. the verifier which inputs to pass into the ONNX model.
  73. acceptable_error_percentage: acceptable percentage of element mismatches in comparison.
  74. It should be a float of value between 0.0 and 1.0.
  75. """
  76. flatten: bool = True
  77. ignore_none: bool = True
  78. check_shape: bool = True
  79. check_dtype: bool = True
  80. backend: OnnxBackend = OnnxBackend.ONNX_RUNTIME_CPU
  81. rtol: float = 1e-3
  82. atol: float = 1e-7
  83. remained_onnx_input_idx: Optional[Sequence[int]] = None
  84. acceptable_error_percentage: Optional[float] = None
  85. @_beartype.beartype
  86. def _flatten_tuples(elem):
  87. flattened = []
  88. for t in elem:
  89. if isinstance(t, tuple):
  90. flattened.extend(_flatten_tuples(t))
  91. else:
  92. flattened.append(t)
  93. return flattened
  94. # TODO(justinchuby): Add type checking by narrowing down the return type when input is None
  95. def _to_numpy(elem) -> Union[list, np.ndarray]:
  96. if isinstance(elem, torch.Tensor):
  97. if elem.requires_grad:
  98. return elem.detach().cpu().numpy()
  99. else:
  100. return elem.cpu().numpy()
  101. elif isinstance(elem, (list, tuple)):
  102. return [_to_numpy(inp) for inp in elem]
  103. elif isinstance(elem, (bool, int, float)):
  104. return np.array(elem)
  105. elif isinstance(elem, dict):
  106. flattened = []
  107. for k in elem:
  108. flattened.extend([_to_numpy(k), _to_numpy(elem[k])])
  109. return flattened
  110. return elem
  111. @_beartype.beartype
  112. def _inline_flatten_list(inputs, res_list) -> list:
  113. for i in inputs:
  114. res_list.append(i) if not isinstance(
  115. i, (list, tuple)
  116. ) else _inline_flatten_list(i, res_list)
  117. return res_list
  118. @_beartype.beartype
  119. def _unpack_to_numpy(values, cast_onnx_accepted=True) -> list:
  120. value_unpacked = []
  121. for value in values:
  122. value_unpacked.extend(
  123. utils.unpack_quantized_tensor(value, cast_onnx_accepted=cast_onnx_accepted)
  124. )
  125. return [_to_numpy(v) for v in value_unpacked]
  126. @_beartype.beartype
  127. def _run_onnx(onnx_session, inputs) -> _OutputsType:
  128. kw_inputs = {}
  129. if inputs and isinstance(inputs[-1], dict):
  130. kw_inputs = inputs[-1]
  131. inputs = inputs[:-1]
  132. inputs = _unpack_to_numpy(_flatten_tuples(inputs))
  133. ort_inputs = {}
  134. for input_name, input in kw_inputs.items():
  135. ort_inputs[input_name] = _to_numpy(input)
  136. inputs = _to_numpy(inputs)
  137. if hasattr(onnx_session, "get_inputs"):
  138. # onnxruntime.InferenceSession
  139. input_names = [i.name for i in onnx_session.get_inputs()]
  140. elif hasattr(onnx_session, "input_names"):
  141. # onnx.reference.ReferenceEvaluator
  142. input_names = onnx_session.input_names
  143. else:
  144. raise ValueError(f"Unknown ONNX backend type: {type(onnx_session)}.")
  145. for i, input in enumerate(inputs):
  146. if i == len(input_names) or input_names[i] in ort_inputs:
  147. raise ValueError(
  148. f"got too many positional inputs. inputs: {inputs}. kw_inputs: {kw_inputs}. "
  149. f"input names: {input_names}."
  150. )
  151. ort_inputs[input_names[i]] = input
  152. onnx_outs = onnx_session.run(None, ort_inputs)
  153. return onnx_outs
  154. @_beartype.beartype
  155. def _ort_session(
  156. model: Union[str, io.BytesIO], ort_providers: Sequence[str] = _ORT_PROVIDERS
  157. ):
  158. try:
  159. import onnxruntime # type: ignore[import]
  160. except ImportError as e:
  161. raise ImportError("onnxruntime is required for export verification.") from e
  162. if ort_providers is None:
  163. ort_providers = _ORT_PROVIDERS
  164. session_options = onnxruntime.SessionOptions()
  165. # suppress ort warnings.
  166. # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
  167. session_options.log_severity_level = 3
  168. ort_session = onnxruntime.InferenceSession(
  169. model if isinstance(model, str) else model.getvalue(),
  170. session_options,
  171. providers=ort_providers,
  172. )
  173. return ort_session
  174. @_beartype.beartype
  175. def _onnx_reference_evaluator_session(model: Union[str, io.BytesIO]):
  176. try:
  177. import onnx
  178. from onnx import reference as onnx_reference
  179. except ImportError:
  180. raise ImportError("onnx >= 1.13 is required for reference evaluator.")
  181. proto = (
  182. onnx.load(model)
  183. if isinstance(model, str)
  184. else onnx.load_model_from_string(model.getvalue())
  185. )
  186. onnx_session = onnx_reference.ReferenceEvaluator(proto)
  187. return onnx_session
  188. @_beartype.beartype
  189. def _onnx_backend_session(model: Union[str, io.BytesIO], backend: OnnxBackend):
  190. if backend == OnnxBackend.REFERENCE:
  191. onnx_session = _onnx_reference_evaluator_session(model)
  192. elif backend in {OnnxBackend.ONNX_RUNTIME_CPU, OnnxBackend.ONNX_RUNTIME_CUDA}:
  193. onnx_session = _ort_session(model, (backend.value,))
  194. else:
  195. raise ValueError(f"Unsupported backend: {backend}")
  196. return onnx_session
  197. @_beartype.beartype
  198. def _compare_onnx_pytorch_outputs_in_np(
  199. onnx_outs: _OutputsType,
  200. pt_outs: _OutputsType,
  201. options: VerificationOptions,
  202. ):
  203. assert len(onnx_outs) == len(
  204. pt_outs
  205. ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})"
  206. acceptable_error_percentage = options.acceptable_error_percentage
  207. if acceptable_error_percentage and (
  208. acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0
  209. ):
  210. raise ValueError(
  211. "If set, acceptable_error_percentage should be between 0.0 and 1.0"
  212. )
  213. for ort_out, pt_out in zip(onnx_outs, pt_outs):
  214. try:
  215. # TODO: Remove `check_shape` option once every shape inconsistent issue is addressed.
  216. if not options.check_shape:
  217. # Allow different but broadcastable output shapes.
  218. ort_out, pt_out = np.broadcast_arrays(ort_out, pt_out)
  219. torch.testing.assert_close(
  220. ort_out,
  221. pt_out,
  222. rtol=options.rtol,
  223. atol=options.atol,
  224. check_dtype=options.check_dtype,
  225. equal_nan=True,
  226. )
  227. except AssertionError as e:
  228. if acceptable_error_percentage:
  229. error_percentage = 1 - np.sum(
  230. np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
  231. ) / np.prod(ort_out.shape)
  232. if error_percentage <= acceptable_error_percentage:
  233. warnings.warn(
  234. f"Suppressed AssertionError:\n{e}.\n"
  235. f"Error percentage {error_percentage} "
  236. f"within acceptable range {acceptable_error_percentage}."
  237. )
  238. continue
  239. if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
  240. warnings.warn("ONNX output is quantized")
  241. if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
  242. warnings.warn("PyTorch output is quantized")
  243. raise
  244. @_beartype.beartype
  245. def _compare_onnx_pytorch_outputs(
  246. onnx_outs: _OutputsType,
  247. pt_outs: Any,
  248. options: VerificationOptions,
  249. ):
  250. """
  251. Compare ONNX and PyTorch outputs.
  252. Args:
  253. onnx_outs: outputs from ONNX backend.
  254. pt_outs: outputs from PyTorch.
  255. options: options for verification.
  256. Raises:
  257. AssertionError: if outputs from ONNX model and PyTorch model are not
  258. equal up to specified precision.
  259. ValueError: if arguments provided are invalid.
  260. """
  261. if options.ignore_none:
  262. # torch.jit._flatten filters None type
  263. pt_outs, _ = torch.jit._flatten(pt_outs)
  264. else:
  265. pt_outs = _inline_flatten_list([pt_outs], [])
  266. pt_outs_np = _unpack_to_numpy(pt_outs, cast_onnx_accepted=False)
  267. onnx_outs = _inline_flatten_list(onnx_outs, [])
  268. _compare_onnx_pytorch_outputs_in_np(onnx_outs, pt_outs_np, options)
  269. @_beartype.beartype
  270. def _prepare_input_for_pytorch(args, kwargs):
  271. """Prepare input for PyTorch model execution.
  272. Any future changes/formatting to the input before dispatching to the PyTorch
  273. model should be made in this function.
  274. Args:
  275. args: positional arguments for PyTorch model forward method.
  276. kwargs: keyword arguments for PyTorch model forward method.
  277. Returns:
  278. args: positional arguments for PyTorch model forward method.
  279. kwargs: keyword arguments for PyTorch model forward method.
  280. """
  281. if isinstance(args, (torch.Tensor, dict)):
  282. args = (args,)
  283. # In-place operators will update input tensor data as well.
  284. # Thus inputs are replicated before every forward call.
  285. args = copy.deepcopy(args)
  286. if kwargs:
  287. kwargs = copy.deepcopy(kwargs)
  288. else:
  289. kwargs = {}
  290. return args, kwargs
  291. @_beartype.beartype
  292. def _prepare_input_for_export(args, kwargs):
  293. """Prepare input for ONNX model export.
  294. Any future changes/formatting to the input before dispatching to the
  295. :func:`torch.onnx.export` api should be made in this function.
  296. Args:
  297. args: positional arguments for PyTorch model forward method.
  298. kwargs: keyword arguments for PyTorch model forward method.
  299. Returns:
  300. onnx_inputs: positional arguments for ONNX model export, as `args` in
  301. :func:`torch.onnx.export`.
  302. """
  303. args, kwargs = _prepare_input_for_pytorch(args, kwargs)
  304. if not kwargs and len(args) > 0 and isinstance(args[-1], dict):
  305. onnx_inputs = args + ({},)
  306. elif kwargs:
  307. onnx_inputs = args + (kwargs,)
  308. else:
  309. onnx_inputs = args
  310. return onnx_inputs
  311. @_beartype.beartype
  312. def _prepare_input_for_onnx(
  313. args, kwargs, remained_onnx_input_idx: Optional[Sequence[int]], flatten: bool
  314. ):
  315. """Prepare input for ONNX model execution in ONNX backend.
  316. Any future changes/formatting to the input before dispatching to the ONNX backend
  317. run should be made in this function.
  318. Args:
  319. args: positional arguments for PyTorch model forward method.
  320. kwargs: keyword arguments for PyTorch model forward method.
  321. remained_onnx_input_idx: indices of inputs to be used for ONNX model execution.
  322. flatten: whether to flatten the input before dispatching to the ONNX model execution.
  323. Returns:
  324. onnx_inputs: positional arguments for ONNX model execution in ONNX backend.
  325. """
  326. onnx_inputs = _prepare_input_for_export(args, kwargs)
  327. if flatten:
  328. onnx_inputs, _ = torch.jit._flatten(onnx_inputs)
  329. elif onnx_inputs and onnx_inputs[-1] == {}:
  330. # Handle empty kwargs (normally removed by flatten).
  331. onnx_inputs = onnx_inputs[:-1]
  332. if remained_onnx_input_idx is not None:
  333. return [onnx_inputs[i] for i in remained_onnx_input_idx]
  334. else:
  335. return onnx_inputs
  336. @_beartype.beartype
  337. def _try_clone_model(model):
  338. """Used for preserving original model in case forward mutates model states."""
  339. try:
  340. return copy.deepcopy(model)
  341. except Exception:
  342. warnings.warn(
  343. "Failed to clone model. Model state might be mutated during verification."
  344. )
  345. return model
  346. @_beartype.beartype
  347. def _compare_onnx_pytorch_model(
  348. pt_model: _ModelType,
  349. onnx_model_f: Union[str, io.BytesIO],
  350. input_args: _InputArgsType,
  351. input_kwargs: Optional[_InputKwargsType],
  352. additional_test_inputs: Optional[Sequence[_InputArgsType]],
  353. options: VerificationOptions,
  354. ):
  355. """Compare outputs from ONNX model runs with outputs from PyTorch model runs.
  356. Args:
  357. pt_model: PyTorch model.
  358. onnx_model_f: ONNX model file path or file-like object.
  359. input_args: positional arguments for PyTorch model forward method.
  360. input_kwargs: keyword arguments for PyTorch model forward method.
  361. additional_test_inputs: additional positional arguments for PyTorch model
  362. forward method.
  363. options: options for verification.
  364. Raises:
  365. AssertionError: if outputs from ONNX model and PyTorch model are not
  366. equal up to specified precision.
  367. """
  368. onnx_session = _onnx_backend_session(onnx_model_f, options.backend)
  369. @_beartype.beartype
  370. def compare_onnx_pytorch_model_with_input(input_args, input_kwargs):
  371. pt_args, pt_kwargs = _prepare_input_for_pytorch(input_args, input_kwargs)
  372. # TODO: remove this and treat mutating model separately. See #77679
  373. pt_model_copy = _try_clone_model(pt_model)
  374. pt_outs = pt_model_copy(*pt_args, **pt_kwargs)
  375. onnx_inputs = _prepare_input_for_onnx(
  376. input_args, input_kwargs, options.remained_onnx_input_idx, options.flatten
  377. )
  378. onnx_outs = _run_onnx(onnx_session, onnx_inputs)
  379. _compare_onnx_pytorch_outputs(
  380. onnx_outs=onnx_outs,
  381. pt_outs=pt_outs,
  382. options=options,
  383. )
  384. compare_onnx_pytorch_model_with_input(input_args, input_kwargs)
  385. if additional_test_inputs:
  386. for test_input_args in additional_test_inputs:
  387. compare_onnx_pytorch_model_with_input(test_input_args, {})
  388. class _GraphDiff:
  389. """A class to represent the difference between two graphs."""
  390. @_beartype.beartype
  391. def __init__(self, graph_a: _C.Graph, graph_b: _C.Graph):
  392. """Construct a _GraphDiff object.
  393. Args:
  394. graph_a (_C.Graph): First graph to compare.
  395. graph_b (_C.Graph): Second graph to compare.
  396. """
  397. self.graph_a = graph_a
  398. self.graph_b = graph_b
  399. @_beartype.beartype
  400. def __str__(self):
  401. """See function :func:`diff_report`."""
  402. return self.diff_report()
  403. @_beartype.beartype
  404. def _indent(self, lines: str) -> str:
  405. return "\n".join(["\t" + line for line in lines.splitlines()])
  406. @_beartype.beartype
  407. def diff_report(self) -> str:
  408. """Return a string representation of the graph difference.
  409. The report shows the first pair of nodes that diverges. It also shows the source
  410. location of the pair of nodes.
  411. Returns:
  412. graph_diff_report (str): A string representation of the graph difference.
  413. """
  414. graph_a = self.graph_a
  415. graph_b = self.graph_b
  416. graph_a_str = str(graph_a)
  417. graph_b_str = str(graph_b)
  418. if graph_a_str == graph_b_str:
  419. return ""
  420. graph_diff = difflib.ndiff(
  421. graph_a_str.splitlines(True), graph_b_str.splitlines(True)
  422. )
  423. graph_diff_report = ["Graph diff:", self._indent("".join(graph_diff))]
  424. for node_a, node_b in itertools.zip_longest(graph_a.nodes(), graph_b.nodes()):
  425. if str(node_a) != str(node_b):
  426. graph_diff_report.append("First diverging operator:")
  427. node_diff = difflib.ndiff(
  428. str(node_a).splitlines(True), str(node_b).splitlines(True)
  429. )
  430. source_printout = ["node diff:", self._indent("".join(node_diff))]
  431. stack_a = node_a.sourceRange() if node_a else None
  432. if stack_a:
  433. source_printout.extend(
  434. ["Former source location:", self._indent(str(stack_a))]
  435. )
  436. stack_b = node_b.sourceRange() if node_b else None
  437. if stack_b:
  438. source_printout.extend(
  439. ["Latter source location:", self._indent(str(stack_b))]
  440. )
  441. graph_diff_report.extend(source_printout)
  442. break
  443. return "\n".join(graph_diff_report)
  444. @_beartype.beartype
  445. def _check_graph_diff(
  446. model: Union[torch.nn.Module, torch.jit.ScriptModule],
  447. test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
  448. export_options: _experimental.ExportOptions,
  449. model_to_graph_func: Callable[
  450. [
  451. torch.nn.Module,
  452. Tuple[Any, ...],
  453. Mapping[str, Any],
  454. _experimental.ExportOptions,
  455. ],
  456. _C.Graph,
  457. ],
  458. ) -> str:
  459. """Check if graph produced by `model_to_graph_func` is the same across `test_input_groups`.
  460. Args:
  461. model: See :func:`check_export_model_diff`.
  462. test_input_groups: See :func:`check_export_model_diff`.
  463. export_options: See :func:`check_export_model_diff`.
  464. model_to_graph_func: A function to convert a PyTorch model to a JIT IR graph.
  465. Returns:
  466. graph_diff_report (str): A string representation of the graph difference.
  467. """
  468. if len(test_input_groups) < 2:
  469. raise ValueError("Need at least two groups of test inputs to compare.")
  470. ref_jit_graph = None
  471. for args, kwargs in test_input_groups:
  472. jit_graph = model_to_graph_func(model, args, kwargs, export_options)
  473. if ref_jit_graph is None:
  474. ref_jit_graph = jit_graph
  475. continue
  476. graph_diff_report = _GraphDiff(ref_jit_graph, jit_graph).diff_report()
  477. if graph_diff_report:
  478. return graph_diff_report
  479. return ""
  480. @_beartype.beartype
  481. def _traced_graph_from_model(
  482. model: Union[torch.nn.Module, torch.jit.ScriptModule],
  483. args: Tuple[Any, ...],
  484. kwargs: Mapping[str, Any],
  485. export_options: _experimental.ExportOptions,
  486. ) -> _C.Graph:
  487. """As part of the ONNX export steps, create a traced JIT graph from a PyTorch model.
  488. Args:
  489. model: See :func:`check_export_model_diff`.
  490. args: See :func:`check_export_model_diff`.
  491. kwargs: See :func:`check_export_model_diff`.
  492. export_options: See :func:`check_export_model_diff`.
  493. Returns:
  494. jit_graph (_C.Graph): A traced JIT graph.
  495. """
  496. training = export_options.training
  497. verbose = export_options.verbose
  498. with utils.exporter_context(model, training, verbose):
  499. export_inputs = _prepare_input_for_export(args, kwargs)
  500. model = utils._pre_trace_quant_model(model, export_inputs)
  501. jit_graph, _, _, _ = utils._create_jit_graph(model, export_inputs)
  502. return jit_graph
  503. @_beartype.beartype
  504. def _onnx_graph_from_model(
  505. model: Union[torch.nn.Module, torch.jit.ScriptModule],
  506. args: Tuple[Any, ...],
  507. kwargs: Mapping[str, Any],
  508. export_options: _experimental.ExportOptions,
  509. ) -> _C.Graph:
  510. """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model.
  511. Args:
  512. model: See :func:`check_export_model_diff`.
  513. args: See :func:`check_export_model_diff`.
  514. kwargs: See :func:`check_export_model_diff`.
  515. export_options: See :func:`check_export_model_diff`.
  516. Returns:
  517. onnx_graph (_C.Graph): An ONNX JIT graph.
  518. """
  519. # TODO: refactor utils.py to remove duplicated code of context setup. See #78834
  520. opset_version = export_options.opset_version
  521. operator_export_type = export_options.operator_export_type
  522. export_modules_as_functions = export_options.export_modules_as_functions
  523. training = export_options.training
  524. verbose = export_options.verbose
  525. dynamic_axes = export_options.dynamic_axes
  526. input_names = export_options.input_names
  527. output_names = export_options.output_names
  528. if opset_version is None:
  529. opset_version = _constants.ONNX_DEFAULT_OPSET
  530. utils._setup_trace_module_map(model, export_modules_as_functions)
  531. if not operator_export_type:
  532. if _C_onnx._CAFFE2_ATEN_FALLBACK:
  533. operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
  534. else:
  535. operator_export_type = _C_onnx.OperatorExportTypes.ONNX
  536. GLOBALS.export_onnx_opset_version = opset_version
  537. GLOBALS.operator_export_type = operator_export_type
  538. with utils.exporter_context(model, training, verbose):
  539. do_constant_folding = utils._decide_constant_folding(
  540. export_options.do_constant_folding, operator_export_type, training
  541. )
  542. if dynamic_axes is None:
  543. dynamic_axes = {}
  544. utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names)
  545. export_inputs = _prepare_input_for_export(args, kwargs)
  546. export_inputs = utils._decide_input_format(model, export_inputs)
  547. onnx_graph, _, _ = utils._model_to_graph(
  548. model,
  549. export_inputs,
  550. verbose,
  551. input_names,
  552. output_names,
  553. operator_export_type,
  554. do_constant_folding,
  555. training=training,
  556. dynamic_axes=dynamic_axes,
  557. )
  558. return onnx_graph
  559. @_beartype.beartype
  560. def _onnx_graph_from_aten_graph(
  561. graph: torch.Graph,
  562. export_options: _experimental.ExportOptions,
  563. params_dict: Optional[Dict[str, Any]] = None,
  564. ) -> Tuple[torch.Graph, Dict[str, Any]]:
  565. if params_dict is None:
  566. params_dict = {}
  567. operator_export_type = export_options.operator_export_type
  568. dynamic_axes = export_options.dynamic_axes or {}
  569. input_names = export_options.input_names
  570. training = export_options.training
  571. do_constant_folding = export_options.do_constant_folding
  572. opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
  573. GLOBALS.export_onnx_opset_version = opset_version
  574. GLOBALS.operator_export_type = operator_export_type
  575. do_constant_folding = utils._decide_constant_folding(
  576. do_constant_folding, operator_export_type, training
  577. )
  578. # TODO: Below is doing aten graph to onnx. It should be abstracted as a
  579. # function in torch/onnx/utils.py.
  580. graph = graph.copy()
  581. graph = utils._optimize_graph(
  582. graph,
  583. operator_export_type,
  584. params_dict=params_dict,
  585. dynamic_axes=dynamic_axes,
  586. input_names=input_names,
  587. )
  588. if training is None or training == _C_onnx.TrainingMode.EVAL:
  589. params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
  590. if (
  591. do_constant_folding
  592. and opset_version >= _constants.ONNX_CONSTANT_FOLDING_MIN_OPSET
  593. ):
  594. params_dict = _C._jit_pass_onnx_constant_fold(graph, params_dict, opset_version)
  595. _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
  596. if GLOBALS.onnx_shape_inference:
  597. _C._jit_pass_onnx_graph_shape_type_inference(graph, params_dict, opset_version)
  598. params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
  599. # For ONNX opset < 9, constants only have three data types: float16, float, double.
  600. # In this pass transform constants of other data types to float/double + cast operator.
  601. if opset_version < 9:
  602. _C._jit_pass_onnx_cast_all_constant_to_floating(graph)
  603. params_dict = _C._jit_pass_filter_non_tensor_arguments(params_dict)
  604. _C._jit_decay_packed_param_input_types(graph)
  605. _C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
  606. if export_options.verbose:
  607. print("ONNX graph: ", graph)
  608. return graph, params_dict
  609. @_beartype.beartype
  610. def _onnx_proto_from_onnx_graph(
  611. onnx_graph: torch.Graph,
  612. export_options: _experimental.ExportOptions,
  613. params_dict: Dict[str, Any],
  614. ) -> Tuple[bytes, Mapping[str, bytes]]:
  615. opset_version = export_options.opset_version or _constants.ONNX_DEFAULT_OPSET
  616. dynamic_axes = export_options.dynamic_axes or {}
  617. operator_export_type = export_options.operator_export_type
  618. val_keep_init_as_ip = utils._decide_keep_init_as_input(
  619. export_options.keep_initializers_as_inputs,
  620. operator_export_type,
  621. opset_version,
  622. )
  623. val_add_node_names = utils._decide_add_node_names(True, operator_export_type)
  624. custom_opsets = export_options.custom_opsets or {}
  625. proto, export_map, _, _ = onnx_graph._export_onnx( # type: ignore[attr-defined]
  626. params_dict,
  627. opset_version,
  628. dynamic_axes,
  629. False,
  630. operator_export_type,
  631. not export_options.verbose,
  632. val_keep_init_as_ip,
  633. custom_opsets,
  634. val_add_node_names,
  635. "",
  636. {},
  637. )
  638. return proto, export_map
  639. @_beartype.beartype
  640. def check_export_model_diff(
  641. model: Union[torch.nn.Module, torch.jit.ScriptModule],
  642. test_input_groups: Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]],
  643. export_options: Optional[_experimental.ExportOptions] = None,
  644. ) -> str:
  645. """Verify exported model discrepancy between different groups of inputs.
  646. A graph is exported for each group of inputs. The exported graphs are then compared
  647. to each other, and discrepancies of first pair of nodes are reported. This function
  648. first checks the jit graph. If no discrepancies were found, it then checks the onnx
  649. graph.
  650. Unless otherwise specified, the jit/ONNX graph is expected to be the same, regardless
  651. of the inputs used for exporting. A discrepancy implies the graph exported is
  652. not accurate when run on other groups of inputs, which will typically results in
  653. runtime errors or mismatching output.
  654. Args:
  655. model (torch.nn.Module or torch.jit.ScriptModule): The model to be exported.
  656. test_input_groups (Sequence[Tuple[Tuple[Any, ...], Mapping[str, Any]]]): A sequence
  657. of input groups to be used to export the model. Each input group is a pair of
  658. (args, kwargs).
  659. export_options (_experimental.ExportOptions, optional): An _experimental.ExportOptions
  660. object that controls the export behavior.
  661. Returns:
  662. str: A string containing the diff of the exported models.
  663. """
  664. export_options = (
  665. _experimental.ExportOptions() if export_options is None else export_options
  666. )
  667. jit_diff_report = _check_graph_diff(
  668. model, test_input_groups, export_options, _traced_graph_from_model
  669. )
  670. if jit_diff_report:
  671. return jit_diff_report
  672. return _check_graph_diff(
  673. model, test_input_groups, export_options, _onnx_graph_from_model
  674. )
  675. @_beartype.beartype
  676. def verify(
  677. model: _ModelType,
  678. input_args: _InputArgsType,
  679. input_kwargs: Optional[_InputKwargsType] = None,
  680. do_constant_folding: bool = True,
  681. dynamic_axes: Optional[
  682. Mapping[str, Union[Mapping[int, str], Mapping[str, Sequence[int]]]]
  683. ] = None,
  684. input_names: Optional[Sequence[str]] = None,
  685. output_names: Optional[Sequence[str]] = None,
  686. training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
  687. opset_version: Optional[int] = None,
  688. keep_initializers_as_inputs: bool = True,
  689. verbose: bool = False,
  690. fixed_batch_size: bool = False,
  691. use_external_data: bool = False,
  692. additional_test_inputs: Optional[Sequence[_InputArgsType]] = None,
  693. options: Optional[VerificationOptions] = None,
  694. ):
  695. """Verify model export to ONNX against original PyTorch model.
  696. Args:
  697. model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
  698. input_args (tuple): See :func:`torch.onnx.export`.
  699. input_kwargs (dict): See :func:`torch.onnx.export`.
  700. do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
  701. dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
  702. input_names (list, optional): See :func:`torch.onnx.export`.
  703. output_names (list, optional): See :func:`torch.onnx.export`.
  704. training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
  705. opset_version (int, optional): See :func:`torch.onnx.export`.
  706. keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
  707. verbose (bool, optional): See :func:`torch.onnx.export`.
  708. fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
  709. use_external_data (bool, optional): Explicitly specify whether to export the
  710. model with external data.
  711. additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
  712. input arguments to test. Currently only *args are supported.
  713. options (_VerificationOptions, optional): A _VerificationOptions object that
  714. controls the verification behavior.
  715. Raises:
  716. AssertionError: if outputs from ONNX model and PyTorch model are not
  717. equal up to specified precision.
  718. ValueError: if arguments provided are invalid.
  719. """
  720. if options is None:
  721. options = VerificationOptions()
  722. if training == torch.onnx.TrainingMode.TRAINING:
  723. model.train()
  724. elif training == torch.onnx.TrainingMode.EVAL:
  725. model.eval()
  726. with torch.no_grad(), contextlib.ExitStack() as stack:
  727. model_f: Union[str, io.BytesIO] = io.BytesIO()
  728. if use_external_data:
  729. tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
  730. model_f = os.path.join(tmpdir_path, "model.onnx")
  731. inputs_for_export = _prepare_input_for_export(input_args, input_kwargs)
  732. # TODO(#77679): remove this and treat mutating model separately.
  733. model_copy = _try_clone_model(model)
  734. utils._export(
  735. model,
  736. inputs_for_export,
  737. model_f,
  738. opset_version=opset_version,
  739. do_constant_folding=do_constant_folding,
  740. keep_initializers_as_inputs=keep_initializers_as_inputs,
  741. dynamic_axes=dynamic_axes,
  742. input_names=input_names,
  743. output_names=output_names,
  744. fixed_batch_size=fixed_batch_size,
  745. training=training,
  746. verbose=verbose,
  747. )
  748. _compare_onnx_pytorch_model(
  749. pt_model=model_copy,
  750. onnx_model_f=model_f,
  751. input_args=input_args,
  752. input_kwargs=input_kwargs,
  753. additional_test_inputs=additional_test_inputs,
  754. options=options,
  755. )
  756. @_beartype.beartype
  757. def verify_aten_graph(
  758. graph: torch.Graph,
  759. input_args: Tuple[Any, ...],
  760. export_options: _experimental.ExportOptions,
  761. params_dict: Optional[Dict[str, Any]] = None,
  762. verification_options: Optional[VerificationOptions] = None,
  763. ) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
  764. if verification_options is None:
  765. verification_options = VerificationOptions()
  766. if params_dict is None:
  767. params_dict = {}
  768. original_jit_graph = graph
  769. graph = graph.copy()
  770. # Execute aten graph and get reference torch jit outputs.
  771. graph_inputs = list(graph.inputs())
  772. jit_inputs = tuple([arg for arg in input_args if arg is not None])
  773. weights = [params_dict[v.debugName()] for v in graph_inputs[len(jit_inputs) :]]
  774. assert all([w is not None for w in weights])
  775. # TODO: Only copy the argument if mutation is detected in Graph.
  776. jit_inputs = copy.deepcopy(jit_inputs)
  777. jit_input_and_parameters = jit_inputs + tuple(weights)
  778. jit_outs = torch._C._jit_interpret_graph(graph, jit_input_and_parameters) # type: ignore[attr-defined]
  779. if not isinstance(jit_outs, (list, tuple)):
  780. jit_outs = [jit_outs]
  781. # Convert aten graph to onnx graph.
  782. graph, onnx_params_dict = _onnx_graph_from_aten_graph(
  783. graph, export_options, params_dict
  784. )
  785. proto, export_map = _onnx_proto_from_onnx_graph(
  786. graph, export_options, onnx_params_dict
  787. )
  788. model_f: Union[str, io.BytesIO] = io.BytesIO()
  789. export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
  790. onnx_proto_utils._export_file(proto, model_f, export_type, export_map)
  791. # NOTE: Verification is unstable. Try catch to emit information for debugging.
  792. try:
  793. # NOTE: Input might be dce'ed, so we need to remove those from the input args.
  794. new_input_names = {v.debugName() for v in graph.inputs()}
  795. new_input_args = []
  796. for v, arg in zip(original_jit_graph.inputs(), input_args):
  797. if v.debugName() in new_input_names:
  798. new_input_args.append(arg)
  799. input_args = tuple(new_input_args)
  800. onnx_inputs = _prepare_input_for_onnx(
  801. input_args,
  802. {},
  803. verification_options.remained_onnx_input_idx,
  804. verification_options.flatten,
  805. )
  806. onnx_session = _onnx_backend_session(model_f, verification_options.backend)
  807. onnx_outs = _run_onnx(onnx_session, onnx_inputs)
  808. del onnx_session # To free device memory
  809. try:
  810. _compare_onnx_pytorch_outputs(
  811. onnx_outs=onnx_outs,
  812. pt_outs=jit_outs,
  813. options=verification_options,
  814. )
  815. except AssertionError as e:
  816. return e, graph, jit_outs, onnx_outs
  817. return None, graph, jit_outs, onnx_outs
  818. except Exception as e:
  819. print("Unexpected error during verification.")
  820. print("jit graph: ", original_jit_graph)
  821. print("onnx graph: ", graph)
  822. raise e
  823. class GraphInfoPrettyPrinter:
  824. graph_info: Optional[GraphInfo]
  825. upper_printer: Optional[GraphInfoPrettyPrinter]
  826. lower_printer: Optional[GraphInfoPrettyPrinter]
  827. graph_str_lambdas: Mapping[int, str]
  828. connector_str_lambdas: Mapping[int, str]
  829. children_str_lambdas: Mapping[int, str]
  830. def __init__(self, graph_info: Optional[GraphInfo]):
  831. self.graph_info = graph_info
  832. if (
  833. graph_info is not None
  834. and graph_info.upper_graph_info is not None
  835. and graph_info.lower_graph_info is not None
  836. ):
  837. self.upper_printer = GraphInfoPrettyPrinter(graph_info.upper_graph_info)
  838. self.lower_printer = GraphInfoPrettyPrinter(graph_info.lower_graph_info)
  839. else:
  840. self.upper_printer = None
  841. self.lower_printer = None
  842. @_beartype.beartype
  843. def _total_rows(self) -> int:
  844. if self.graph_info is None:
  845. return 1
  846. if self.upper_printer and self.lower_printer:
  847. return (
  848. self.upper_printer._total_rows() + self.lower_printer._total_rows() + 1
  849. )
  850. return 2 # Two lines: node count + id.
  851. @_beartype.beartype
  852. def _node_count_segment_str(self) -> str:
  853. if self.graph_info is None:
  854. return "..."
  855. node_count = self.graph_info.essential_node_count()
  856. has_mismatch = self.graph_info.has_mismatch()
  857. error_node_kind = (
  858. f"({self.graph_info.essential_node_kinds().pop()})"
  859. if node_count == 1 and has_mismatch
  860. else ""
  861. )
  862. return f"{node_count} {'X' if has_mismatch else '✓'} {error_node_kind}"
  863. @_beartype.beartype
  864. def _graph_id_segment_str(self) -> str:
  865. if self.graph_info is None:
  866. return ""
  867. return f"id: {self.graph_info.id}"
  868. @_beartype.beartype
  869. def _max_segment_columns(self) -> int:
  870. return max(
  871. map(len, (self._node_count_segment_str(), self._graph_id_segment_str()))
  872. )
  873. @_beartype.beartype
  874. def _graph_segment_str_at_line(self, line: int) -> str:
  875. """Get the string representation of the graph segment at the given line."""
  876. if line == 0:
  877. result_str = self._node_count_segment_str()
  878. result_str += " " * (self._max_segment_columns() - len(result_str))
  879. return result_str
  880. if line == 1:
  881. result_str = self._graph_id_segment_str()
  882. result_str += " " * (self._max_segment_columns() - len(result_str))
  883. return result_str
  884. if 0 <= line < self._total_rows():
  885. return " " * self._max_segment_columns()
  886. return ""
  887. @_beartype.beartype
  888. def _connector_segment_str_at_line(self, line: int) -> str:
  889. """Get the connector segment string at the given line."""
  890. if self.upper_printer is None and self.lower_printer is None:
  891. return ""
  892. upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1
  893. lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1
  894. if line == 0:
  895. return " __"
  896. elif line < upper_total_rows + 1:
  897. return " | "
  898. elif line == upper_total_rows + 1:
  899. return " |__"
  900. elif line < upper_total_rows + lower_total_rows + 1:
  901. return " "
  902. return ""
  903. @_beartype.beartype
  904. def _children_str_at_line(self, line: int) -> str:
  905. """Get the string representation of the children at the given line.
  906. Recursively calls `_str_at_line` on children nodes.
  907. """
  908. if self.upper_printer is None and self.lower_printer is None:
  909. return ""
  910. upper_total_rows = self.upper_printer._total_rows() if self.upper_printer else 1
  911. lower_total_rows = self.lower_printer._total_rows() if self.lower_printer else 1
  912. if 0 <= line < upper_total_rows:
  913. return (
  914. self.upper_printer._str_at_line(line) if self.upper_printer else "..."
  915. )
  916. elif upper_total_rows < line < upper_total_rows + lower_total_rows + 1:
  917. return (
  918. self.lower_printer._str_at_line(line - upper_total_rows - 1)
  919. if self.lower_printer
  920. else "..."
  921. )
  922. return ""
  923. @_beartype.beartype
  924. def _str_at_line(self, line: int) -> str:
  925. """Get the string representation of the graph at the given line."""
  926. return (
  927. self._graph_segment_str_at_line(line)
  928. + self._connector_segment_str_at_line(line)
  929. + self._children_str_at_line(line)
  930. )
  931. def pretty_print(self):
  932. if self.graph_info is None:
  933. print(None)
  934. return
  935. # Print tree.
  936. print(" Tree: ".center(80, "="))
  937. total_rows = self._total_rows()
  938. for line in range(total_rows):
  939. print(self._str_at_line(line).rstrip())
  940. if self.graph_info.has_mismatch():
  941. # Summarize leaf subgraphs with mismatch.
  942. print(" Mismatch leaf subgraphs: ".center(80, "="))
  943. print(
  944. [
  945. graph_info.id
  946. for graph_info in self.graph_info.all_mismatch_leaf_graph_info()
  947. ]
  948. )
  949. # Summarize node kinds with mismatch.
  950. mismatch_node_kinds: Dict[str, int] = {}
  951. for graph_info in self.graph_info.all_mismatch_leaf_graph_info():
  952. node_kinds = graph_info.essential_node_kinds()
  953. if len(node_kinds) == 1:
  954. node_kind = node_kinds.pop()
  955. mismatch_node_kinds[node_kind] = (
  956. mismatch_node_kinds.get(node_kind, 0) + 1
  957. )
  958. print(" Mismatch node kinds: ".center(80, "="))
  959. print(mismatch_node_kinds)
  960. else:
  961. print(" No mismatch found. ".center(80, "="))
  962. class OnnxTestCaseRepro:
  963. def __init__(self, repro_dir):
  964. self.repro_dir = repro_dir
  965. self.proto, self.inputs, self.outputs = onnx_proto_utils.load_test_case(
  966. repro_dir
  967. )
  968. @classmethod
  969. @_beartype.beartype
  970. def create_test_case_repro(
  971. cls, proto: bytes, inputs, outputs, dir: str, name: Optional[str] = None
  972. ):
  973. """Create a repro under "{dir}/test_{name}" for an ONNX test case.
  974. The test case contains the model and the inputs/outputs data. The directory
  975. structure is as follows:
  976. dir
  977. ├── test_<name>
  978. │ ├── model.onnx
  979. │ └── test_data_set_0
  980. │ ├── input_0.pb
  981. │ ├── input_1.pb
  982. │ ├── output_0.pb
  983. │ └── output_1.pb
  984. Args:
  985. proto: ONNX model proto.
  986. inputs: Inputs to the model.
  987. outputs: Outputs of the model.
  988. dir: Directory to save the repro.
  989. name: Name of the test case. If not specified, a name based on current time
  990. will be generated.
  991. Returns:
  992. Path to the repro.
  993. """
  994. if name is None:
  995. name = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
  996. return onnx_proto_utils.export_as_test_case(
  997. proto,
  998. _to_numpy(inputs),
  999. _to_numpy(outputs),
  1000. name,
  1001. dir,
  1002. )
  1003. @_beartype.beartype
  1004. def validate(self, options: VerificationOptions):
  1005. """Run the ONNX test case with options.backend, and compare with the expected outputs.
  1006. Args:
  1007. options: Options for validation.
  1008. Raise:
  1009. AssertionError: if outputs from options.backend and expected outputs are not
  1010. equal up to specified precision.
  1011. """
  1012. onnx_session = _onnx_backend_session(io.BytesIO(self.proto), options.backend)
  1013. run_outputs = onnx_session.run(None, self.inputs)
  1014. if hasattr(onnx_session, "get_outputs"):
  1015. output_names = [o.name for o in onnx_session.get_outputs()]
  1016. elif hasattr(onnx_session, "output_names"):
  1017. output_names = onnx_session.output_names
  1018. else:
  1019. raise ValueError(f"Unknown onnx session type: {type(onnx_session)}")
  1020. expected_outs = [self.outputs[name] for name in output_names]
  1021. _compare_onnx_pytorch_outputs_in_np(run_outputs, expected_outs, options)
  1022. @dataclasses.dataclass
  1023. class GraphInfo:
  1024. """GraphInfo contains validation information of a TorchScript graph and its converted ONNX graph."""
  1025. graph: torch.Graph
  1026. input_args: Tuple[Any, ...]
  1027. params_dict: Dict[str, Any]
  1028. export_options: _experimental.ExportOptions = dataclasses.field(
  1029. default_factory=_experimental.ExportOptions
  1030. )
  1031. mismatch_error: Optional[AssertionError] = dataclasses.field(
  1032. default=None, init=False
  1033. )
  1034. pt_outs: Optional[Sequence[_NumericType]] = dataclasses.field(
  1035. default=None, init=False
  1036. )
  1037. upper_graph_info: Optional[GraphInfo] = dataclasses.field(default=None, init=False)
  1038. lower_graph_info: Optional[GraphInfo] = dataclasses.field(default=None, init=False)
  1039. id: str = dataclasses.field(default="")
  1040. _onnx_graph: Optional[torch.Graph] = dataclasses.field(init=False, default=None)
  1041. _EXCLUDED_NODE_KINDS: FrozenSet[str] = frozenset(
  1042. {"prim::Constant", "prim::ListConstruct", "aten::ScalarImplicit"}
  1043. )
  1044. def clear(self):
  1045. """Clear states and results of previous verification."""
  1046. self.mismatch_error = None
  1047. self.pt_outs = None
  1048. self._onnx_graph = None
  1049. self.upper_graph_info = None
  1050. self.lower_graph_info = None
  1051. def pretty_print_tree(self):
  1052. """Pretty print `GraphInfo` tree.
  1053. Each node represents a subgraph, showing the number of nodes in the subgraph and
  1054. a check mark if the subgraph has output mismatch between torch and ONNX.
  1055. The id of the subgraph is shown under the node. The `GraphInfo` object for any
  1056. subgraph can be retrieved by calling `graph_info.find_partition(id)`.
  1057. Example::
  1058. ==================================== Tree: =====================================
  1059. 5 X __2 X __1 ✓
  1060. id: | id: 0 | id: 00
  1061. | |
  1062. | |__1 X (aten::relu)
  1063. | id: 01
  1064. |
  1065. |__3 X __1 ✓
  1066. id: 1 | id: 10
  1067. |
  1068. |__2 X __1 X (aten::relu)
  1069. id: 11 | id: 110
  1070. |
  1071. |__1 ✓
  1072. id: 111
  1073. =========================== Mismatch leaf subgraphs: ===========================
  1074. ['01', '110']
  1075. ============================= Mismatch node kinds: =============================
  1076. {'aten::relu': 2}
  1077. """
  1078. GraphInfoPrettyPrinter(self).pretty_print()
  1079. def pretty_print_mismatch(self, graph: bool = False):
  1080. """Pretty print details of the mismatch between torch and ONNX.
  1081. Args:
  1082. graph: If True, print the ATen JIT graph and ONNX graph.
  1083. """
  1084. print(f" Mismatch info for graph partition {self.id}: ".center(80, "="))
  1085. if graph:
  1086. print(" ATen JIT graph ".center(80, "="))
  1087. # TODO: A more compact graph printer.
  1088. # * Drop stride, grad, device information.
  1089. # * Show source location on a separate line.
  1090. print(self.graph)
  1091. if self._onnx_graph is not None:
  1092. print(" ONNX graph ".center(80, "="))
  1093. print(self._onnx_graph)
  1094. if self.has_mismatch():
  1095. print(" Mismatch error ".center(80, "="))
  1096. print(self.mismatch_error)
  1097. else:
  1098. print(" No mismatch ".center(80, "="))
  1099. @_beartype.beartype
  1100. def has_mismatch(self) -> bool:
  1101. """Return True if the subgraph has output mismatch between torch and ONNX."""
  1102. return self.mismatch_error is not None
  1103. @_beartype.beartype
  1104. def essential_node_count(self) -> int:
  1105. """Return the number of nodes in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
  1106. return sum(
  1107. 1 for n in self.graph.nodes() if n.kind() not in self._EXCLUDED_NODE_KINDS
  1108. )
  1109. @_beartype.beartype
  1110. def essential_node_kinds(self) -> Set[str]:
  1111. """Return the set of node kinds in the subgraph excluding those in `_EXCLUDED_NODE_KINDS`."""
  1112. return {
  1113. n.kind()
  1114. for n in self.graph.nodes()
  1115. if n.kind() not in self._EXCLUDED_NODE_KINDS
  1116. }
  1117. @_beartype.beartype
  1118. def all_mismatch_leaf_graph_info(self) -> List["GraphInfo"]:
  1119. """Return a list of all leaf `GraphInfo` objects that have mismatch."""
  1120. if not self.has_mismatch():
  1121. return []
  1122. no_mismatch_children = (
  1123. self.upper_graph_info is None or not self.upper_graph_info.has_mismatch()
  1124. ) and (
  1125. self.lower_graph_info is None or not self.lower_graph_info.has_mismatch()
  1126. )
  1127. if no_mismatch_children:
  1128. return [self]
  1129. results = []
  1130. if self.upper_graph_info is not None:
  1131. results += self.upper_graph_info.all_mismatch_leaf_graph_info()
  1132. if self.lower_graph_info is not None:
  1133. results += self.lower_graph_info.all_mismatch_leaf_graph_info()
  1134. return results
  1135. @_beartype.beartype
  1136. def find_partition(self, id: str) -> Optional["GraphInfo"]:
  1137. """Find the `GraphInfo` object with the given id."""
  1138. if id == self.id:
  1139. return self
  1140. current_length = len(self.id)
  1141. if len(id) > current_length:
  1142. if id[current_length] == "0" and self.upper_graph_info is not None:
  1143. return self.upper_graph_info.find_partition(id)
  1144. elif id[current_length] == "1" and self.lower_graph_info is not None:
  1145. return self.lower_graph_info.find_partition(id)
  1146. return None
  1147. @_beartype.beartype
  1148. def export_repro(
  1149. self, repro_dir: Optional[str] = None, name: Optional[str] = None
  1150. ) -> str:
  1151. """Export the subgraph to ONNX along with the input/output data for repro.
  1152. The repro directory will contain the following files::
  1153. dir
  1154. ├── test_<name>
  1155. │ ├── model.onnx
  1156. │ └── test_data_set_0
  1157. │ ├── input_0.pb
  1158. │ ├── input_1.pb
  1159. │ ├── output_0.pb
  1160. │ └── output_1.pb
  1161. Args:
  1162. repro_dir: The directory to export the repro files to. Defaults to current
  1163. working directory if None.
  1164. name: An optional name for the test case folder: "test_{name}".
  1165. Returns:
  1166. The path to the exported repro directory.
  1167. """
  1168. if repro_dir is None:
  1169. repro_dir = os.getcwd()
  1170. repro_dir = os.path.join(repro_dir, "onnx_debug")
  1171. onnx_graph, onnx_params_dict = _onnx_graph_from_aten_graph(
  1172. self.graph, self.export_options, self.params_dict
  1173. )
  1174. proto, _ = _onnx_proto_from_onnx_graph(
  1175. onnx_graph, self.export_options, onnx_params_dict
  1176. )
  1177. return OnnxTestCaseRepro.create_test_case_repro(
  1178. proto, self.input_args, self.pt_outs, repro_dir, name
  1179. )
  1180. @_beartype.beartype
  1181. def _graph_partition_pivot(self) -> int:
  1182. """Find the pivot index to partition the graph.
  1183. The pivot is the node that splits the graph into two parts. Each part should
  1184. have the similar amount of nodes, excluding non essential ops, defined in
  1185. `_EXCLUDED_NODE_KINDS`, such as `prim::Constant`.
  1186. If the graph has an odd number of nodes, the upper part will have one more node.
  1187. If the graph does not have any node that can be partitioned, return -1.
  1188. Returns:
  1189. The index of the pivot node.
  1190. """
  1191. included_node_indices = [
  1192. i
  1193. for i, n in enumerate(self.graph.nodes())
  1194. if n.kind() not in self._EXCLUDED_NODE_KINDS
  1195. ]
  1196. half_idx = len(included_node_indices) // 2 - 1
  1197. if half_idx >= 0 and len(included_node_indices) > half_idx:
  1198. return included_node_indices[half_idx] + 1
  1199. return -1
  1200. @_beartype.beartype
  1201. def _partition_upper_graph(self) -> torch.Graph:
  1202. pivot = self._graph_partition_pivot()
  1203. if pivot == -1:
  1204. return torch.Graph()
  1205. graph = self.graph.copy() # Copy to not mutate parent graph.
  1206. original_outputs = list(graph.outputs())
  1207. def _process_bridge_value_for_upper(
  1208. new_outputs: List[torch.Value], bridge_value: torch.Value
  1209. ) -> torch.Value:
  1210. # Add bridge values as upper graph outputs.
  1211. new_outputs.append(bridge_value)
  1212. return bridge_value
  1213. new_outputs: List[torch.Value] = []
  1214. process_bridge_value_for_upper = functools.partial(
  1215. _process_bridge_value_for_upper, new_outputs
  1216. )
  1217. _, dropped_nodes, complete_upper_nodes_set, _ = self._partition_nodes(
  1218. graph, pivot, process_bridge_value_for_upper
  1219. )
  1220. for _ in enumerate(original_outputs):
  1221. graph.eraseOutput(0)
  1222. for output in new_outputs:
  1223. graph.registerOutput(output)
  1224. for node in reversed(dropped_nodes):
  1225. node.destroy()
  1226. for i, input in reversed(list(enumerate(list(graph.inputs())))):
  1227. if (
  1228. not _has_uses_by_nodes(input, complete_upper_nodes_set)
  1229. and input not in new_outputs
  1230. ):
  1231. try:
  1232. graph.eraseInput(i)
  1233. except RuntimeError as e:
  1234. print(input, graph)
  1235. raise e
  1236. return graph
  1237. @_beartype.beartype
  1238. def _partition_lower_graph(self) -> torch.Graph:
  1239. pivot = self._graph_partition_pivot()
  1240. if pivot == -1:
  1241. return torch.Graph()
  1242. graph = self.graph.copy() # Copy to not mutate parent graph.
  1243. original_outputs = list(graph.outputs())
  1244. original_inputs = list(graph.inputs())
  1245. new_outputs = []
  1246. def _process_bridge_value_for_lower(
  1247. graph: torch.Graph, bridge_value: torch.Value
  1248. ) -> torch.Value:
  1249. # Add bridge values as lower graph inputs.
  1250. new_input = graph.addInput()
  1251. bridge_value.replaceAllUsesWith(new_input)
  1252. new_input.copyMetadata(bridge_value)
  1253. return new_input
  1254. process_bridge_value_for_lower = functools.partial(
  1255. _process_bridge_value_for_lower, graph
  1256. )
  1257. upper_nodes, lower_nodes, _, complete_lower_nodes_set = self._partition_nodes(
  1258. graph, pivot, process_bridge_value_for_lower
  1259. )
  1260. for output in original_outputs:
  1261. if _produced_by(output, lower_nodes):
  1262. new_outputs.append(output)
  1263. for _ in enumerate(original_outputs):
  1264. graph.eraseOutput(0)
  1265. for output in new_outputs:
  1266. graph.registerOutput(output)
  1267. for input in original_inputs:
  1268. if _has_uses_by_nodes(input, complete_lower_nodes_set):
  1269. new_input = graph.addInput()
  1270. input.replaceAllUsesWith(new_input)
  1271. new_input.copyMetadata(input)
  1272. for node in reversed(upper_nodes):
  1273. if node not in complete_lower_nodes_set:
  1274. try:
  1275. node.destroy()
  1276. except RuntimeError as e:
  1277. print(node, graph)
  1278. raise e
  1279. for _ in original_inputs:
  1280. graph.eraseInput(0)
  1281. return graph
  1282. @_beartype.beartype
  1283. def _partition_node(
  1284. self,
  1285. node: torch.Node,
  1286. complete_upper_nodes_set: Set[torch.Node],
  1287. complete_lower_nodes_set: Set[torch.Node],
  1288. original_graph_outputs: Set[torch.Value],
  1289. covered_bridge_values: Set[torch.Value],
  1290. process_bridge_value: Callable[[torch.Value], torch.Value],
  1291. ):
  1292. if node in complete_lower_nodes_set:
  1293. return
  1294. if (
  1295. _node_has_uses_by(node, complete_lower_nodes_set)
  1296. and node.kind() in self._EXCLUDED_NODE_KINDS
  1297. ):
  1298. complete_lower_nodes_set.update(_all_nodes([node]))
  1299. for input in node.inputs():
  1300. if input in covered_bridge_values:
  1301. continue
  1302. self._partition_node(
  1303. input.node(),
  1304. complete_upper_nodes_set,
  1305. complete_lower_nodes_set,
  1306. original_graph_outputs,
  1307. covered_bridge_values,
  1308. process_bridge_value,
  1309. )
  1310. else:
  1311. for output in node.outputs():
  1312. if output in covered_bridge_values:
  1313. continue
  1314. if (
  1315. _has_uses_by_nodes(output, complete_lower_nodes_set)
  1316. or output in original_graph_outputs
  1317. ):
  1318. covered_bridge_values.add(process_bridge_value(output))
  1319. @_beartype.beartype
  1320. def _partition_nodes(
  1321. self,
  1322. graph: torch.Graph,
  1323. pivot: int,
  1324. process_bridge_value: Callable[[torch.Value], torch.Value],
  1325. ) -> Tuple[List[torch.Node], List[torch.Node], Set[torch.Node], Set[torch.Node]]:
  1326. nodes = list(graph.nodes())
  1327. upper_nodes = nodes[:pivot]
  1328. lower_nodes = nodes[pivot:]
  1329. # `upper_nodes` and `complete_upper_nodes_set` differs in that the latter
  1330. # recursively contains nodes in subblock of `upper_nodes`.
  1331. # The same applies for `lower_nodes` and `complete_lower_nodes_set`.
  1332. # With addition that `complete_lower_nodes_set` will include nodes that
  1333. # are determined to be copied from `upper_nodes` to `lower_nodes`.
  1334. complete_upper_nodes_set = _all_nodes(upper_nodes)
  1335. complete_lower_nodes_set = _all_nodes(lower_nodes)
  1336. original_graph_outputs = set(graph.outputs())
  1337. # Bridge values are values produced from upper graph, and consumed
  1338. # by lower graph. These values need to be become upper graph outputs
  1339. # and lower graph inputs, to bridge the interaction.
  1340. # Start with all graph inputs marked as covered. If any graph input is
  1341. # needed by lower graph, just keep it in lower graph inputs later.
  1342. covered_bridge_values = set(graph.inputs())
  1343. for node in upper_nodes:
  1344. self._partition_node(
  1345. node,
  1346. complete_upper_nodes_set,
  1347. complete_lower_nodes_set,
  1348. original_graph_outputs,
  1349. covered_bridge_values,
  1350. process_bridge_value,
  1351. )
  1352. return (
  1353. upper_nodes,
  1354. lower_nodes,
  1355. complete_upper_nodes_set,
  1356. complete_lower_nodes_set,
  1357. )
  1358. @_beartype.beartype
  1359. def _bridge_kwargs(self):
  1360. pt_outs = self.pt_outs
  1361. graph_outputs = list(self.graph.outputs())
  1362. assert pt_outs is not None
  1363. assert len(graph_outputs) == len(
  1364. pt_outs
  1365. ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}"
  1366. return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)}
  1367. @_beartype.beartype
  1368. def _args_and_params_for_partition_graph(
  1369. self,
  1370. graph: torch.Graph,
  1371. bridge_kwargs: Mapping[str, Union[_NumericType, Sequence[_NumericType]]],
  1372. full_kwargs: Mapping[str, torch.Tensor],
  1373. full_params: Mapping[str, torch.Tensor],
  1374. ):
  1375. input_names = [input.debugName() for input in graph.inputs()]
  1376. args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs)
  1377. args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs)
  1378. params = {k: full_params[k] for k in input_names if k in full_params}
  1379. assert len(args) + len(params) == len(
  1380. input_names
  1381. ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}"
  1382. return args, params
  1383. @_beartype.beartype
  1384. def verify_export(
  1385. self, options: VerificationOptions
  1386. ) -> Tuple[Optional[AssertionError], torch.Graph, _OutputsType, _OutputsType]:
  1387. """
  1388. Verify the export from TorchScript IR graph to ONNX.
  1389. Export the TorchScript IR graph to ONNX, with the inputs, parameters and export
  1390. options recorded in this object. Then verify the exported ONNX graph against
  1391. the original TorchScript IR graph under the provided verification options.
  1392. Args:
  1393. options: The verification options.
  1394. Returns:
  1395. error: The AssertionError raised during the verification. Returns None if no
  1396. error is raised.
  1397. onnx_graph: The exported ONNX graph in TorchScript IR format.
  1398. onnx_outs: The outputs from running exported ONNX model under the onnx
  1399. backend in `options`.
  1400. pt_outs: The outputs from running the TorchScript IR graph.
  1401. """
  1402. return verify_aten_graph(
  1403. self.graph,
  1404. input_args=self.input_args,
  1405. params_dict=self.params_dict,
  1406. export_options=self.export_options,
  1407. verification_options=options,
  1408. )
  1409. @_beartype.beartype
  1410. def find_mismatch(
  1411. self,
  1412. options: Optional[VerificationOptions] = None,
  1413. ):
  1414. """
  1415. Find all mismatches between the TorchScript IR graph and the exported onnx model.
  1416. Binary searches the model graph to find the minimal subgraph that exhibits the
  1417. mismatch. A `GraphInfo` object is created for each subgraph, recording the test
  1418. inputs and export options, as well as the validation results.
  1419. Args:
  1420. options: The verification options.
  1421. """
  1422. self.clear()
  1423. if options is None:
  1424. options = VerificationOptions()
  1425. if self.export_options.verbose:
  1426. print(self.graph)
  1427. if len(list(self.graph.outputs())) == 0:
  1428. return
  1429. assert len(self.input_args) + len(self.params_dict) == len(
  1430. list(self.graph.inputs())
  1431. ), (
  1432. f"Number of graph inputs({len(list(self.graph.inputs()))}) does not match "
  1433. f"the provided tensor arguments({len(self.input_args)} + {len(self.params_dict)})."
  1434. )
  1435. self.mismatch_error, self._onnx_graph, self.pt_outs, _ = self.verify_export(
  1436. options
  1437. )
  1438. if self.mismatch_error is None:
  1439. # No mismatch found in graph.
  1440. return
  1441. if self.essential_node_count() <= 1:
  1442. # Reached leaf node, no more partitioning.
  1443. return
  1444. full_kwargs = {
  1445. k.debugName(): v for k, v in zip(self.graph.inputs(), self.input_args)
  1446. }
  1447. full_params = self.params_dict
  1448. upper_graph = self._partition_upper_graph()
  1449. upper_args, upper_params = self._args_and_params_for_partition_graph(
  1450. upper_graph, {}, full_kwargs, full_params
  1451. )
  1452. self.upper_graph_info = GraphInfo(
  1453. upper_graph,
  1454. upper_args,
  1455. upper_params,
  1456. self.export_options,
  1457. id=self.id + "0",
  1458. )
  1459. self.upper_graph_info.find_mismatch(options)
  1460. bridge_kwargs = self.upper_graph_info._bridge_kwargs()
  1461. lower_graph = self._partition_lower_graph()
  1462. lower_args, lower_params = self._args_and_params_for_partition_graph(
  1463. lower_graph, bridge_kwargs, full_kwargs, full_params
  1464. )
  1465. self.lower_graph_info = GraphInfo(
  1466. lower_graph,
  1467. lower_args,
  1468. lower_params,
  1469. self.export_options,
  1470. id=self.id + "1",
  1471. )
  1472. self.lower_graph_info.find_mismatch(options)
  1473. @_beartype.beartype
  1474. def _all_nodes(nodes: Collection[torch.Node]) -> Set[torch.Node]:
  1475. all_nodes = set(nodes)
  1476. for n in nodes:
  1477. for b in n.blocks():
  1478. all_nodes.update(_all_nodes(list(b.nodes())))
  1479. return all_nodes
  1480. @_beartype.beartype
  1481. def _has_uses_by_nodes(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
  1482. if any(use.user in nodes for use in value.uses()):
  1483. return True
  1484. return False
  1485. @_beartype.beartype
  1486. def _node_has_uses_by(node: torch.Node, nodes: Collection[torch.Node]) -> bool:
  1487. for output in node.outputs():
  1488. if _has_uses_by_nodes(output, nodes):
  1489. return True
  1490. return False
  1491. @_beartype.beartype
  1492. def _produced_by(value: torch.Value, nodes: Collection[torch.Node]) -> bool:
  1493. return value.node() in nodes
  1494. @_beartype.beartype
  1495. def find_mismatch(
  1496. model: Union[torch.nn.Module, torch.jit.ScriptModule],
  1497. input_args: Tuple[Any, ...],
  1498. do_constant_folding: bool = True,
  1499. training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,
  1500. opset_version: Optional[int] = None,
  1501. keep_initializers_as_inputs: bool = True,
  1502. verbose: bool = False,
  1503. options: Optional[VerificationOptions] = None,
  1504. ) -> GraphInfo:
  1505. r"""Find all mismatches between the original model and the exported model.
  1506. Experimental. The API is subject to change.
  1507. This tool helps debug the mismatch between the original PyTorch model and exported
  1508. ONNX model. It binary searches the model graph to find the minimal subgraph that
  1509. exhibits the mismatch.
  1510. Args:
  1511. model: The model to be exported.
  1512. input_args: The input arguments to the model.
  1513. do_constant_folding: Same as `do_constant_folding` in :func:`torch.onnx.export`.
  1514. training: Same as `training` in :func:`torch.onnx.export`.
  1515. opset_version: Same as `opset_version` in :func:`torch.onnx.export`.
  1516. keep_initializers_as_inputs: Same as `keep_initializers_as_inputs` in :func:`torch.onnx.export`.
  1517. verbose: Same as `verbose` in :func:`torch.onnx.export`.
  1518. options: The options for the mismatch verification.
  1519. Returns:
  1520. A GraphInfo object that contains the mismatch information.
  1521. Example::
  1522. >>> import torch
  1523. >>> import torch.onnx.verification
  1524. >>> torch.manual_seed(0)
  1525. >>> opset_version = 15
  1526. >>> # Define a custom symbolic function for aten::relu.
  1527. >>> # The custom symbolic function is incorrect, which will result in mismatches.
  1528. >>> def incorrect_relu_symbolic_function(g, self):
  1529. ... return self
  1530. >>> torch.onnx.register_custom_op_symbolic(
  1531. ... "aten::relu",
  1532. ... incorrect_relu_symbolic_function,
  1533. ... opset_version=opset_version,
  1534. ... )
  1535. >>> class Model(torch.nn.Module):
  1536. ... def __init__(self):
  1537. ... super().__init__()
  1538. ... self.layers = torch.nn.Sequential(
  1539. ... torch.nn.Linear(3, 4),
  1540. ... torch.nn.ReLU(),
  1541. ... torch.nn.Linear(4, 5),
  1542. ... torch.nn.ReLU(),
  1543. ... torch.nn.Linear(5, 6),
  1544. ... )
  1545. ... def forward(self, x):
  1546. ... return self.layers(x)
  1547. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
  1548. >>> graph_info = torch.onnx.verification.find_mismatch(
  1549. ... Model(),
  1550. ... (torch.randn(2, 3),),
  1551. ... opset_version=opset_version,
  1552. ... )
  1553. ===================== Mismatch info for graph partition : ======================
  1554. ================================ Mismatch error ================================
  1555. Tensor-likes are not close!
  1556. Mismatched elements: 12 / 12 (100.0%)
  1557. Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
  1558. Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
  1559. ==================================== Tree: =====================================
  1560. 5 X __2 X __1 ✓
  1561. id: | id: 0 | id: 00
  1562. | |
  1563. | |__1 X (aten::relu)
  1564. | id: 01
  1565. |
  1566. |__3 X __1 ✓
  1567. id: 1 | id: 10
  1568. |
  1569. |__2 X __1 X (aten::relu)
  1570. id: 11 | id: 110
  1571. |
  1572. |__1 ✓
  1573. id: 111
  1574. =========================== Mismatch leaf subgraphs: ===========================
  1575. ['01', '110']
  1576. ============================= Mismatch node kinds: =============================
  1577. {'aten::relu': 2}
  1578. """
  1579. if options is None:
  1580. options = VerificationOptions()
  1581. if opset_version is None:
  1582. opset_version = _constants.ONNX_DEFAULT_OPSET
  1583. """From aten graph, do binary search on graph partition to find operator export discrepancy."""
  1584. # TODO: Copied from utils.py `export` until `_optimize_graph`.
  1585. if training == torch.onnx.TrainingMode.TRAINING:
  1586. model.train()
  1587. elif training == torch.onnx.TrainingMode.EVAL:
  1588. model.eval()
  1589. with torch.no_grad():
  1590. inputs_for_export = _prepare_input_for_export(input_args, {})
  1591. args = utils._decide_input_format(model, inputs_for_export)
  1592. model = utils._pre_trace_quant_model(model, args)
  1593. graph, params, torch_out, module = utils._create_jit_graph(model, args)
  1594. params_dict = utils._get_named_param_dict(graph, params)
  1595. utils._apply_friendly_debug_names(graph, params_dict)
  1596. graph_info = GraphInfo(
  1597. graph,
  1598. input_args,
  1599. params_dict,
  1600. _experimental.ExportOptions(
  1601. do_constant_folding=do_constant_folding,
  1602. training=training,
  1603. opset_version=opset_version,
  1604. keep_initializers_as_inputs=keep_initializers_as_inputs,
  1605. verbose=verbose,
  1606. ),
  1607. )
  1608. graph_info.find_mismatch(options)
  1609. graph_info.pretty_print_mismatch()
  1610. graph_info.pretty_print_tree()
  1611. return graph_info