_map.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from functools import partial
  2. import torch
  3. import torch.utils._pytree as pytree
  4. from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
  5. from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
  6. from torch._ops import PyOperator
  7. from torch._subclasses.fake_tensor import FakeTensorMode
  8. from torch.fx.experimental.proxy_tensor import (
  9. disable_proxy_modes_tracing,
  10. make_fx,
  11. ProxyTorchDispatchMode,
  12. track_tensor_tree,
  13. unwrap_proxy,
  14. )
  15. from torch.utils._python_dispatch import (
  16. _get_current_dispatch_mode,
  17. _pop_mode_temporarily,
  18. )
  19. from torch.utils._pytree import tree_flatten
  20. from ._cond import _has_potential_branch_input_alias, _has_potential_branch_input_mutation, UnsupportedAliasMutationException
  21. map = PyOperator("map")
  22. def trace_map(proxy_mode, func_overload, f, xs, *args):
  23. if not isinstance(xs, torch.Tensor):
  24. raise ValueError("map() must loop over a tensor")
  25. if len(xs.shape) == 0 or xs.shape[0] == 0:
  26. raise ValueError("map() cannot be traced with scalar tensors or zero dimension tensors")
  27. if not all(isinstance(o, torch.Tensor) for o in args):
  28. raise ValueError("map() operands must be a list of tensors or modules")
  29. with disable_proxy_modes_tracing():
  30. body_graph = make_fx(f)(xs[0], *args)
  31. next_name = None
  32. i = 0
  33. while not next_name:
  34. candidate = f"body_graph_{i}"
  35. if hasattr(proxy_mode.tracer.root, candidate):
  36. i += 1
  37. else:
  38. next_name = candidate
  39. proxy_mode.tracer.root.register_module(next_name, body_graph)
  40. node_args = (body_graph, xs, *args)
  41. proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), node_args)
  42. out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
  43. name="map")
  44. outs = [body_graph(x, *args) for x in xs]
  45. # Implementation notes: we need to use new_empty() + copy_() here instead of stack() directly
  46. # because stack([...]) takes a fixed size list which will specialize dynamic shape here.
  47. # Meanwhile we want to preserve the looped over dimension as symbolic shape, such that:
  48. # ys: Tensor[s0, ...] = map(xs: Tensor[s0, ...], *args)
  49. out = outs[0].new_empty([xs.shape[0], *outs[0].shape])
  50. out.copy_(torch.stack(outs))
  51. return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
  52. @map.py_impl(DispatchKey.CUDA)
  53. @map.py_impl(DispatchKey.CPU)
  54. def map_cpu(f, xs, *args):
  55. mode = _get_current_dispatch_mode()
  56. assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
  57. return torch.stack([f(x, *args) for x in xs])
  58. @map.py_impl(DispatchKey.AutogradCUDA)
  59. @map.py_impl(DispatchKey.AutogradCPU)
  60. def map_autograd(f, xs, *args):
  61. # TODO: support autograd
  62. flat_operands, _ = tree_flatten([f, xs, args])
  63. assert all([not f.requires_grad for f in flat_operands
  64. if isinstance(f, torch.Tensor)])
  65. _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
  66. return map(f, xs, *args)
  67. @map.py_impl(ProxyTorchDispatchMode)
  68. def map_proxy_torch_dispatch_mode(f, xs, *args):
  69. mode = _get_current_dispatch_mode()
  70. assert (mode is not None), "Mode should always be enabled for python fallback key"
  71. with _pop_mode_temporarily() as mode:
  72. res = trace_map(mode, map, f, xs, *args)
  73. return res
  74. @map.py_impl(FakeTensorMode)
  75. def map_fake_tensor_mode(f, xs, *args):
  76. outs = [f(x, *args) for x in xs]
  77. return outs[0].new_empty([xs.shape[0], *outs[0].shape])
  78. # We cannot directly call fallthrough here due to issue #89037.
  79. @map.py_impl(DispatchKey.PythonDispatcher)
  80. def map_python_dispatcher(*args):
  81. _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
  82. return map(*args)
  83. @map.py_impl(torch._C._functorch.TransformType.Functionalize)
  84. def map_functionalize(interpreter, f, xs, *args):
  85. """
  86. Functionalization implementation for torch.map. Currently:
  87. 1. We don't allow any input mutation inside the map function
  88. 2. Our check for above condition is not exhaustive
  89. """
  90. reapply_views = interpreter.functionalize_add_back_views()
  91. mode = 'mutations_and_views' if reapply_views else 'mutations'
  92. # At this point, we will see functionalized tensors, so need to unwrap them first
  93. unwrapped_xs = _unwrap_all_tensors_from_functional(xs, reapply_views=reapply_views)
  94. unwrapped_args = _unwrap_all_tensors_from_functional(args, reapply_views=reapply_views)
  95. functional_map_fn = functionalize(f, remove=mode)
  96. with interpreter.lower():
  97. fake_tensor_mode = FakeTensorMode()
  98. with fake_tensor_mode as ft_mode:
  99. # Returns fake inputs for a single map function call
  100. def get_fake_inputs(unwrapped_xs, unwrapped_args):
  101. fake_xs = ft_mode.fake_tensor_converter(ft_mode, unwrapped_xs)
  102. fake_args = pytree.tree_map_only(
  103. torch.Tensor,
  104. lambda x: ft_mode.fake_tensor_converter(ft_mode, x),
  105. unwrapped_args,
  106. )
  107. return (fake_xs[0],) + fake_args
  108. fake_inputs = get_fake_inputs(unwrapped_xs, unwrapped_args)
  109. if _has_potential_branch_input_mutation(functional_map_fn, fake_inputs):
  110. raise UnsupportedAliasMutationException(
  111. "torch.map is mutating the input!"
  112. )
  113. if _has_potential_branch_input_alias(functional_map_fn, fake_inputs):
  114. raise UnsupportedAliasMutationException(
  115. "torch.map is aliasing the input!"
  116. )
  117. map_return = map(functional_map_fn, unwrapped_xs, *unwrapped_args)
  118. return _wrap_all_tensors_to_functional(map_return, level=interpreter.level())
  119. # TODO(voz) Make this automatic for keys, this is very ugly atm
  120. map.fallthrough(DispatchKey.PythonTLSSnapshot)
  121. map.fallthrough(DispatchKey.ADInplaceOrView)
  122. map.fallthrough(DispatchKey.BackendSelect)