123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- import torch
- from torch._prims_common import (
- Number,
- NumberType,
- TensorLike,
- TensorLikeType,
- ShapeType,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- )
- import torch._prims_common as utils
- from torch.utils._pytree import tree_flatten, tree_unflatten
- from typing import Callable, Sequence, Tuple, NamedTuple, overload
- import inspect
- from functools import wraps
- import warnings
- from itertools import chain
- @overload
- def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
- pass
- @overload
- def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
- pass
- @overload
- def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
- pass
- @overload
- def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
- pass
- # TODO: implement ref.cast with an option to enforce safe casting
- def _maybe_convert_to_dtype(a, dtype):
- if isinstance(a, TensorLike):
- if a.dtype != dtype:
- return a.to(dtype)
- return a
- if isinstance(a, Number):
- return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]
- if isinstance(a, Sequence):
- return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
- # Passthrough None because some functions wrapped with type promotion
- # wrapper might have optional args
- if a is None:
- return None
- raise ValueError(
- "Received type {0} that is neither a tensor or a number!".format(type(a))
- )
- def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
- if not isinstance(a, Number):
- msg = "Found unknown type {0} when trying to convert scalars!".format(type(a))
- raise ValueError(msg)
- if not utils.is_weakly_lesser_type(type(a), typ):
- msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format(
- a, type(a), typ
- )
- raise ValueError(msg)
- return typ(a)
- def _annotation_has_type(*, typ, annotation):
- if hasattr(annotation, "__args__"):
- for a in annotation.__args__:
- if _annotation_has_type(typ=typ, annotation=a):
- return True
- return False
- return typ is annotation
- class elementwise_type_promotion_wrapper:
- """
- Adds elementwise type promotion to a Python reference implementation.
- Takes two kwargs, type_promoting_args and type_promotion_kind.
- type_promoting_args must be a string Sequence specifiying the argument names of all
- arguments that participate in type promotion (and should be type promoted). If the
- arg specifies a Sequence-type then every element of the Sequence will participate in
- type promotion.
- type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
- See its documentation for details.
- Other type promotion behavior, like validating the Python type of scalar arguments, must
- be handled separately.
- """
- def __init__(
- self,
- *,
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
- type_promoting_args: Sequence[str] = None,
- ):
- self.type_promoting_arg_names = type_promoting_args
- self.type_promotion_kind = type_promotion_kind
- def __call__(self, fn: Callable) -> Callable:
- sig = inspect.signature(fn)
- @wraps(fn)
- def _fn(*args, **kwargs):
- bound = sig.bind(*args, **kwargs)
- type_promoting_args = tuple(
- bound.arguments[x]
- for x in self.type_promoting_arg_names # type: ignore[union-attr]
- if x in bound.arguments.keys()
- )
- flattened_type_promoting_args = tree_flatten(type_promoting_args)[0]
- compute_dtype, result_dtype = utils.elementwise_dtypes(
- *flattened_type_promoting_args,
- type_promotion_kind=self.type_promotion_kind,
- )
- promoted_args = {
- x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
- for x in self.type_promoting_arg_names # type: ignore[union-attr]
- if x in bound.arguments.keys()
- }
- bound.arguments.update(promoted_args)
- result = fn(**bound.arguments)
- if isinstance(result, TensorLike):
- return _maybe_convert_to_dtype(result, result_dtype)
- if isinstance(result, Sequence):
- return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
- raise AssertionError(f"Unhandled result type: {type(result)}")
- _fn.__signature__ = sig # type: ignore[attr-defined]
- return _fn
- # TODO: handle tuples of tensors
- def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
- # If the shapes are correct there's nothing to do
- if utils.same_shape(out.shape, shape):
- return out
- else:
- if out.numel() != 0:
- msg = (
- f"An output with one or more elements was resized since it had shape {str(out.shape)} "
- "which does not match the required output shape {str(shape)}. "
- "This behavior is deprecated, and in a future PyTorch release outputs will not "
- "be resized unless they have zero elements. "
- "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
- )
- warnings.warn(msg)
- return out.resize_(shape)
- def _safe_copy_out(
- *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
- ):
- # Checks same device
- if copy_from.device != copy_to.device:
- msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
- copy_from.device, copy_to.device
- )
- raise RuntimeError(msg)
- # Checks safe cast
- if exact_dtype:
- utils.check(
- copy_from.dtype == copy_to.dtype,
- lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
- f"but got {copy_to.dtype} instead",
- )
- else:
- utils.check(
- utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
- lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
- "but this can't be cast because it is not safe!",
- )
- return copy_to.copy_(copy_from)
- def out_wrapper(*out_names: str, exact_dtype: bool = False):
- is_tensor = len(out_names) == 0
- assert is_tensor or len(out_names) >= 2
- def _out_wrapper(fn: Callable) -> Callable:
- """
- Adds the out parameter to a Python reference.
- """
- out_type = (
- TensorLikeType
- if is_tensor
- else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
- )
- return_type = (
- TensorLikeType
- if is_tensor
- else NamedTuple(
- f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
- )
- )
- sig = inspect.signature(fn)
- factory_kwargs = ("device", "dtype")
- is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
- @wraps(fn)
- def _fn(*args, out=None, **kwargs):
- if is_factory_fn and out is not None:
- for k in factory_kwargs:
- out_attr = getattr(out, k)
- if k not in kwargs:
- kwargs[k] = out_attr
- result = fn(*args, **kwargs)
- assert (
- isinstance(result, TensorLike)
- and is_tensor
- or isinstance(result, Tuple) # type: ignore[arg-type]
- and len(result) == len(out_names)
- )
- if out is not None:
- # Naively you might expect this assert to be true, but
- # it's not:
- #
- # assert type(out) == type(result)
- #
- # The reason is that functions under this wrapper can
- # get registered to the Meta dispatch key, and that
- # means they can be executed in a context where tensor
- # subclasses are disabled (with no_dispatch), which is a
- # handy way for an is-a tensor subclass (e.g.,
- # FakeTensor) to have the normal meta backend create a
- # meta tensor, to be wrapped once it gets returned.
- # In this situation, you will get a FakeTensor as
- # the output tensor, but not the result--which will
- # be a normal meta tensor, but this is perfectly
- # harmless.
- if is_tensor:
- assert isinstance(out, TensorLike)
- # These two operations are done in-place
- _maybe_resize_out(out, result.shape)
- _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
- else:
- assert isinstance(out, Tuple) # type: ignore[arg-type]
- utils.check(
- len(out) == len(result),
- lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
- TypeError,
- )
- for r, o in zip(result, out):
- # These two operations are done in-place
- _maybe_resize_out(o, r.shape)
- _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
- else:
- out = result
- # mypy does not see through the definition of out_type given that it's in a different scope
- return out if is_tensor else return_type(*out) # type: ignore[operator]
- out_param = inspect.Parameter(
- "out",
- kind=inspect.Parameter.KEYWORD_ONLY,
- default=None,
- annotation=out_type,
- )
- # Mark that the function now returns a tuple
- assert sig.return_annotation in (sig.empty, out_type)
- params = chain(sig.parameters.values(), (out_param,))
- _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
- parameters=params, return_annotation=return_type # type: ignore[arg-type]
- )
- _fn.__annotations__ = fn.__annotations__
- _fn.__annotations__["out"] = out_type
- _fn.__annotations__["return"] = return_type
- return _fn
- return _out_wrapper
- def backwards_not_supported(prim):
- def redispatch_prim(args, kwargs):
- g = torch._C._AutoDispatchBelowAutograd()
- try:
- old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView)
- return prim(*args, **kwargs)
- finally:
- del g
- class BackwardsNotSupported(torch.autograd.Function):
- @staticmethod
- def forward(ctx, args_spec, *flat_args):
- args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
- return redispatch_prim(args, kwargs)
- @staticmethod
- def backward(ctx, *args):
- raise RuntimeError("backwards not supported on prim")
- @wraps(prim)
- def _autograd_impl(*args, **kwargs):
- flat_args, args_spec = tree_flatten((args, kwargs))
- if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)):
- # TODO: There is a subtle bug here: prims like copy_to
- # return their input argument after mutating it; and custom
- # autograd function will incorrectly turn the result into
- # a view which will fail test_python_ref_executor tests.
- # At the moment, we sidestep this by observing that the
- # unit tests don't ever try to run the executor with
- # autograd, so we don't exercise the buggy case, but if
- # you ever want to feed autograd through this, be aware
- # of it! We need a way of properly implementing autograd
- # for mutating operations in Python to do this.
- return BackwardsNotSupported.apply(args_spec, *flat_args)
- else:
- return redispatch_prim(args, kwargs)
- return _autograd_impl
- # TODO: when tracing this will add torch tensors and not TensorMeta objects
- # to the trace -- we should fix this by adding a tracing context and NumberMeta classes
- # TODO: this wrapper is currently untested
- def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
- """
- Allows unary operators that accept tensors to work with Python numbers.
- """
- sig = inspect.signature(fn)
- @wraps(fn)
- def _fn(*args, **kwargs):
- if len(args) > 0 and isinstance(args[0], Number):
- dtype = utils.type_to_dtype(type(args[0]))
- args_ = list(args)
- args_[0] = torch.tensor(args[0], dtype=dtype)
- result = fn(*args_, **kwargs)
- assert isinstance(result, torch.Tensor)
- return result.item()
- return fn(*args, **kwargs)
- _fn.__signature__ = sig # type: ignore[attr-defined]
- return _fn
|