123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831 |
- # Module for defining "primitive" operations executable by the nvFuser. This
- # list exists to decouple main set of primitives from the ones that provide a
- # lowering of the op to nvFuser’s Python interface. Mostly torch.ops.nvprims is
- # a subset of the primitives in torch.ops.prims, but some additional primitives
- # can be added in the future for the corresponding higher-level torch/aten
- # functions.
- from typing import Any, Dict, Optional, Tuple
- import torch
- import torch._prims_common as utils
- from torch._prims_common import (
- DimsSequenceType,
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- getnvFuserDtype,
- make_contiguous_strides_for,
- NumberType,
- ShapeType,
- TensorLikeType,
- )
- from torch._prims_common.wrappers import (
- _maybe_convert_to_dtype,
- backwards_not_supported,
- elementwise_type_promotion_wrapper,
- )
- nvprim_namespace = "nvprims"
- nvprim = torch.library.Library(nvprim_namespace, "DEF")
- nvprim_impl = torch.library.Library(
- nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
- )
- nvprim_implicit_impl = torch.library.Library(
- nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
- )
- nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
- nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
- nvprim_names = [
- "abs",
- "acos",
- "asin",
- "atan",
- "atanh",
- "cos",
- "cosh",
- "clone",
- "bitwise_not",
- "ceil",
- "erf",
- "erfc",
- "exp",
- "expm1",
- "floor",
- "imag",
- "isfinite",
- "lgamma",
- "log",
- "log1p",
- "log2",
- "log10",
- "real",
- "reciprocal",
- "neg",
- "round",
- "rsqrt",
- "sign",
- "sin",
- "sinh",
- "sqrt",
- "tan",
- "tanh",
- "transpose",
- "trunc",
- "add",
- "atan2",
- "bitwise_and",
- "bitwise_or",
- "bitwise_xor",
- "div",
- "eq",
- "fmod",
- "ge",
- "gt",
- "le",
- "lt",
- "mul",
- "ne",
- "pow",
- "remainder",
- "sub",
- "squeeze",
- "view_of",
- "broadcast_in_dim",
- "where",
- "convert_element_type",
- "sum",
- "var",
- "amax",
- "amin",
- ]
- _nvfuser_impls: Dict[str, Any] = {}
- _nvfuser_unary_ops = {
- "abs",
- "acos",
- "asin",
- "atan",
- "atanh",
- "cos",
- "cosh",
- "bitwise_not",
- "ceil",
- "erf",
- "erfc",
- "exp",
- "expm1",
- "floor",
- "imag",
- "isfinite",
- "lgamma",
- "log",
- "log1p",
- "log2",
- "log10",
- "reciprocal",
- "neg",
- "real",
- "round",
- "rsqrt",
- "sign",
- "sin",
- "sinh",
- "sqrt",
- "tan",
- "tanh",
- "trunc",
- }
- def _assert_nvfuser_op_exists(fname: str):
- try:
- from nvfuser._C import FusionDefinition as fd # type: ignore[import]
- assert getattr(fd.Operators, fname)
- except ImportError:
- # Not all PyTorch builds have nvfuser
- pass
- for fname in _nvfuser_unary_ops:
- exec(
- f"""
- # Ensure that the nvfuser implementation exists
- _assert_nvfuser_op_exists("{fname}")
- def _{fname}_nvfuser(fd, a):
- return fd.ops.{fname}(a) # type: ignore[attr-defined]
- _nvfuser_impls["{fname}"] = _{fname}_nvfuser
- """
- )
- _nvfuser_binary_ops = {
- "add",
- "atan2",
- "bitwise_and",
- "bitwise_or",
- "bitwise_xor",
- "div",
- "eq",
- "fmod",
- "ge",
- "gt",
- "le",
- "lt",
- "mul",
- "ne",
- "pow",
- "remainder",
- "sub",
- }
- for fname in _nvfuser_binary_ops:
- exec(
- f"""
- # Ensure that the nvfuser implementation exists
- _assert_nvfuser_op_exists("{fname}")
- def _{fname}_nvfuser(fd, a, b):
- return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
- _nvfuser_impls["{fname}"] = _{fname}_nvfuser
- """
- )
- _nvfuser_ternary_ops = {
- "where",
- }
- for fname in _nvfuser_ternary_ops:
- exec(
- f"""
- # Ensure that the nvfuser implementation exists
- _assert_nvfuser_op_exists("{fname}")
- def _{fname}_nvfuser(fd, a, b, c):
- return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
- _nvfuser_impls["{fname}"] = _{fname}_nvfuser
- """
- )
- def _native_batch_norm_nvfuser(
- fd, input, weight, bias, running_mean, running_var, training, momentum, eps
- ):
- """
- if weight is None:
- weight = fd.define_null_tensor()
- if bias is None:
- bias = fd.define_null_tensor()
- if running_mean is None:
- running_mean = fd.define_null_tensor()
- if running_var is None:
- running_var = fd.define_null_tensor()
- """
- return fd.ops.batch_norm(
- input,
- weight,
- bias,
- running_mean,
- running_var,
- momentum,
- eps,
- training,
- )
- def _broadcast_in_dim_nvfuser(
- fd: Any,
- a: TensorLikeType,
- shape: ShapeType,
- broadcast_dimensions: ShapeType,
- ):
- return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
- def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
- nvfuser_dtype = getnvFuserDtype(dtype)
- return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
- def _transpose_nvfuser(fd, a, dims):
- return fd.ops.permute(a, dims) # type: ignore[attr-defined]
- def _squeeze_nvfuser(fd, a, a_shape, dimensions):
- for idx in sorted(dimensions, reverse=True):
- a = fd.ops.squeeze(a, a_shape, idx)
- a_shape = a_shape[:idx] + a_shape[idx + 1 :]
- return a
- def _view_of_nvfuser(fd, a):
- return fd.ops.set(a)
- def _view_nvfuser(
- fd,
- a,
- a_shape,
- new_shape,
- ):
- return fd.ops.view(a, a_shape, new_shape)
- def _sum_nvfuser(
- fd: Any,
- a: TensorLikeType,
- dims: DimsSequenceType,
- ):
- keep_dims = False
- from nvfuser._C import DataType # type: ignore[import]
- output_dtype = DataType.Null
- return fd.ops.sum(a, dims, keep_dims, output_dtype)
- def _var_nvfuser(
- fd: Any,
- a: TensorLikeType,
- dims: DimsSequenceType,
- *,
- correction: int,
- ):
- keep_dims = False
- return fd.ops.var(a, dims, correction, keep_dims)
- def _var_mean_nvfuser(
- fd: Any,
- a: TensorLikeType,
- dims: DimsSequenceType,
- unbiased: Optional[bool] = None,
- keepdim: bool = False,
- *,
- correction: int,
- ):
- # Unbiased arg shouldn't be set when this function is called
- assert unbiased is None
- # Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
- # keepdim is handled by the reference implementation
- keepdim = False
- return fd.ops.var_mean(a, dims, correction, keepdim)
- def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
- return fd.ops.rand_like(a)
- def _amax_nvfuser(
- fd: Any,
- a: TensorLikeType,
- dims: DimsSequenceType,
- ):
- keep_dims = False
- return fd.ops.max(a, dims, keep_dims)
- def _amin_nvfuser(
- fd: Any,
- a: TensorLikeType,
- dims: DimsSequenceType,
- ):
- keep_dims = False
- return fd.ops.min(a, dims, keep_dims)
- def _clone_nvfuser(fd: Any, input: TensorLikeType, *, memory_format=None):
- return fd.ops.set(input)
- def _full_nvfuser(
- fd: Any,
- shape: ShapeType,
- fill_value: NumberType,
- *,
- dtype: Optional[torch.dtype] = None,
- layout: Optional[torch.layout] = None,
- device: Optional[torch.device] = None,
- pin_memory: bool = False,
- requires_grad: bool = False,
- ):
- assert device != torch.device("cpu")
- assert layout is None or layout is torch.strided
- assert pin_memory is False
- assert requires_grad is False
- dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
- nvfuser_dtype = getnvFuserDtype(dtype)
- return fd.ops.full(shape, fill_value, nvfuser_dtype)
- _nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
- _nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
- _nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
- _nvfuser_impls["clone"] = _clone_nvfuser
- _nvfuser_impls["transpose"] = _transpose_nvfuser
- _nvfuser_impls["squeeze"] = _squeeze_nvfuser
- _nvfuser_impls["view_of"] = _view_of_nvfuser
- _nvfuser_impls["view"] = _view_nvfuser
- _nvfuser_impls["rand_like"] = _rand_like_nvfuser
- _nvfuser_impls["sum"] = _sum_nvfuser
- _nvfuser_impls["var"] = _var_nvfuser
- _nvfuser_impls["var_mean"] = _var_mean_nvfuser
- _nvfuser_impls["amax"] = _amax_nvfuser
- _nvfuser_impls["amin"] = _amin_nvfuser
- _nvfuser_impls["full"] = _full_nvfuser
- def register_full():
- name = "full"
- nvprim.define(
- "full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, "
- + "bool? pin_memory=None, bool? requires_grad=None) -> Tensor"
- )
- def _meta_impl(
- size,
- fill_value,
- *,
- out=None,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=False,
- requires_grad=False,
- ):
- strides = make_contiguous_strides_for(size)
- return torch._prims.TensorMeta(
- None,
- shape=size,
- strides=strides,
- dtype=dtype,
- device=device,
- )
- def _prim_impl(
- size,
- fill_value,
- *,
- out=None,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=False,
- requires_grad=False,
- ):
- return torch.full(
- size,
- fill_value,
- out=out,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- requires_grad=requires_grad,
- )
- nvprim_impl.impl(name, _prim_impl)
- nvprim_meta_impl.impl(name, _meta_impl)
- prim_packet = getattr(torch._ops.ops.nvprims, name)
- prim = prim_packet.default
- nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
- for p in (prim_packet, prim):
- p.__doc__ = "Create a tensor with given size and filled with value"
- p.impl_nvfuser = _nvfuser_impls["full"]
- p.is_recomputable = _nvfuser_is_recomputable["full"]
- p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
- # functorch.compile.min_cut_rematerialization_partition accepts a list of
- # operators that can be recomputed in the backward pass. This list is used to
- # determine which operators can be recomputed. If an operator is not in this
- # list, it will not be recomputed.
- _nvfuser_is_recomputable: Dict[str, bool] = {
- # Reductions are not allowed to be recomputed
- "amax": False,
- "amin": False,
- "sum": False,
- "var": False,
- "var_mean": False,
- # Normalizations are not allowed to be recomputed
- "native_batch_norm": False,
- # Random ops are not allowed to be recomputed
- "rand_like": False,
- # Everything else is allowed to be recomputed
- "abs": True,
- "acos": True,
- "add": True,
- "asin": True,
- "atan": True,
- "atan2": True,
- "atanh": True,
- "bitwise_and": True,
- "bitwise_not": True,
- "bitwise_or": True,
- "bitwise_xor": True,
- "broadcast_in_dim": True,
- "ceil": True,
- "clone": True,
- "convert_element_type": True,
- "cos": True,
- "cosh": True,
- "div": True,
- "eq": True,
- "erf": True,
- "erfc": True,
- "exp": True,
- "expm1": True,
- "floor": True,
- "fmod": True,
- "full": True,
- "ge": True,
- "gt": True,
- "imag": True,
- "isfinite": True,
- "le": True,
- "lgamma": True,
- "log": True,
- "log10": True,
- "log1p": True,
- "log2": True,
- "lt": True,
- "mul": True,
- "ne": True,
- "neg": True,
- "pow": True,
- "real": True,
- "reciprocal": True,
- "remainder": True,
- "round": True,
- "rsqrt": True,
- "sign": True,
- "sin": True,
- "sinh": True,
- "sqrt": True,
- "squeeze": True,
- "sub": True,
- "tan": True,
- "tanh": True,
- "transpose": True,
- "trunc": True,
- "view": True,
- "view_of": True,
- "where": True,
- }
- def register_native_batch_norm():
- """This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
- name = "native_batch_norm"
- nvprim.define(
- f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
- + "bool training, float momentum, float eps)"
- + " -> (Tensor, Tensor, Tensor)"
- )
- def _prim_impl(
- input, weight, bias, running_mean, running_var, training, momentum, eps
- ):
- return torch.native_batch_norm(
- input, weight, bias, running_mean, running_var, training, momentum, eps
- )
- nvprim_impl.impl(name, _prim_impl)
- prim_packet = torch._ops.ops.nvprims.native_batch_norm
- prim = prim_packet.default
- def _native_batch_norm_ref(
- input: torch.Tensor,
- weight: Optional[torch.Tensor],
- bias: Optional[torch.Tensor],
- running_mean: Optional[torch.Tensor],
- running_var: Optional[torch.Tensor],
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- if torch._prims_common.is_complex_dtype(input.dtype):
- raise NotImplementedError("Complex tensors are not supported")
- # note: BN only promotes input to dtype of weight/bias, but keeps the same output dtype
- result_dtype = input.dtype
- computation_dtype, _ = elementwise_dtypes(
- input,
- weight,
- bias,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
- )
- input_ = _maybe_convert_to_dtype(input, computation_dtype)
- output, mean, rstd = prim(
- input_, weight, bias, running_mean, running_var, training, momentum, eps
- )
- output_ = _maybe_convert_to_dtype(output, result_dtype) # type: ignore[arg-type]
- return (output_, mean, rstd) # type: ignore[return-value]
- def _native_batch_norm_autograd(
- input: torch.Tensor,
- weight: Optional[torch.Tensor],
- bias: Optional[torch.Tensor],
- running_mean: Optional[torch.Tensor],
- running_var: Optional[torch.Tensor],
- training: bool,
- momentum: float,
- eps: float,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- # This wrapper is needed to convert prims calls inside
- # _native_batch_norm_ref to nvprims calls
- from torch._prims.context import NvfuserPrimsMode
- with NvfuserPrimsMode():
- return backwards_not_supported(_native_batch_norm_ref)(
- input, weight, bias, running_mean, running_var, training, momentum, eps
- )
- nvprim_autograd_impl.impl(name, _native_batch_norm_autograd)
- for p in (prim_packet, prim):
- p.__doc__ = "Computes batch normalization."
- p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
- p.is_recomputable = _nvfuser_is_recomputable["native_batch_norm"]
- p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
- def register_rand_like():
- name = "rand_like"
- nvprim.define(
- "rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
- + "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
- )
- def _meta_rand_like(
- self,
- *,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=None,
- memory_format=None,
- ):
- strides = make_contiguous_strides_for(self.shape)
- return torch._prims.TensorMeta(
- self,
- shape=self.shape,
- strides=strides,
- dtype=dtype,
- device=device,
- )
- def _prim_impl(
- self,
- *,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=None,
- memory_format=None,
- ):
- return torch.rand_like(
- self,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- memory_format=memory_format,
- )
- nvprim_impl.impl(name, _prim_impl)
- nvprim_meta_impl.impl(name, _meta_rand_like)
- prim_packet = getattr(torch._ops.ops.nvprims, name)
- prim = prim_packet.default
- nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
- for p in (prim_packet, prim):
- p.__doc__ = "Computes rand_like"
- p.impl_nvfuser = _nvfuser_impls["rand_like"]
- p.is_recomputable = _nvfuser_is_recomputable["rand_like"]
- p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
- def register_var_mean():
- """This function is used to register the var_mean function in torch.ops.nvprims module."""
- name = "var_mean.main"
- # This overload must be default for correct dispatching of var_mean(Tensor, bool)
- nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
- # This signature tries to combine several overloads of the torch.var_mean function into one overload.
- nvprim.define(
- f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
- + " -> (Tensor, Tensor)"
- )
- # This function is used for device="meta" Tensors.
- def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
- if torch._prims_common.is_complex_dtype(inp.dtype):
- output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
- else:
- output_dtype = inp.dtype
- var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
- mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
- if keepdim:
- output_shape = [
- inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
- ]
- broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
- var = torch._ops.ops.nvprims.broadcast_in_dim(
- var, output_shape, broadcast_dims
- )
- mean = torch._ops.ops.nvprims.broadcast_in_dim(
- mean, output_shape, broadcast_dims
- )
- return (var, mean)
- # This function is used under _AutoDispatchBelowAutograd context
- def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
- correction = torch._prims_common.set_correction(unbiased, correction)
- return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
- nvprim_impl.impl(name, _prim_impl)
- nvprim_meta_impl.impl(name, _meta_var_mean)
- prim_packet = torch._ops.ops.nvprims.var_mean
- prim = prim_packet.main
- def _unbiased_overload_impl(inp, unbiased):
- return prim(inp, dim=None, unbiased=unbiased)
- nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
- )
- def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
- correction = torch._prims_common.set_correction(unbiased, correction)
- # reduces over all dimensions if dim=() is passed
- if dim == () or dim == []:
- dim = None
- dim = torch._prims_common.reduction_dims(a.shape, dim)
- # For complex tensors eager computes the variance as the sum of variances of
- # the real and imaginary parts
- # TODO: Creating a complex tensor from real and imaginary parts is not supported
- if torch._prims_common.is_complex_dtype(a.dtype):
- raise NotImplementedError("Complex tensors are not supported")
- var_mean = prim(a, dim, correction=correction)
- if keepdim:
- output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
- broadcast_dims = [i for i in range(a.ndim) if i not in dim]
- var, mean = var_mean
- var = torch._ops.ops.nvprims.broadcast_in_dim(
- var, output_shape, broadcast_dims
- )
- mean = torch._ops.ops.nvprims.broadcast_in_dim(
- mean, output_shape, broadcast_dims
- )
- var_mean = (var, mean)
- return var_mean
- def _var_mean_autograd(
- a, dim=None, unbiased=None, keepdim=False, *, correction=None
- ):
- # This wrapper is needed to convert prims calls inside
- # elementwise_type_promotion_wrapper to nvprims calls
- from torch._prims.context import NvfuserPrimsMode
- with NvfuserPrimsMode():
- return backwards_not_supported(_var_mean_ref)(
- a, dim, unbiased, keepdim, correction=correction
- )
- nvprim_autograd_impl.impl(name, _var_mean_autograd)
- for p in (prim_packet, prim):
- p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
- p.impl_nvfuser = _nvfuser_impls["var_mean"]
- p.is_recomputable = _nvfuser_is_recomputable["var_mean"]
- p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
- def _nvprims_view_impl_aten(a, original_shape, new_shape):
- return a.reshape(new_shape)
- def register_view():
- """This function is used to register the view function in torch.ops.view module."""
- # View is implemented as a decomposition into prims.split_dim,
- # prims.collapse_dim, and prims.reshape, but we would like to intercept
- # non-decomposed view for now
- name = "view"
- nvprim.define("view(Tensor inp, SymInt[] original_shape, SymInt[] shape) -> Tensor")
- nvprim.define("view.shape(Tensor inp, SymInt[] shape) -> Tensor")
- # This function is used under _AutoDispatchBelowAutograd context
- def _prim_impl(a, original_shape, new_shape):
- return a.reshape(new_shape)
- nvprim_impl.impl(name, _prim_impl)
- prim_packet = torch._ops.ops.nvprims.view
- prim = prim_packet.default
- def _view_no_original_shape_overload_impl(a, shape):
- if list(a.shape) == list(shape):
- return torch.ops.nvprims.view_of(a)
- return torch.ops.nvprims.view.default(a, a.shape, shape)
- nvprim_implicit_impl.impl("view.shape", _view_no_original_shape_overload_impl)
- nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
- for p in (prim_packet, prim):
- p.__doc__ = "Creates a tensor with the specified shape containing a copy of the data in a."
- p.impl_nvfuser = _nvfuser_impls["view"]
- p.is_recomputable = _nvfuser_is_recomputable["view"]
- p.return_type = torch._prims_common.RETURN_TYPE.VIEW # type: ignore[attr-defined]
- p.impl_aten = _nvprims_view_impl_aten
- def register_nvprims():
- """Registers all nvFuser primitives in the torch.ops.nvprims module."""
- register_var_mean()
- register_view()
- register_native_batch_norm()
- register_rand_like()
- register_full()
- for name in nvprim_names:
- main_prim = getattr(torch._ops.ops.prims, name)
- nvprim.define(main_prim.schema)
- nvprim_impl.impl(name, main_prim.prim_impl)
- nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
- prim_packet = getattr(torch._ops.ops.nvprims, name)
- prim = prim_packet.default
- nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
- for p in (prim_packet, prim):
- p.__doc__ = main_prim.__doc__
- p.impl_nvfuser = _nvfuser_impls[name]
- p.is_recomputable = _nvfuser_is_recomputable.get(name, False)
- p.return_type = main_prim.return_type # type: ignore[attr-defined]
- p.impl_aten = main_prim.impl_aten
|