onnx_proto_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. """Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
  2. import glob
  3. import io
  4. import os
  5. import shutil
  6. import zipfile
  7. from typing import Any, List, Mapping, Set, Tuple, Union
  8. import torch
  9. import torch.jit._trace
  10. import torch.serialization
  11. from torch.onnx import _constants, _exporter_states, errors
  12. from torch.onnx._internal import _beartype, jit_utils, registration
  13. @_beartype.beartype
  14. def export_as_test_case(
  15. model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
  16. ) -> str:
  17. """Export an ONNX model as a self contained ONNX test case.
  18. The test case contains the model and the inputs/outputs data. The directory structure
  19. is as follows:
  20. dir
  21. ├── test_<name>
  22. │ ├── model.onnx
  23. │ └── test_data_set_0
  24. │ ├── input_0.pb
  25. │ ├── input_1.pb
  26. │ ├── output_0.pb
  27. │ └── output_1.pb
  28. Args:
  29. model_bytes: The ONNX model in bytes.
  30. inputs_data: The inputs data, nested data structure of numpy.ndarray.
  31. outputs_data: The outputs data, nested data structure of numpy.ndarray.
  32. Returns:
  33. The path to the test case directory.
  34. """
  35. try:
  36. import onnx
  37. except ImportError:
  38. raise ImportError(
  39. "Export test case to ONNX format failed: Please install ONNX."
  40. )
  41. test_case_dir = os.path.join(dir, "test_" + name)
  42. os.makedirs(test_case_dir, exist_ok=True)
  43. _export_file(
  44. model_bytes,
  45. os.path.join(test_case_dir, "model.onnx"),
  46. _exporter_states.ExportTypes.PROTOBUF_FILE,
  47. {},
  48. )
  49. data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
  50. if os.path.exists(data_set_dir):
  51. shutil.rmtree(data_set_dir)
  52. os.makedirs(data_set_dir)
  53. proto = onnx.load_from_string(model_bytes)
  54. for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
  55. export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
  56. for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)):
  57. export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb"))
  58. return test_case_dir
  59. @_beartype.beartype
  60. def load_test_case(dir: str) -> Tuple[bytes, Any, Any]:
  61. """Load a self contained ONNX test case from a directory.
  62. The test case must contain the model and the inputs/outputs data. The directory structure
  63. should be as follows:
  64. dir
  65. ├── test_<name>
  66. │ ├── model.onnx
  67. │ └── test_data_set_0
  68. │ ├── input_0.pb
  69. │ ├── input_1.pb
  70. │ ├── output_0.pb
  71. │ └── output_1.pb
  72. Args:
  73. dir: The directory containing the test case.
  74. Returns:
  75. model_bytes: The ONNX model in bytes.
  76. inputs: the inputs data, mapping from input name to numpy.ndarray.
  77. outputs: the outputs data, mapping from output name to numpy.ndarray.
  78. """
  79. try:
  80. import onnx
  81. from onnx import numpy_helper
  82. except ImportError:
  83. raise ImportError(
  84. "Load test case from ONNX format failed: Please install ONNX."
  85. )
  86. with open(os.path.join(dir, "model.onnx"), "rb") as f:
  87. model_bytes = f.read()
  88. test_data_dir = os.path.join(dir, "test_data_set_0")
  89. inputs = {}
  90. input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
  91. for input_file in input_files:
  92. tensor = onnx.load_tensor(input_file)
  93. inputs[tensor.name] = numpy_helper.to_array(tensor)
  94. outputs = {}
  95. output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
  96. for output_file in output_files:
  97. tensor = onnx.load_tensor(output_file)
  98. outputs[tensor.name] = numpy_helper.to_array(tensor)
  99. return model_bytes, inputs, outputs
  100. @_beartype.beartype
  101. def export_data(data, value_info_proto, f: str) -> None:
  102. """Export data to ONNX protobuf format.
  103. Args:
  104. data: The data to export, nested data structure of numpy.ndarray.
  105. value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto
  106. determines how the data is stored.
  107. f: The file to write the data to.
  108. """
  109. try:
  110. from onnx import numpy_helper
  111. except ImportError:
  112. raise ImportError("Export data to ONNX format failed: Please install ONNX.")
  113. with open(f, "wb") as opened_file:
  114. if value_info_proto.type.HasField("map_type"):
  115. opened_file.write(
  116. numpy_helper.from_dict(data, value_info_proto.name).SerializeToString()
  117. )
  118. elif value_info_proto.type.HasField("sequence_type"):
  119. opened_file.write(
  120. numpy_helper.from_list(data, value_info_proto.name).SerializeToString()
  121. )
  122. elif value_info_proto.type.HasField("optional_type"):
  123. opened_file.write(
  124. numpy_helper.from_optional(
  125. data, value_info_proto.name
  126. ).SerializeToString()
  127. )
  128. else:
  129. assert value_info_proto.type.HasField("tensor_type")
  130. opened_file.write(
  131. numpy_helper.from_array(data, value_info_proto.name).SerializeToString()
  132. )
  133. @_beartype.beartype
  134. def _export_file(
  135. model_bytes: bytes,
  136. f: Union[io.BytesIO, str],
  137. export_type: str,
  138. export_map: Mapping[str, bytes],
  139. ) -> None:
  140. """export/write model bytes into directory/protobuf/zip"""
  141. # TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f,
  142. # but beartype raises beartype.roar.BeartypeDecorHintNonpepException,
  143. # as os.PathLike[str] uncheckable at runtime
  144. if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE:
  145. assert len(export_map) == 0
  146. with torch.serialization._open_file_like(f, "wb") as opened_file:
  147. opened_file.write(model_bytes)
  148. elif export_type in {
  149. _exporter_states.ExportTypes.ZIP_ARCHIVE,
  150. _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE,
  151. }:
  152. compression = (
  153. zipfile.ZIP_DEFLATED
  154. if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE
  155. else zipfile.ZIP_STORED
  156. )
  157. with zipfile.ZipFile(f, "w", compression=compression) as z:
  158. z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes)
  159. for k, v in export_map.items():
  160. z.writestr(k, v)
  161. elif export_type == _exporter_states.ExportTypes.DIRECTORY:
  162. if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type]
  163. raise ValueError(
  164. f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}"
  165. )
  166. if not os.path.exists(f): # type: ignore[arg-type]
  167. os.makedirs(f) # type: ignore[arg-type]
  168. model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type]
  169. with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file:
  170. opened_file.write(model_bytes)
  171. for k, v in export_map.items():
  172. weight_proto_file = os.path.join(f, k) # type: ignore[arg-type]
  173. with torch.serialization._open_file_like(
  174. weight_proto_file, "wb"
  175. ) as opened_file:
  176. opened_file.write(v)
  177. else:
  178. raise ValueError("Unknown export type")
  179. @_beartype.beartype
  180. def _add_onnxscript_fn(
  181. model_bytes: bytes,
  182. custom_opsets: Mapping[str, int],
  183. ) -> bytes:
  184. """Insert model-included custom onnx-script function into ModelProto"""
  185. # TODO(titaiwang): remove this when onnx becomes dependency
  186. try:
  187. import onnx
  188. except ImportError as e:
  189. raise errors.OnnxExporterError("Module onnx is not installed!") from e
  190. # For > 2GB model, onnx.load_fromstring would fail. However, because
  191. # in _export_onnx, the tensors should be saved separately if the proto
  192. # size > 2GB, and if it for some reason did not, the model would fail on
  193. # serialization anyway in terms of the protobuf limitation. So we don't
  194. # need to worry about > 2GB model getting here.
  195. model_proto = onnx.load_from_string(model_bytes)
  196. # Iterate graph nodes to insert only the included custom
  197. # function_proto into model_proto
  198. # TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction
  199. # calling other ONNXFunction scenario, neither does it here
  200. onnx_function_list = list() # type: ignore[var-annotated]
  201. included_node_func = set() # type: Set[str]
  202. # onnx_function_list and included_node_func are expanded in-place
  203. _find_onnxscript_op(
  204. model_proto.graph, included_node_func, custom_opsets, onnx_function_list
  205. )
  206. if onnx_function_list:
  207. model_proto.functions.extend(onnx_function_list)
  208. model_bytes = model_proto.SerializeToString()
  209. return model_bytes
  210. @_beartype.beartype
  211. def _find_onnxscript_op(
  212. graph_proto,
  213. included_node_func: Set[str],
  214. custom_opsets: Mapping[str, int],
  215. onnx_function_list: List,
  216. ):
  217. """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
  218. for node in graph_proto.node:
  219. node_kind = node.domain + "::" + node.op_type
  220. # Recursive needed for control flow nodes: IF/Loop which has inner graph_proto
  221. for attr in node.attribute:
  222. if attr.g is not None:
  223. _find_onnxscript_op(
  224. attr.g, included_node_func, custom_opsets, onnx_function_list
  225. )
  226. # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
  227. onnx_function_group = registration.registry.get_function_group(node_kind)
  228. # Ruled out corner cases: onnx/prim in registry
  229. if (
  230. node.domain
  231. and not jit_utils.is_aten(node.domain)
  232. and not jit_utils.is_prim(node.domain)
  233. and not jit_utils.is_onnx(node.domain)
  234. and onnx_function_group is not None
  235. and node_kind not in included_node_func
  236. ):
  237. specified_version = custom_opsets.get(node.domain, 1)
  238. onnx_fn = onnx_function_group.get(specified_version)
  239. if onnx_fn is not None:
  240. # TODO(titaiwang): to_function_proto is onnx-script API and can be annotated
  241. # after onnx-script is dependency
  242. onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined]
  243. included_node_func.add(node_kind)
  244. continue
  245. raise errors.UnsupportedOperatorError(
  246. node_kind,
  247. specified_version,
  248. onnx_function_group.get_min_supported()
  249. if onnx_function_group
  250. else None,
  251. )
  252. return onnx_function_list, included_node_func