123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- import functools
- from contextlib import nullcontext
- from typing import Any, Callable, Dict, Sequence
- from warnings import warn
- import torch
- import torch._decomp
- import torch._prims
- import torch._refs
- import torch._refs.nn
- import torch._refs.nn.functional
- import torch._refs.special
- import torch.overrides
- from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport
- from torch._prims_common import torch_function_passthrough
- from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
- @functools.lru_cache(None)
- def torch_to_refs_map():
- """
- Mapping of torch API functions to torch._refs functions.
- E.g. torch_to_refs_map()[torch.add] == torch._refs.add
- """
- modules = [
- (torch, torch._refs),
- (torch.nn, torch._refs.nn),
- (torch.nn.functional, torch._refs.nn.functional),
- (torch.special, torch._refs.special),
- (torch.fft, torch._refs.fft),
- (torch.linalg, torch._refs.linalg),
- ]
- r: Dict[Any, Any] = {
- torch.Tensor.__invert__: torch._refs.bitwise_not,
- torch.Tensor.__xor__: torch._refs.bitwise_xor,
- torch.Tensor.__and__: torch._refs.bitwise_and,
- torch.Tensor.__or__: torch._refs.bitwise_or,
- torch.Tensor.__eq__: torch._refs.eq,
- torch.Tensor.__rsub__: torch._refs.rsub,
- torch.Tensor.__rtruediv__: torch._refs.rtruediv,
- torch.Tensor.__floordiv__: torch._refs.floor_divide,
- torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
- torch.Tensor.__pow__: torch._refs.pow,
- torch.Tensor.__rpow__: torch._refs.rpow,
- torch.Tensor.new_empty: torch._refs.new_empty,
- torch.Tensor.new_full: torch._refs.new_full,
- torch.Tensor.new_zeros: torch._refs.new_zeros,
- torch.Tensor.new_ones: torch._refs.new_ones,
- torch.Tensor.fill_: torch._refs.fill_,
- torch.Tensor.zero_: torch._refs.zero_,
- torch.Tensor.to: torch._refs.to,
- torch.Tensor.sum_to_size: torch._refs.sum_to_size,
- # TODO: Should these methods be mapped some other way?
- torch.Tensor.copy_: torch._prims.copy_to,
- torch.Tensor.resize: torch._prims.resize,
- }
- for mod_torch, mod_refs in modules:
- for s in mod_refs.__all__: # type: ignore[attr-defined]
- r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
- # Support remapping torch.Tensor.foo to _refs.foo
- for s in dir(torch.Tensor):
- if s in torch._refs.__all__:
- r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
- # Support conversions
- for s in torch._refs._conversions.__all__:
- tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
- r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
- return r
- @functools.lru_cache(None)
- def all_prims():
- """
- Set of all prim functions, e.g., torch._prims.add in all_prims()
- """
- return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
- class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
- """
- Switches the interpretation of torch.ops.prims.* functions to
- use nvFuser's prims in torch.ops.nvprims.*
- >>> # xdoctest: +SKIP("undefined vars")
- >>> with NvfuserPrimsMode():
- ... torch.ops.prims.add(x, y) # calls torch.ops.nvprims.add(x, y)
- By default, this context manager will fall back on the torch.ops.prims* if the
- nvprim does not exist.
- It's possible to skip certain prims by passing their names to the skip_ops
- argument. skip_ops is expected to be a sequence of strings, e.g.,
- ["prims.add.default"] In order to check the expected name of a prim, one can
- use the `torch.overrides.resolve_name`.
- >>> # xdoctest: +SKIP("undefined vars")
- >>> with NvfuserPrimsMode(skips_ops=("prims.add.default")):
- ... torch.ops.prims.add.default(x, y) # does not call torch.ops.nvprims.add.default(x, y)
- """
- def __init__(self, *, skip_ops=()):
- self.skip_ops = skip_ops
- def __torch_function__(
- self,
- orig_func: Callable,
- types: Sequence,
- args: Sequence[Any] = (),
- kwargs: Dict = None,
- ):
- if kwargs is None:
- kwargs = {}
- # If the function is in the skip list, then we don't want to
- # remap it to the nvprims.
- if torch.overrides.resolve_name(orig_func) in self.skip_ops:
- return orig_func(*args, **kwargs)
- if isinstance(orig_func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
- namespace = str(orig_func).split(".")[0]
- name = str(orig_func).split(".")[1]
- if namespace == "prims":
- nvfunc = getattr(torch.ops.nvprims, name, None)
- if nvfunc is not None:
- return nvfunc(*args, **kwargs)
- return orig_func(*args, **kwargs)
- class TorchRefsMode(torch.overrides.TorchFunctionMode):
- """
- Switches the interpretation of torch.* functions and Tensor methods to
- use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.)
- >>> # xdoctest: +SKIP
- >>> with TorchRefsMode():
- ... torch.add(x, y) # calls torch._refs.add(x, y)
- By default, this context manager will fall back on the torch.* if the
- ref does not exist; set strict=True to error if this occurs.
- If the ref exists we still would like to fall back on the torch.* sometimes,
- this behavior can be customized by passing a function to should_fallback_fn.
- """
- def __init__(
- self,
- strict=False,
- should_fallback_fn=lambda *_: False,
- prims_mode_cls=nullcontext,
- ):
- self.strict = strict
- self.should_fallback_fn = should_fallback_fn
- self.prims_mode_cls = prims_mode_cls
- def __torch_function__(
- self,
- orig_func: Callable,
- types: Sequence,
- args: Sequence[Any] = (),
- kwargs: Dict = None,
- ):
- if kwargs is None:
- kwargs = {}
- # For primitive operations, run them as is without interception
- # Unless we are in prims_mode, in which case we want to use nvprims
- if orig_func in torch_function_passthrough or orig_func in all_prims():
- with self.prims_mode_cls():
- return orig_func(*args, **kwargs)
- mapping = torch_to_refs_map()
- func = mapping.get(orig_func, None)
- # For torch.ops.aten.*, use registered decompositions from torch._decomp
- # torch._decomp.decomposition_table provides a mapping from
- # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
- # implementations.
- # There're other ways to implement this functionality,
- # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
- if func is None and isinstance(orig_func, torch._ops.OpOverload):
- func = torch._decomp.decomposition_table.get(orig_func, None)
- if func is not None:
- # If the ref exists query whether we should use it or not
- if self.should_fallback_fn(self, orig_func, func, args, kwargs):
- return orig_func(*args, **kwargs)
- # torch calls inside func should be interpreted as refs calls
- with self:
- return func(*args, **kwargs)
- if self.strict:
- raise RuntimeError(
- f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
- )
- return orig_func(*args, **kwargs)
- def _is_node_supported_nvfuser(node):
- return (
- node.op == "call_function"
- and getattr(node.target, "impl_nvfuser", None) is not None
- )
- def _is_func_unsupported_nvfuser(
- torch_function_mode, orig_func, func, args, kwargs, *, skip_ops=()
- ):
- """
- This function traces the `func` under `torch_function_mode` and checks if
- any of the traced nodes are not supported by nvFuser. If so, we should
- fallback to the original function.
- `skip_ops` argument is expected to be a list of strings of function names
- that would match with `torch.overrides.resolve_name`.
- Args:
- torch_function_mode: The torch_function_mode context manager. orig_func:
- The original function, its name will be used to check if
- it should be skipped.
- func: The function to be traced. args: The args to be passed to the
- function. kwargs: The kwargs to be passed to the function.
- Keyword args:
- skip_ops: A list of ops to skip when checking if the function is
- supported.
- """
- # One supported case is easy to check: if the resolved name of the original
- # function in the skip list, skip it.
- if torch.overrides.resolve_name(orig_func) in skip_ops:
- return True
- with torch_function_mode:
- try:
- gm = get_isolated_graphmodule(func, args, kwargs)
- except Exception as e:
- warn(
- "get_isolated_graphmodule failed on decomposition: "
- + func.__name__
- + " with error message: "
- + str(e)
- )
- # returns unsupported when tracing fails.
- return True
- supported_ops = NvfuserPrimOperatorSupport()
- call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
- any_unsupported = any(
- not supported_ops.is_node_supported(None, node) for node in call_function_nodes
- )
- return any_unsupported
- class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
- def __init__(self, *, skip_ops=()):
- aten_ops_to_skip = (
- "aten._log_softmax.default",
- "aten._log_softmax_backward_data.default",
- "aten.expand.default",
- )
- self.skip_ops = tuple(skip_ops) + aten_ops_to_skip
- super().__init__(
- strict=False,
- should_fallback_fn=functools.partial(
- _is_func_unsupported_nvfuser,
- skip_ops=tuple(skip_ops) + aten_ops_to_skip,
- ),
- prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
- )
- # TODO: remove this once version from _decomp/decompositions.py is working
- # with this context manager
- # This is a workaround for AOT Autograd graphs
- def _cudnn_batch_norm(
- self,
- input,
- weight,
- bias,
- running_mean,
- running_var,
- training,
- exponential_average_factor,
- epsilon,
- ):
- a, b, c = torch.ops.nvprims.native_batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- training,
- exponential_average_factor,
- epsilon,
- )
- if training:
- return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
- return (
- a,
- weight.new_zeros((0,)),
- weight.new_zeros((0,)),
- input.new_zeros((0,), dtype=torch.uint8),
- )
- # This is a workaround for AOT Autograd graphs
- def _cudnn_batch_norm_backward(
- self,
- input,
- grad_output,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_var,
- epsilon,
- reserveSpace,
- ):
- func = torch._decomp.decomposition_table[
- torch.ops.aten.native_batch_norm_backward.default
- ]
- return func(
- grad_output,
- input,
- weight,
- running_mean,
- running_var,
- save_mean,
- save_var,
- True,
- epsilon,
- [True, True, True],
- )
- def _is_var_mean(self, func):
- return "torch.var_mean" == torch.overrides.resolve_name(func) or (
- (isinstance(func, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)))
- and "aten.var_mean" in str(func)
- )
- def _is_view_or_reshape(self, func):
- allowed_ops = {
- "torch.Tensor.view",
- "torch.Tensor.reshape",
- "torch.view_copy",
- "torch.reshape",
- "aten.view.default",
- "aten._unsafe_view.default",
- "aten.view_copy.default",
- } - set(self.skip_ops)
- return torch.overrides.resolve_name(func) in allowed_ops
- def _is_native_batch_norm(self, func):
- return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
- func == torch.ops.aten.native_batch_norm.default
- or func == torch.ops.aten.native_batch_norm
- )
- def _is_rand_like(self, func):
- result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
- func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
- )
- return result
- def _is_full(self, func):
- result = "torch.full" == torch.overrides.resolve_name(func) or (
- func
- in [
- torch.ops.aten.full,
- torch.ops.aten.full.names,
- ]
- )
- return result
- def __torch_function__(
- self,
- orig_func: Callable,
- types: Sequence,
- args: Sequence[Any] = (),
- kwargs: Dict = None,
- ):
- if kwargs is None:
- kwargs = {}
- # First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
- if self._is_var_mean(orig_func):
- return torch.ops.nvprims.var_mean(*args, **kwargs)
- if (
- orig_func == torch.ops.aten.cudnn_batch_norm.default
- or orig_func == torch.ops.aten.cudnn_batch_norm
- ):
- with self:
- return self._cudnn_batch_norm(*args, **kwargs)
- # A workaround for AOT Autograd graphs
- # See https://github.com/pytorch/pytorch/pull/86115#issue-1394883782
- if (
- orig_func == torch.ops.aten.cudnn_batch_norm_backward.default
- or orig_func == torch.ops.aten.cudnn_batch_norm_backward
- ):
- with self:
- return self._cudnn_batch_norm_backward(*args, **kwargs)
- if self._is_view_or_reshape(orig_func):
- a, *shape = args
- shape = torch._prims_common.extract_shape_from_varargs(
- shape, validate=False
- ) # type: ignore[assignment]
- if len(kwargs) > 0:
- warn("view has ignored kwargs!")
- return torch.ops.nvprims.view(a, shape)
- if orig_func == torch.ops.aten._reshape_alias.default:
- a, shape, stride = args
- if len(kwargs) > 0:
- warn("view has ignored kwargs!")
- return torch.ops.nvprims.view(a, shape)
- if self._is_native_batch_norm(orig_func):
- return torch.ops.nvprims.native_batch_norm(*args, **kwargs)
- if self._is_rand_like(orig_func):
- if len(kwargs) > 0:
- warn("rand_like has ignored kwargs!")
- return torch.ops.nvprims.rand_like(*args)
- if self._is_full(orig_func):
- return torch.ops.nvprims.full(*args, **kwargs)
- # Then we use TorchRefsMode to interpret the rest
- return super().__torch_function__(orig_func, types, args, kwargs)
|