_cond.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from dataclasses import dataclass
  2. from functools import partial
  3. import torch
  4. from torch.multiprocessing.reductions import StorageWeakRef
  5. import torch.utils._pytree as pytree
  6. from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
  7. from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
  8. from torch._ops import PyOperator
  9. from torch._subclasses.fake_tensor import FakeTensorMode
  10. from torch.fx.experimental.proxy_tensor import (
  11. disable_proxy_modes_tracing,
  12. ProxyTorchDispatchMode,
  13. make_fx,
  14. track_tensor_tree,
  15. unwrap_proxy,
  16. )
  17. from torch.fx.passes.shape_prop import _extract_tensor_metadata
  18. from torch.utils._python_dispatch import (
  19. _get_current_dispatch_mode,
  20. _pop_mode_temporarily,
  21. )
  22. from torch.utils._pytree import tree_flatten
  23. @dataclass
  24. class UnsupportedAliasMutationException(RuntimeError):
  25. reason: str
  26. """
  27. We're going to define a `cond` operation.
  28. In order to do this, we need implementations for each of the dispatch keys.
  29. """
  30. cond = PyOperator("cond")
  31. def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
  32. assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
  33. assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
  34. with disable_proxy_modes_tracing():
  35. true_graph = make_fx(true_fn)(*operands)
  36. false_graph = make_fx(false_fn)(*operands)
  37. true_outs = []
  38. false_outs = []
  39. for node in true_graph.graph.nodes:
  40. if node.op == 'output':
  41. true_outs.extend(node.args)
  42. for node in false_graph.graph.nodes:
  43. if node.op == 'output':
  44. false_outs.extend(node.args)
  45. flat_true_outs, _ = pytree.tree_flatten(true_outs)
  46. flat_false_outs, _ = pytree.tree_flatten(false_outs)
  47. assert(len(flat_true_outs) == len(flat_false_outs))
  48. for i in range(0, len(flat_true_outs)):
  49. true_out = flat_true_outs[i]
  50. false_out = flat_false_outs[i]
  51. assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
  52. # There are probably better ways - I know that create_arg has some self incrementing name
  53. # magic to it, but since we explicitly have to get the name for register_module,
  54. # I was not sure how to do that. This kinda simulates it.
  55. next_name = None
  56. i = 0
  57. while not next_name:
  58. candidate = f"true_graph_{i}"
  59. if hasattr(proxy_mode.tracer.root, candidate):
  60. i += 1
  61. else:
  62. next_name = candidate
  63. true_name = next_name
  64. false_name = f"false_graph_{i}"
  65. assert(not hasattr(proxy_mode.tracer.root, false_name))
  66. proxy_mode.tracer.root.register_module(true_name, true_graph)
  67. proxy_mode.tracer.root.register_module(false_name, false_graph)
  68. args = (pred, true_graph, false_graph, operands)
  69. proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args)
  70. out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
  71. name="conditional")
  72. # At this point, we're *guaranteed* that whether an output came from the
  73. # true or false branch is indistinguishable. So, as this is just for tracing
  74. # purposes, choose the true branch.
  75. # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
  76. # a FakeTensorMode error :
  77. # `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
  78. out = false_fn(*operands)
  79. return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
  80. @cond.py_impl(DispatchKey.CUDA)
  81. @cond.py_impl(DispatchKey.CPU)
  82. def cond_dense(pred, true_fn, false_fn, operands):
  83. mode = _get_current_dispatch_mode()
  84. assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
  85. if pred:
  86. return true_fn(*operands)
  87. else:
  88. return false_fn(*operands)
  89. @cond.py_impl(DispatchKey.AutogradCUDA)
  90. @cond.py_impl(DispatchKey.AutogradCPU)
  91. def cond_autograd(pred, true_fn, false_fn, *operands):
  92. # TODO: support autograd
  93. flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
  94. assert all([not f.requires_grad for f in flat_operands
  95. if isinstance(f, torch.Tensor)])
  96. guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
  97. return cond(pred, true_fn, false_fn, *operands)
  98. @cond.py_impl(ProxyTorchDispatchMode)
  99. def inner(pred, true_fn, false_fn, operands):
  100. mode = _get_current_dispatch_mode()
  101. assert (mode is not None), "Mode should always be enabled for python fallback key"
  102. with _pop_mode_temporarily() as mode:
  103. res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
  104. return res
  105. @cond.py_impl(FakeTensorMode)
  106. def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
  107. true_outs = true_fn(*operands)
  108. flat_true_outs, _ = pytree.tree_flatten(true_outs)
  109. flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
  110. if len(flat_true_outs) != len(flat_false_outs):
  111. raise RuntimeError("Unmatched number of outputs from cond() branches.")
  112. for true_out, false_out in zip(flat_true_outs, flat_false_outs):
  113. true_meta = _extract_tensor_metadata(true_out)
  114. false_meta = _extract_tensor_metadata(false_out)
  115. if true_meta != false_meta:
  116. raise RuntimeError(
  117. f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}")
  118. return true_outs
  119. # We cannot directly call fallthrough here due to issue #89037.
  120. @cond.py_impl(DispatchKey.PythonDispatcher)
  121. def cond_python_dispatcher(*args):
  122. _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
  123. return cond(*args)
  124. def _has_potential_branch_input_mutation(branch, fake_inputs):
  125. """
  126. Dispatch-trace the branch with fake inputs and check if
  127. producing graph has mutable op on the input. This is
  128. bit restrictive as the branch must be traceable.
  129. """
  130. try:
  131. gm = make_fx(branch)(*fake_inputs)
  132. except UnsupportedAliasMutationException:
  133. # this can happen when nested cond is
  134. # functionalized
  135. return True
  136. except Exception as e:
  137. raise e
  138. input_nodes = set()
  139. for node in gm.graph.nodes:
  140. if node.op == "placeholder":
  141. input_nodes.add(node)
  142. if node.op == "call_function":
  143. target = node.target
  144. if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
  145. for arg in node.args:
  146. if arg in input_nodes:
  147. return True
  148. return False
  149. def _has_potential_branch_input_alias(branch, fake_inputs):
  150. """
  151. Dispatch-trace the branch with fake inputs and check if
  152. producing graph has output aliasing the branch input. This is
  153. bit restrictive as the branch must be traceable.
  154. """
  155. try:
  156. gm = make_fx(branch)(*fake_inputs)
  157. except UnsupportedAliasMutationException:
  158. # this can happen when nested cond is
  159. # functionalized
  160. return True
  161. except Exception as e:
  162. raise e
  163. input_storages = set()
  164. for node in gm.graph.nodes:
  165. if node.op == "placeholder":
  166. input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
  167. outs, _ = pytree.tree_flatten(gm(*fake_inputs))
  168. for out in outs:
  169. if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
  170. return True
  171. return False
  172. @cond.py_impl(torch._C._functorch.TransformType.Functionalize)
  173. def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
  174. """
  175. Functionalization implementation for torch.cond. Currently:
  176. 1. We don't allow any input mutation inside the branches
  177. 2. Our check for above condition is not exhaustive
  178. """
  179. reapply_views = interpreter.functionalize_add_back_views()
  180. mode = 'mutations_and_views' if reapply_views else 'mutations'
  181. # At this point, we will see functionalized tensors, so need to unwrap them first
  182. unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
  183. unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
  184. functional_true_fn = functionalize(true_fn, remove=mode)
  185. functional_false_fn = functionalize(false_fn, remove=mode)
  186. with interpreter.lower():
  187. fake_tensor_mode = FakeTensorMode()
  188. with fake_tensor_mode as ft_mode:
  189. for branch in [functional_true_fn, functional_false_fn]:
  190. def convert(x):
  191. return ft_mode.fake_tensor_converter(ft_mode, x)
  192. fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
  193. if _has_potential_branch_input_mutation(branch, fake_inputs):
  194. raise UnsupportedAliasMutationException("One of torch.cond branch "
  195. "might be modifying the input!")
  196. for branch in [true_fn, false_fn]:
  197. def convert(x):
  198. return ft_mode.fake_tensor_converter(ft_mode, x)
  199. fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
  200. if _has_potential_branch_input_alias(branch, fake_inputs):
  201. raise UnsupportedAliasMutationException("One of torch.cond branch "
  202. "might be aliasing the input!")
  203. cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
  204. return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
  205. # TODO(voz): Make this automatic for keys, this is very ugly atm
  206. cond.fallthrough(DispatchKey.PythonTLSSnapshot)
  207. cond.fallthrough(DispatchKey.ADInplaceOrView)
  208. cond.fallthrough(DispatchKey.BackendSelect)