onnxrt.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import importlib
  2. import os
  3. import tempfile
  4. import torch
  5. from .common import device_from_inputs, fake_tensor_unsupported
  6. from .registry import register_backend
  7. try:
  8. import numpy as np
  9. _np_dtype = {
  10. torch.float16: np.float16,
  11. torch.float32: np.float32,
  12. torch.float64: np.float64,
  13. torch.uint8: np.uint8,
  14. torch.int8: np.int8,
  15. torch.int16: np.int16,
  16. torch.int32: np.int32,
  17. torch.int64: np.longlong,
  18. torch.bool: np.bool_,
  19. }
  20. except ImportError:
  21. _np_dtype = None
  22. def default_provider(device_type):
  23. if "ONNXRT_PROVIDER" in os.environ:
  24. return os.environ["ONNXRT_PROVIDER"]
  25. return {
  26. "cpu": "CPUExecutionProvider",
  27. "cuda": "CUDAExecutionProvider",
  28. # "TensorrtExecutionProvider" is another option
  29. }[device_type]
  30. def has_onnxruntime():
  31. try:
  32. importlib.import_module("onnxruntime")
  33. return True
  34. except ImportError:
  35. return False
  36. @register_backend
  37. @fake_tensor_unsupported
  38. def onnxrt(gm, example_inputs, *, filename=None, provider=None):
  39. if filename is None:
  40. with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp:
  41. return onnxrt(gm, example_inputs, filename=tmp.name)
  42. import onnxruntime # type: ignore[import]
  43. assert _np_dtype, "requires numpy"
  44. device_type = device_from_inputs(example_inputs).type
  45. example_outputs = gm(*example_inputs)
  46. output_spec = [
  47. (o.shape, o.dtype, o.layout, o.device, o.requires_grad) for o in example_outputs
  48. ]
  49. input_names = [f"i{i}" for i in range(len(example_inputs))]
  50. output_names = [f"o{x}" for x in range(len(example_outputs))]
  51. torch.onnx.export(
  52. torch.jit.script(gm),
  53. example_inputs,
  54. filename,
  55. input_names=input_names,
  56. output_names=output_names,
  57. )
  58. del example_inputs, example_outputs
  59. if provider is None:
  60. provider = default_provider(device_type)
  61. assert provider in onnxruntime.get_available_providers()
  62. session = onnxruntime.InferenceSession(filename, providers=[provider])
  63. def _call(*initial_args):
  64. binding = session.io_binding()
  65. args = [a.contiguous() for a in initial_args]
  66. for name, value in zip(input_names, args):
  67. dev = value.device
  68. binding.bind_input(
  69. name,
  70. dev.type,
  71. dev.index or 0,
  72. _np_dtype[value.dtype],
  73. value.size(),
  74. value.data_ptr(),
  75. )
  76. outputs = [
  77. torch.empty(
  78. shape,
  79. dtype=dtype,
  80. layout=layout,
  81. device=device,
  82. requires_grad=requires_grad,
  83. )
  84. for shape, dtype, layout, device, requires_grad in output_spec
  85. ]
  86. for name, value in zip(output_names, outputs):
  87. dev = value.device
  88. binding.bind_output(
  89. name,
  90. dev.type,
  91. dev.index or 0,
  92. _np_dtype[value.dtype],
  93. value.size(),
  94. value.data_ptr(),
  95. )
  96. session.run_with_iobinding(binding)
  97. if device_type == "cpu":
  98. binding.copy_outputs_to_cpu()
  99. return outputs
  100. return _call