utils.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. import collections
  2. import warnings
  3. from functools import partial, wraps
  4. from typing import Sequence
  5. import numpy as np
  6. import torch
  7. from torch.testing._internal.common_cuda import TEST_CUDA
  8. from torch.testing._internal.common_dtype import (
  9. _dispatch_dtypes,
  10. all_types,
  11. all_types_and,
  12. all_types_and_complex,
  13. all_types_and_complex_and,
  14. all_types_and_half,
  15. complex_types,
  16. floating_and_complex_types,
  17. floating_and_complex_types_and,
  18. floating_types,
  19. floating_types_and,
  20. floating_types_and_half,
  21. integral_types,
  22. integral_types_and,
  23. )
  24. from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
  25. COMPLETE_DTYPES_DISPATCH = (
  26. all_types,
  27. all_types_and_complex,
  28. all_types_and_half,
  29. floating_types,
  30. floating_and_complex_types,
  31. floating_types_and_half,
  32. integral_types,
  33. complex_types,
  34. )
  35. EXTENSIBLE_DTYPE_DISPATCH = (
  36. all_types_and_complex_and,
  37. floating_types_and,
  38. floating_and_complex_types_and,
  39. integral_types_and,
  40. all_types_and,
  41. )
  42. # Better way to acquire devices?
  43. DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else [])
  44. class _dynamic_dispatch_dtypes(_dispatch_dtypes):
  45. # Class to tag the dynamically generated types.
  46. pass
  47. def get_supported_dtypes(op, sample_inputs_fn, device_type):
  48. # Returns the supported dtypes for the given operator and device_type pair.
  49. assert device_type in ["cpu", "cuda"]
  50. if not TEST_CUDA and device_type == "cuda":
  51. warnings.warn(
  52. "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!"
  53. )
  54. return _dynamic_dispatch_dtypes(())
  55. supported_dtypes = set()
  56. for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
  57. try:
  58. samples = sample_inputs_fn(op, device_type, dtype, False)
  59. except RuntimeError:
  60. # If `sample_inputs_fn` doesn't support sampling for a given
  61. # `dtype`, we assume that the `dtype` is not supported.
  62. # We raise a warning, so that user knows that this was the case
  63. # and can investigate if there was an issue with the `sample_inputs_fn`.
  64. warnings.warn(
  65. f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}"
  66. )
  67. continue
  68. # We assume the dtype is supported
  69. # only if all samples pass for the given dtype.
  70. supported = True
  71. for sample in samples:
  72. try:
  73. op(sample.input, *sample.args, **sample.kwargs)
  74. except RuntimeError as re:
  75. # dtype is not supported
  76. supported = False
  77. break
  78. if supported:
  79. supported_dtypes.add(dtype)
  80. return _dynamic_dispatch_dtypes(supported_dtypes)
  81. def dtypes_dispatch_hint(dtypes):
  82. # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH)
  83. # and its string representation for the passed `dtypes`.
  84. return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str")
  85. # CUDA is not available, dtypes will be empty.
  86. if len(dtypes) == 0:
  87. return return_type((), str(tuple()))
  88. set_dtypes = set(dtypes)
  89. for dispatch in COMPLETE_DTYPES_DISPATCH:
  90. # Short circuit if we get an exact match.
  91. if set(dispatch()) == set_dtypes:
  92. return return_type(dispatch, dispatch.__name__ + "()")
  93. chosen_dispatch = None
  94. chosen_dispatch_score = 0.0
  95. for dispatch in EXTENSIBLE_DTYPE_DISPATCH:
  96. dispatch_dtypes = set(dispatch())
  97. if not dispatch_dtypes.issubset(set_dtypes):
  98. continue
  99. score = len(dispatch_dtypes)
  100. if score > chosen_dispatch_score:
  101. chosen_dispatch_score = score
  102. chosen_dispatch = dispatch
  103. # If user passed dtypes which are lower than the lowest
  104. # dispatch type available (not likely but possible in code path).
  105. if chosen_dispatch is None:
  106. return return_type((), str(dtypes))
  107. return return_type(
  108. partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
  109. dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))),
  110. )
  111. def is_dynamic_dtype_set(op):
  112. # Detect if the OpInfo entry acquired dtypes dynamically
  113. # using `get_supported_dtypes`.
  114. return op.dynamic_dtypes
  115. def str_format_dynamic_dtype(op):
  116. fmt_str = """
  117. OpInfo({name},
  118. dtypes={dtypes},
  119. dtypesIfCUDA={dtypesIfCUDA},
  120. )
  121. """.format(
  122. name=op.name,
  123. dtypes=dtypes_dispatch_hint(op.dtypes).dispatch_fn_str,
  124. dtypesIfCUDA=dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str,
  125. )
  126. return fmt_str
  127. def np_unary_ufunc_integer_promotion_wrapper(fn):
  128. # Wrapper that passes PyTorch's default scalar
  129. # type as an argument to the wrapped NumPy
  130. # unary ufunc when given an integer input.
  131. # This mimicks PyTorch's integer->floating point
  132. # type promotion.
  133. #
  134. # This is necessary when NumPy promotes
  135. # integer types to double, since PyTorch promotes
  136. # integer types to the default scalar type.
  137. # Helper to determine if promotion is needed
  138. def is_integral(dtype):
  139. return dtype in [
  140. np.bool_,
  141. bool,
  142. np.uint8,
  143. np.int8,
  144. np.int16,
  145. np.int32,
  146. np.int64,
  147. ]
  148. @wraps(fn)
  149. def wrapped_fn(x):
  150. # As the default dtype can change, acquire it when function is called.
  151. # NOTE: Promotion in PyTorch is from integer types to the default dtype
  152. np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
  153. if is_integral(x.dtype):
  154. return fn(x.astype(np_dtype))
  155. return fn(x)
  156. return wrapped_fn
  157. def reference_reduction_numpy(f, supports_keepdims=True):
  158. """Wraps a NumPy reduction operator.
  159. The wrapper function will forward dim, keepdim, mask, and identity
  160. kwargs to the wrapped function as the NumPy equivalent axis,
  161. keepdims, where, and initiak kwargs, respectively.
  162. Args:
  163. f: NumPy reduction operator to wrap
  164. supports_keepdims (bool, optional): Whether the NumPy operator accepts
  165. keepdims parameter. If it does not, the wrapper will manually unsqueeze
  166. the reduced dimensions if it was called with keepdim=True. Defaults to True.
  167. Returns:
  168. Wrapped function
  169. """
  170. @wraps(f)
  171. def wrapper(x: np.ndarray, *args, **kwargs):
  172. # Copy keys into a set
  173. keys = set(kwargs.keys())
  174. dim = kwargs.pop("dim", None)
  175. keepdim = kwargs.pop("keepdim", False)
  176. if "dim" in keys:
  177. dim = tuple(dim) if isinstance(dim, Sequence) else dim
  178. # NumPy reductions don't accept dim=0 for scalar inputs
  179. # so we convert it to None if and only if dim is equivalent
  180. if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}:
  181. kwargs["axis"] = None
  182. else:
  183. kwargs["axis"] = dim
  184. if "keepdim" in keys and supports_keepdims:
  185. kwargs["keepdims"] = keepdim
  186. if "mask" in keys:
  187. mask = kwargs.pop("mask")
  188. if mask is not None:
  189. assert mask.layout == torch.strided
  190. kwargs["where"] = mask.cpu().numpy()
  191. if "identity" in keys:
  192. identity = kwargs.pop("identity")
  193. if identity is not None:
  194. if identity.dtype is torch.bfloat16:
  195. identity = identity.cpu().to(torch.float32)
  196. else:
  197. identity = identity.cpu()
  198. kwargs["initial"] = identity.numpy()
  199. result = f(x, *args, **kwargs)
  200. # Unsqueeze reduced dimensions if NumPy does not support keepdims
  201. if keepdim and not supports_keepdims and x.ndim > 0:
  202. dim = list(range(x.ndim)) if dim is None else dim
  203. result = np.expand_dims(result, dim)
  204. return result
  205. return wrapper
  206. def prod_numpy(a, *args, **kwargs):
  207. """
  208. The function will call np.prod with type as np.int64 if the input type
  209. is int or uint64 if is uint. This is necessary because windows np.prod uses by default
  210. int32 while on linux it uses int64.
  211. This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320
  212. Returns:
  213. np.prod of input
  214. """
  215. if "dtype" not in kwargs:
  216. if np.issubdtype(a.dtype, np.signedinteger):
  217. a = a.astype(np.int64)
  218. elif np.issubdtype(a.dtype, np.unsignedinteger):
  219. a = a.astype(np.uint64)
  220. fn = reference_reduction_numpy(np.prod)
  221. return fn(a, *args, **kwargs)