wrappers.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import torch
  2. from torch._prims_common import (
  3. Number,
  4. NumberType,
  5. TensorLike,
  6. TensorLikeType,
  7. ShapeType,
  8. ELEMENTWISE_TYPE_PROMOTION_KIND,
  9. )
  10. import torch._prims_common as utils
  11. from torch.utils._pytree import tree_flatten, tree_unflatten
  12. from typing import Callable, Sequence, Tuple, NamedTuple, overload
  13. import inspect
  14. from functools import wraps
  15. import warnings
  16. from itertools import chain
  17. @overload
  18. def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
  19. pass
  20. @overload
  21. def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
  22. pass
  23. @overload
  24. def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
  25. pass
  26. @overload
  27. def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None:
  28. pass
  29. # TODO: implement ref.cast with an option to enforce safe casting
  30. def _maybe_convert_to_dtype(a, dtype):
  31. if isinstance(a, TensorLike):
  32. if a.dtype != dtype:
  33. return a.to(dtype)
  34. return a
  35. if isinstance(a, Number):
  36. return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type]
  37. if isinstance(a, Sequence):
  38. return tuple(_maybe_convert_to_dtype(x, dtype) for x in a)
  39. # Passthrough None because some functions wrapped with type promotion
  40. # wrapper might have optional args
  41. if a is None:
  42. return None
  43. raise ValueError(
  44. "Received type {0} that is neither a tensor or a number!".format(type(a))
  45. )
  46. def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType:
  47. if not isinstance(a, Number):
  48. msg = "Found unknown type {0} when trying to convert scalars!".format(type(a))
  49. raise ValueError(msg)
  50. if not utils.is_weakly_lesser_type(type(a), typ):
  51. msg = "Scalar {0} of type {1} cannot be safely cast to type {2}!".format(
  52. a, type(a), typ
  53. )
  54. raise ValueError(msg)
  55. return typ(a)
  56. def _annotation_has_type(*, typ, annotation):
  57. if hasattr(annotation, "__args__"):
  58. for a in annotation.__args__:
  59. if _annotation_has_type(typ=typ, annotation=a):
  60. return True
  61. return False
  62. return typ is annotation
  63. class elementwise_type_promotion_wrapper:
  64. """
  65. Adds elementwise type promotion to a Python reference implementation.
  66. Takes two kwargs, type_promoting_args and type_promotion_kind.
  67. type_promoting_args must be a string Sequence specifiying the argument names of all
  68. arguments that participate in type promotion (and should be type promoted). If the
  69. arg specifies a Sequence-type then every element of the Sequence will participate in
  70. type promotion.
  71. type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND.
  72. See its documentation for details.
  73. Other type promotion behavior, like validating the Python type of scalar arguments, must
  74. be handled separately.
  75. """
  76. def __init__(
  77. self,
  78. *,
  79. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  80. type_promoting_args: Sequence[str] = None,
  81. ):
  82. self.type_promoting_arg_names = type_promoting_args
  83. self.type_promotion_kind = type_promotion_kind
  84. def __call__(self, fn: Callable) -> Callable:
  85. sig = inspect.signature(fn)
  86. @wraps(fn)
  87. def _fn(*args, **kwargs):
  88. bound = sig.bind(*args, **kwargs)
  89. type_promoting_args = tuple(
  90. bound.arguments[x]
  91. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  92. if x in bound.arguments.keys()
  93. )
  94. flattened_type_promoting_args = tree_flatten(type_promoting_args)[0]
  95. compute_dtype, result_dtype = utils.elementwise_dtypes(
  96. *flattened_type_promoting_args,
  97. type_promotion_kind=self.type_promotion_kind,
  98. )
  99. promoted_args = {
  100. x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
  101. for x in self.type_promoting_arg_names # type: ignore[union-attr]
  102. if x in bound.arguments.keys()
  103. }
  104. bound.arguments.update(promoted_args)
  105. result = fn(**bound.arguments)
  106. if isinstance(result, TensorLike):
  107. return _maybe_convert_to_dtype(result, result_dtype)
  108. if isinstance(result, Sequence):
  109. return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result)
  110. raise AssertionError(f"Unhandled result type: {type(result)}")
  111. _fn.__signature__ = sig # type: ignore[attr-defined]
  112. return _fn
  113. # TODO: handle tuples of tensors
  114. def _maybe_resize_out(out: TensorLikeType, shape: ShapeType):
  115. # If the shapes are correct there's nothing to do
  116. if utils.same_shape(out.shape, shape):
  117. return out
  118. else:
  119. if out.numel() != 0:
  120. msg = (
  121. f"An output with one or more elements was resized since it had shape {str(out.shape)} "
  122. "which does not match the required output shape {str(shape)}. "
  123. "This behavior is deprecated, and in a future PyTorch release outputs will not "
  124. "be resized unless they have zero elements. "
  125. "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)."
  126. )
  127. warnings.warn(msg)
  128. return out.resize_(shape)
  129. def _safe_copy_out(
  130. *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False
  131. ):
  132. # Checks same device
  133. if copy_from.device != copy_to.device:
  134. msg = "Attempting to copy from device {0} to device {1}, but cross-device copies are not allowed!".format(
  135. copy_from.device, copy_to.device
  136. )
  137. raise RuntimeError(msg)
  138. # Checks safe cast
  139. if exact_dtype:
  140. utils.check(
  141. copy_from.dtype == copy_to.dtype,
  142. lambda: f"Expected out tensor to have dtype {copy_from.dtype} "
  143. f"but got {copy_to.dtype} instead",
  144. )
  145. else:
  146. utils.check(
  147. utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype),
  148. lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, "
  149. "but this can't be cast because it is not safe!",
  150. )
  151. return copy_to.copy_(copy_from)
  152. def out_wrapper(*out_names: str, exact_dtype: bool = False):
  153. is_tensor = len(out_names) == 0
  154. assert is_tensor or len(out_names) >= 2
  155. def _out_wrapper(fn: Callable) -> Callable:
  156. """
  157. Adds the out parameter to a Python reference.
  158. """
  159. out_type = (
  160. TensorLikeType
  161. if is_tensor
  162. else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
  163. )
  164. return_type = (
  165. TensorLikeType
  166. if is_tensor
  167. else NamedTuple(
  168. f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
  169. )
  170. )
  171. sig = inspect.signature(fn)
  172. factory_kwargs = ("device", "dtype")
  173. is_factory_fn = all(p in sig.parameters for p in factory_kwargs)
  174. @wraps(fn)
  175. def _fn(*args, out=None, **kwargs):
  176. if is_factory_fn and out is not None:
  177. for k in factory_kwargs:
  178. out_attr = getattr(out, k)
  179. if k not in kwargs:
  180. kwargs[k] = out_attr
  181. result = fn(*args, **kwargs)
  182. assert (
  183. isinstance(result, TensorLike)
  184. and is_tensor
  185. or isinstance(result, Tuple) # type: ignore[arg-type]
  186. and len(result) == len(out_names)
  187. )
  188. if out is not None:
  189. # Naively you might expect this assert to be true, but
  190. # it's not:
  191. #
  192. # assert type(out) == type(result)
  193. #
  194. # The reason is that functions under this wrapper can
  195. # get registered to the Meta dispatch key, and that
  196. # means they can be executed in a context where tensor
  197. # subclasses are disabled (with no_dispatch), which is a
  198. # handy way for an is-a tensor subclass (e.g.,
  199. # FakeTensor) to have the normal meta backend create a
  200. # meta tensor, to be wrapped once it gets returned.
  201. # In this situation, you will get a FakeTensor as
  202. # the output tensor, but not the result--which will
  203. # be a normal meta tensor, but this is perfectly
  204. # harmless.
  205. if is_tensor:
  206. assert isinstance(out, TensorLike)
  207. # These two operations are done in-place
  208. _maybe_resize_out(out, result.shape)
  209. _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
  210. else:
  211. assert isinstance(out, Tuple) # type: ignore[arg-type]
  212. utils.check(
  213. len(out) == len(result),
  214. lambda: f"expected tuple of {len(result)} elements but got {len(out)}",
  215. TypeError,
  216. )
  217. for r, o in zip(result, out):
  218. # These two operations are done in-place
  219. _maybe_resize_out(o, r.shape)
  220. _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type]
  221. else:
  222. out = result
  223. # mypy does not see through the definition of out_type given that it's in a different scope
  224. return out if is_tensor else return_type(*out) # type: ignore[operator]
  225. out_param = inspect.Parameter(
  226. "out",
  227. kind=inspect.Parameter.KEYWORD_ONLY,
  228. default=None,
  229. annotation=out_type,
  230. )
  231. # Mark that the function now returns a tuple
  232. assert sig.return_annotation in (sig.empty, out_type)
  233. params = chain(sig.parameters.values(), (out_param,))
  234. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  235. parameters=params, return_annotation=return_type # type: ignore[arg-type]
  236. )
  237. _fn.__annotations__ = fn.__annotations__
  238. _fn.__annotations__["out"] = out_type
  239. _fn.__annotations__["return"] = return_type
  240. return _fn
  241. return _out_wrapper
  242. def backwards_not_supported(prim):
  243. def redispatch_prim(args, kwargs):
  244. g = torch._C._AutoDispatchBelowAutograd()
  245. try:
  246. old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView)
  247. return prim(*args, **kwargs)
  248. finally:
  249. del g
  250. class BackwardsNotSupported(torch.autograd.Function):
  251. @staticmethod
  252. def forward(ctx, args_spec, *flat_args):
  253. args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
  254. return redispatch_prim(args, kwargs)
  255. @staticmethod
  256. def backward(ctx, *args):
  257. raise RuntimeError("backwards not supported on prim")
  258. @wraps(prim)
  259. def _autograd_impl(*args, **kwargs):
  260. flat_args, args_spec = tree_flatten((args, kwargs))
  261. if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)):
  262. # TODO: There is a subtle bug here: prims like copy_to
  263. # return their input argument after mutating it; and custom
  264. # autograd function will incorrectly turn the result into
  265. # a view which will fail test_python_ref_executor tests.
  266. # At the moment, we sidestep this by observing that the
  267. # unit tests don't ever try to run the executor with
  268. # autograd, so we don't exercise the buggy case, but if
  269. # you ever want to feed autograd through this, be aware
  270. # of it! We need a way of properly implementing autograd
  271. # for mutating operations in Python to do this.
  272. return BackwardsNotSupported.apply(args_spec, *flat_args)
  273. else:
  274. return redispatch_prim(args, kwargs)
  275. return _autograd_impl
  276. # TODO: when tracing this will add torch tensors and not TensorMeta objects
  277. # to the trace -- we should fix this by adding a tracing context and NumberMeta classes
  278. # TODO: this wrapper is currently untested
  279. def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable:
  280. """
  281. Allows unary operators that accept tensors to work with Python numbers.
  282. """
  283. sig = inspect.signature(fn)
  284. @wraps(fn)
  285. def _fn(*args, **kwargs):
  286. if len(args) > 0 and isinstance(args[0], Number):
  287. dtype = utils.type_to_dtype(type(args[0]))
  288. args_ = list(args)
  289. args_[0] = torch.tensor(args[0], dtype=dtype)
  290. result = fn(*args_, **kwargs)
  291. assert isinstance(result, torch.Tensor)
  292. return result.item()
  293. return fn(*args, **kwargs)
  294. _fn.__signature__ = sig # type: ignore[attr-defined]
  295. return _fn