context.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import copy
  2. from typing import List
  3. import torch
  4. class FxToOnnxContext:
  5. """Context manager to make PyTorch friendly to FX-to-ONNX exporter.
  6. This class means to collect all "patches" required by FX-to-ONNX
  7. exporter. If PyTorch needs to be patched, please use this class to
  8. manage the patch.
  9. This context overrides several torch functions to support symbolic
  10. export of large scale models.
  11. torch.load:
  12. This function is patched to record the files PyTorch stores model
  13. parameters and buffers. Downstream FX-to-ONNX exporter can create
  14. initializers from these files.
  15. torch._util._rebuild_tensor:
  16. This function is patched to avoid creating real tensors during
  17. model loading. FakeTensor's are created instead. Real tensors
  18. cannot be fitted into single machine's memory for the targeted
  19. model scale.
  20. torch.fx._symbolic_trace._wrapped_methods_to_patch:
  21. This list is extended with (torch.Tensor, "__getitem__") so that
  22. weight[x, :, y] becomes exportable with torch.fx.symbolic_trace.
  23. Search for FxToOnnxContext in test_fx_to_onnx_with_onnxruntime.py for
  24. example usage.
  25. """
  26. def __init__(self):
  27. # List of file paths processed by torch.load.
  28. self.paths: List[str] = []
  29. def torch_load_wrapper(f, *args, **kwargs):
  30. # Record path.
  31. self.paths.append(f)
  32. # Then, call the original torch.load.
  33. return self.torch_load(f, *args, **kwargs)
  34. def torch__util__rebuild_tensor_wrapper(storage, storage_offset, size, stride):
  35. from torch._subclasses.fake_tensor import FakeTensorMode
  36. from torch.utils._mode_utils import no_dispatch
  37. from torch.utils._python_dispatch import _get_current_dispatch_mode
  38. def _rebuild_real_tensor(storage, storage_offset, size, stride):
  39. t = torch.tensor(
  40. [], dtype=storage.dtype, device=storage._untyped_storage.device
  41. )
  42. return t.set_(storage._untyped_storage, storage_offset, size, stride)
  43. mode = _get_current_dispatch_mode()
  44. if isinstance(mode, FakeTensorMode):
  45. # Create a real tensor and then convert it to FakeTensor.
  46. # We cannot directly create a FakeTensor because it tensor.set_(...)
  47. # is not supported in FakeTensorMode dispatcher.
  48. with no_dispatch():
  49. t = _rebuild_real_tensor(storage, storage_offset, size, stride)
  50. return mode.from_tensor(t)
  51. return _rebuild_real_tensor(storage, storage_offset, size, stride)
  52. # Original version of torch.load.
  53. self.torch_load = torch.load
  54. self.torch__util_rebuild_tensor = torch._utils._rebuild_tensor
  55. # Wrapper or modified version of torch functions.
  56. self.torch_load_wrapper = torch_load_wrapper
  57. self.torch__util_rebuild_tensor_wrapper = torch__util__rebuild_tensor_wrapper
  58. def __enter__(self):
  59. torch.load = self.torch_load_wrapper
  60. torch._utils._rebuild_tensor = self.torch__util_rebuild_tensor_wrapper
  61. self.torch_fx__symbolic_trace__wrapped_methods_to_patch = (
  62. torch.fx._symbolic_trace._wrapped_methods_to_patch
  63. )
  64. desired_wrapped_methods = copy.deepcopy(
  65. torch.fx._symbolic_trace._wrapped_methods_to_patch
  66. )
  67. if (torch.Tensor, "__getitem__") not in desired_wrapped_methods:
  68. # Adding `__getitem__` to the patching list will make tensor indexing traceable via
  69. # torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced.
  70. # This happens because `__getitem__` is neither under torch domain nor an aten operator,
  71. # so the patching (or similar Proxy-generating mechanism) doesn't happen automatically.
  72. # Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for
  73. # enabling the line below for patching.
  74. desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
  75. torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
  76. def __exit__(self, exc_type, exc_value, traceback):
  77. torch.load = self.torch_load
  78. torch._utils._rebuild_tensor = self.torch__util_rebuild_tensor
  79. torch.fx._symbolic_trace._wrapped_methods_to_patch = (
  80. self.torch_fx__symbolic_trace__wrapped_methods_to_patch
  81. )