123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- from dataclasses import dataclass
- from functools import partial
- import torch
- from torch.multiprocessing.reductions import StorageWeakRef
- import torch.utils._pytree as pytree
- from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
- from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
- from torch._ops import PyOperator
- from torch._subclasses.fake_tensor import FakeTensorMode
- from torch.fx.experimental.proxy_tensor import (
- disable_proxy_modes_tracing,
- ProxyTorchDispatchMode,
- make_fx,
- track_tensor_tree,
- unwrap_proxy,
- )
- from torch.fx.passes.shape_prop import _extract_tensor_metadata
- from torch.utils._python_dispatch import (
- _get_current_dispatch_mode,
- _pop_mode_temporarily,
- )
- from torch.utils._pytree import tree_flatten
- @dataclass
- class UnsupportedAliasMutationException(RuntimeError):
- reason: str
- """
- We're going to define a `cond` operation.
- In order to do this, we need implementations for each of the dispatch keys.
- """
- cond = PyOperator("cond")
- def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
- assert isinstance(operands, (list, tuple)), "Cond operands must be a list or tuple of tensors"
- assert all(isinstance(o, torch.Tensor) for o in operands), "Cond operands must be a list of tensors"
- with disable_proxy_modes_tracing():
- true_graph = make_fx(true_fn)(*operands)
- false_graph = make_fx(false_fn)(*operands)
- true_outs = []
- false_outs = []
- for node in true_graph.graph.nodes:
- if node.op == 'output':
- true_outs.extend(node.args)
- for node in false_graph.graph.nodes:
- if node.op == 'output':
- false_outs.extend(node.args)
- flat_true_outs, _ = pytree.tree_flatten(true_outs)
- flat_false_outs, _ = pytree.tree_flatten(false_outs)
- assert(len(flat_true_outs) == len(flat_false_outs))
- for i in range(0, len(flat_true_outs)):
- true_out = flat_true_outs[i]
- false_out = flat_false_outs[i]
- assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
- # There are probably better ways - I know that create_arg has some self incrementing name
- # magic to it, but since we explicitly have to get the name for register_module,
- # I was not sure how to do that. This kinda simulates it.
- next_name = None
- i = 0
- while not next_name:
- candidate = f"true_graph_{i}"
- if hasattr(proxy_mode.tracer.root, candidate):
- i += 1
- else:
- next_name = candidate
- true_name = next_name
- false_name = f"false_graph_{i}"
- assert(not hasattr(proxy_mode.tracer.root, false_name))
- proxy_mode.tracer.root.register_module(true_name, true_graph)
- proxy_mode.tracer.root.register_module(false_name, false_graph)
- args = (pred, true_graph, false_graph, operands)
- proxy_args = pytree.tree_map(partial(unwrap_proxy, proxy_mode), args)
- out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {},
- name="conditional")
- # At this point, we're *guaranteed* that whether an output came from the
- # true or false branch is indistinguishable. So, as this is just for tracing
- # purposes, choose the true branch.
- # TODO: Uhh.... it shouldn't matter, but changing this to true_fn results in
- # a FakeTensorMode error :
- # `Current active mode <class 'torch._subclasses.fake_tensor.FakeTensorMode'> not registered`
- out = false_fn(*operands)
- return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
- @cond.py_impl(DispatchKey.CUDA)
- @cond.py_impl(DispatchKey.CPU)
- def cond_dense(pred, true_fn, false_fn, operands):
- mode = _get_current_dispatch_mode()
- assert (mode is None), "Mode should never be enabled for CPU/CUDA key"
- if pred:
- return true_fn(*operands)
- else:
- return false_fn(*operands)
- @cond.py_impl(DispatchKey.AutogradCUDA)
- @cond.py_impl(DispatchKey.AutogradCPU)
- def cond_autograd(pred, true_fn, false_fn, *operands):
- # TODO: support autograd
- flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
- assert all([not f.requires_grad for f in flat_operands
- if isinstance(f, torch.Tensor)])
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
- return cond(pred, true_fn, false_fn, *operands)
- @cond.py_impl(ProxyTorchDispatchMode)
- def inner(pred, true_fn, false_fn, operands):
- mode = _get_current_dispatch_mode()
- assert (mode is not None), "Mode should always be enabled for python fallback key"
- with _pop_mode_temporarily() as mode:
- res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
- return res
- @cond.py_impl(FakeTensorMode)
- def cond_fake_tensor_mode(pred, true_fn, false_fn, operands):
- true_outs = true_fn(*operands)
- flat_true_outs, _ = pytree.tree_flatten(true_outs)
- flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands))
- if len(flat_true_outs) != len(flat_false_outs):
- raise RuntimeError("Unmatched number of outputs from cond() branches.")
- for true_out, false_out in zip(flat_true_outs, flat_false_outs):
- true_meta = _extract_tensor_metadata(true_out)
- false_meta = _extract_tensor_metadata(false_out)
- if true_meta != false_meta:
- raise RuntimeError(
- f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}")
- return true_outs
- # We cannot directly call fallthrough here due to issue #89037.
- @cond.py_impl(DispatchKey.PythonDispatcher)
- def cond_python_dispatcher(*args):
- _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher))
- return cond(*args)
- def _has_potential_branch_input_mutation(branch, fake_inputs):
- """
- Dispatch-trace the branch with fake inputs and check if
- producing graph has mutable op on the input. This is
- bit restrictive as the branch must be traceable.
- """
- try:
- gm = make_fx(branch)(*fake_inputs)
- except UnsupportedAliasMutationException:
- # this can happen when nested cond is
- # functionalized
- return True
- except Exception as e:
- raise e
- input_nodes = set()
- for node in gm.graph.nodes:
- if node.op == "placeholder":
- input_nodes.add(node)
- if node.op == "call_function":
- target = node.target
- if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable:
- for arg in node.args:
- if arg in input_nodes:
- return True
- return False
- def _has_potential_branch_input_alias(branch, fake_inputs):
- """
- Dispatch-trace the branch with fake inputs and check if
- producing graph has output aliasing the branch input. This is
- bit restrictive as the branch must be traceable.
- """
- try:
- gm = make_fx(branch)(*fake_inputs)
- except UnsupportedAliasMutationException:
- # this can happen when nested cond is
- # functionalized
- return True
- except Exception as e:
- raise e
- input_storages = set()
- for node in gm.graph.nodes:
- if node.op == "placeholder":
- input_storages.add(StorageWeakRef(node.meta['val']._typed_storage()))
- outs, _ = pytree.tree_flatten(gm(*fake_inputs))
- for out in outs:
- if isinstance(out, torch.Tensor) and StorageWeakRef(out._typed_storage()) in input_storages:
- return True
- return False
- @cond.py_impl(torch._C._functorch.TransformType.Functionalize)
- def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
- """
- Functionalization implementation for torch.cond. Currently:
- 1. We don't allow any input mutation inside the branches
- 2. Our check for above condition is not exhaustive
- """
- reapply_views = interpreter.functionalize_add_back_views()
- mode = 'mutations_and_views' if reapply_views else 'mutations'
- # At this point, we will see functionalized tensors, so need to unwrap them first
- unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
- unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
- functional_true_fn = functionalize(true_fn, remove=mode)
- functional_false_fn = functionalize(false_fn, remove=mode)
- with interpreter.lower():
- fake_tensor_mode = FakeTensorMode()
- with fake_tensor_mode as ft_mode:
- for branch in [functional_true_fn, functional_false_fn]:
- def convert(x):
- return ft_mode.fake_tensor_converter(ft_mode, x)
- fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
- if _has_potential_branch_input_mutation(branch, fake_inputs):
- raise UnsupportedAliasMutationException("One of torch.cond branch "
- "might be modifying the input!")
- for branch in [true_fn, false_fn]:
- def convert(x):
- return ft_mode.fake_tensor_converter(ft_mode, x)
- fake_inputs = pytree.tree_map_only(torch.Tensor, convert, unwrapped_inputs)
- if _has_potential_branch_input_alias(branch, fake_inputs):
- raise UnsupportedAliasMutationException("One of torch.cond branch "
- "might be aliasing the input!")
- cond_return = cond(unwrapped_pred, functional_true_fn, functional_false_fn, unwrapped_inputs)
- return _wrap_all_tensors_to_functional(cond_return, level=interpreter.level())
- # TODO(voz): Make this automatic for keys, this is very ugly atm
- cond.fallthrough(DispatchKey.PythonTLSSnapshot)
- cond.fallthrough(DispatchKey.ADInplaceOrView)
- cond.fallthrough(DispatchKey.BackendSelect)
|