123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- import collections
- import warnings
- from functools import partial, wraps
- from typing import Sequence
- import numpy as np
- import torch
- from torch.testing._internal.common_cuda import TEST_CUDA
- from torch.testing._internal.common_dtype import (
- _dispatch_dtypes,
- all_types,
- all_types_and,
- all_types_and_complex,
- all_types_and_complex_and,
- all_types_and_half,
- complex_types,
- floating_and_complex_types,
- floating_and_complex_types_and,
- floating_types,
- floating_types_and,
- floating_types_and_half,
- integral_types,
- integral_types_and,
- )
- from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
- COMPLETE_DTYPES_DISPATCH = (
- all_types,
- all_types_and_complex,
- all_types_and_half,
- floating_types,
- floating_and_complex_types,
- floating_types_and_half,
- integral_types,
- complex_types,
- )
- EXTENSIBLE_DTYPE_DISPATCH = (
- all_types_and_complex_and,
- floating_types_and,
- floating_and_complex_types_and,
- integral_types_and,
- all_types_and,
- )
- # Better way to acquire devices?
- DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else [])
- class _dynamic_dispatch_dtypes(_dispatch_dtypes):
- # Class to tag the dynamically generated types.
- pass
- def get_supported_dtypes(op, sample_inputs_fn, device_type):
- # Returns the supported dtypes for the given operator and device_type pair.
- assert device_type in ["cpu", "cuda"]
- if not TEST_CUDA and device_type == "cuda":
- warnings.warn(
- "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!"
- )
- return _dynamic_dispatch_dtypes(())
- supported_dtypes = set()
- for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
- try:
- samples = sample_inputs_fn(op, device_type, dtype, False)
- except RuntimeError:
- # If `sample_inputs_fn` doesn't support sampling for a given
- # `dtype`, we assume that the `dtype` is not supported.
- # We raise a warning, so that user knows that this was the case
- # and can investigate if there was an issue with the `sample_inputs_fn`.
- warnings.warn(
- f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}"
- )
- continue
- # We assume the dtype is supported
- # only if all samples pass for the given dtype.
- supported = True
- for sample in samples:
- try:
- op(sample.input, *sample.args, **sample.kwargs)
- except RuntimeError as re:
- # dtype is not supported
- supported = False
- break
- if supported:
- supported_dtypes.add(dtype)
- return _dynamic_dispatch_dtypes(supported_dtypes)
- def dtypes_dispatch_hint(dtypes):
- # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH)
- # and its string representation for the passed `dtypes`.
- return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str")
- # CUDA is not available, dtypes will be empty.
- if len(dtypes) == 0:
- return return_type((), str(tuple()))
- set_dtypes = set(dtypes)
- for dispatch in COMPLETE_DTYPES_DISPATCH:
- # Short circuit if we get an exact match.
- if set(dispatch()) == set_dtypes:
- return return_type(dispatch, dispatch.__name__ + "()")
- chosen_dispatch = None
- chosen_dispatch_score = 0.0
- for dispatch in EXTENSIBLE_DTYPE_DISPATCH:
- dispatch_dtypes = set(dispatch())
- if not dispatch_dtypes.issubset(set_dtypes):
- continue
- score = len(dispatch_dtypes)
- if score > chosen_dispatch_score:
- chosen_dispatch_score = score
- chosen_dispatch = dispatch
- # If user passed dtypes which are lower than the lowest
- # dispatch type available (not likely but possible in code path).
- if chosen_dispatch is None:
- return return_type((), str(dtypes))
- return return_type(
- partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
- dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))),
- )
- def is_dynamic_dtype_set(op):
- # Detect if the OpInfo entry acquired dtypes dynamically
- # using `get_supported_dtypes`.
- return op.dynamic_dtypes
- def str_format_dynamic_dtype(op):
- fmt_str = """
- OpInfo({name},
- dtypes={dtypes},
- dtypesIfCUDA={dtypesIfCUDA},
- )
- """.format(
- name=op.name,
- dtypes=dtypes_dispatch_hint(op.dtypes).dispatch_fn_str,
- dtypesIfCUDA=dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str,
- )
- return fmt_str
- def np_unary_ufunc_integer_promotion_wrapper(fn):
- # Wrapper that passes PyTorch's default scalar
- # type as an argument to the wrapped NumPy
- # unary ufunc when given an integer input.
- # This mimicks PyTorch's integer->floating point
- # type promotion.
- #
- # This is necessary when NumPy promotes
- # integer types to double, since PyTorch promotes
- # integer types to the default scalar type.
- # Helper to determine if promotion is needed
- def is_integral(dtype):
- return dtype in [
- np.bool_,
- bool,
- np.uint8,
- np.int8,
- np.int16,
- np.int32,
- np.int64,
- ]
- @wraps(fn)
- def wrapped_fn(x):
- # As the default dtype can change, acquire it when function is called.
- # NOTE: Promotion in PyTorch is from integer types to the default dtype
- np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
- if is_integral(x.dtype):
- return fn(x.astype(np_dtype))
- return fn(x)
- return wrapped_fn
- def reference_reduction_numpy(f, supports_keepdims=True):
- """Wraps a NumPy reduction operator.
- The wrapper function will forward dim, keepdim, mask, and identity
- kwargs to the wrapped function as the NumPy equivalent axis,
- keepdims, where, and initiak kwargs, respectively.
- Args:
- f: NumPy reduction operator to wrap
- supports_keepdims (bool, optional): Whether the NumPy operator accepts
- keepdims parameter. If it does not, the wrapper will manually unsqueeze
- the reduced dimensions if it was called with keepdim=True. Defaults to True.
- Returns:
- Wrapped function
- """
- @wraps(f)
- def wrapper(x: np.ndarray, *args, **kwargs):
- # Copy keys into a set
- keys = set(kwargs.keys())
- dim = kwargs.pop("dim", None)
- keepdim = kwargs.pop("keepdim", False)
- if "dim" in keys:
- dim = tuple(dim) if isinstance(dim, Sequence) else dim
- # NumPy reductions don't accept dim=0 for scalar inputs
- # so we convert it to None if and only if dim is equivalent
- if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}:
- kwargs["axis"] = None
- else:
- kwargs["axis"] = dim
- if "keepdim" in keys and supports_keepdims:
- kwargs["keepdims"] = keepdim
- if "mask" in keys:
- mask = kwargs.pop("mask")
- if mask is not None:
- assert mask.layout == torch.strided
- kwargs["where"] = mask.cpu().numpy()
- if "identity" in keys:
- identity = kwargs.pop("identity")
- if identity is not None:
- if identity.dtype is torch.bfloat16:
- identity = identity.cpu().to(torch.float32)
- else:
- identity = identity.cpu()
- kwargs["initial"] = identity.numpy()
- result = f(x, *args, **kwargs)
- # Unsqueeze reduced dimensions if NumPy does not support keepdims
- if keepdim and not supports_keepdims and x.ndim > 0:
- dim = list(range(x.ndim)) if dim is None else dim
- result = np.expand_dims(result, dim)
- return result
- return wrapper
- def prod_numpy(a, *args, **kwargs):
- """
- The function will call np.prod with type as np.int64 if the input type
- is int or uint64 if is uint. This is necessary because windows np.prod uses by default
- int32 while on linux it uses int64.
- This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320
- Returns:
- np.prod of input
- """
- if "dtype" not in kwargs:
- if np.issubdtype(a.dtype, np.signedinteger):
- a = a.astype(np.int64)
- elif np.issubdtype(a.dtype, np.unsignedinteger):
- a = a.astype(np.uint64)
- fn = reference_reduction_numpy(np.prod)
- return fn(a, *args, **kwargs)
|