123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import copy
- from typing import List
- import torch
- class FxToOnnxContext:
- """Context manager to make PyTorch friendly to FX-to-ONNX exporter.
- This class means to collect all "patches" required by FX-to-ONNX
- exporter. If PyTorch needs to be patched, please use this class to
- manage the patch.
- This context overrides several torch functions to support symbolic
- export of large scale models.
- torch.load:
- This function is patched to record the files PyTorch stores model
- parameters and buffers. Downstream FX-to-ONNX exporter can create
- initializers from these files.
- torch._util._rebuild_tensor:
- This function is patched to avoid creating real tensors during
- model loading. FakeTensor's are created instead. Real tensors
- cannot be fitted into single machine's memory for the targeted
- model scale.
- torch.fx._symbolic_trace._wrapped_methods_to_patch:
- This list is extended with (torch.Tensor, "__getitem__") so that
- weight[x, :, y] becomes exportable with torch.fx.symbolic_trace.
- Search for FxToOnnxContext in test_fx_to_onnx_with_onnxruntime.py for
- example usage.
- """
- def __init__(self):
- # List of file paths processed by torch.load.
- self.paths: List[str] = []
- def torch_load_wrapper(f, *args, **kwargs):
- # Record path.
- self.paths.append(f)
- # Then, call the original torch.load.
- return self.torch_load(f, *args, **kwargs)
- def torch__util__rebuild_tensor_wrapper(storage, storage_offset, size, stride):
- from torch._subclasses.fake_tensor import FakeTensorMode
- from torch.utils._mode_utils import no_dispatch
- from torch.utils._python_dispatch import _get_current_dispatch_mode
- def _rebuild_real_tensor(storage, storage_offset, size, stride):
- t = torch.tensor(
- [], dtype=storage.dtype, device=storage._untyped_storage.device
- )
- return t.set_(storage._untyped_storage, storage_offset, size, stride)
- mode = _get_current_dispatch_mode()
- if isinstance(mode, FakeTensorMode):
- # Create a real tensor and then convert it to FakeTensor.
- # We cannot directly create a FakeTensor because it tensor.set_(...)
- # is not supported in FakeTensorMode dispatcher.
- with no_dispatch():
- t = _rebuild_real_tensor(storage, storage_offset, size, stride)
- return mode.from_tensor(t)
- return _rebuild_real_tensor(storage, storage_offset, size, stride)
- # Original version of torch.load.
- self.torch_load = torch.load
- self.torch__util_rebuild_tensor = torch._utils._rebuild_tensor
- # Wrapper or modified version of torch functions.
- self.torch_load_wrapper = torch_load_wrapper
- self.torch__util_rebuild_tensor_wrapper = torch__util__rebuild_tensor_wrapper
- def __enter__(self):
- torch.load = self.torch_load_wrapper
- torch._utils._rebuild_tensor = self.torch__util_rebuild_tensor_wrapper
- self.torch_fx__symbolic_trace__wrapped_methods_to_patch = (
- torch.fx._symbolic_trace._wrapped_methods_to_patch
- )
- desired_wrapped_methods = copy.deepcopy(
- torch.fx._symbolic_trace._wrapped_methods_to_patch
- )
- if (torch.Tensor, "__getitem__") not in desired_wrapped_methods:
- # Adding `__getitem__` to the patching list will make tensor indexing traceable via
- # torch.fx.symbolic_trace. Otherwise, `tensor[x, :, y]` cannot be traced.
- # This happens because `__getitem__` is neither under torch domain nor an aten operator,
- # so the patching (or similar Proxy-generating mechanism) doesn't happen automatically.
- # Note that torch.fx.symbolic_trace defines FX_PATCH_GETITEM environment variable for
- # enabling the line below for patching.
- desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
- torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods
- def __exit__(self, exc_type, exc_value, traceback):
- torch.load = self.torch_load
- torch._utils._rebuild_tensor = self.torch__util_rebuild_tensor
- torch.fx._symbolic_trace._wrapped_methods_to_patch = (
- self.torch_fx__symbolic_trace__wrapped_methods_to_patch
- )
|