test_elementwise_functions.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from inspect import getfullargspec
  2. from numpy.testing import assert_raises
  3. from .. import asarray, _elementwise_functions
  4. from .._elementwise_functions import bitwise_left_shift, bitwise_right_shift
  5. from .._dtypes import (
  6. _dtype_categories,
  7. _boolean_dtypes,
  8. _floating_dtypes,
  9. _integer_dtypes,
  10. )
  11. def nargs(func):
  12. return len(getfullargspec(func).args)
  13. def test_function_types():
  14. # Test that every function accepts only the required input types. We only
  15. # test the negative cases here (error). The positive cases are tested in
  16. # the array API test suite.
  17. elementwise_function_input_types = {
  18. "abs": "numeric",
  19. "acos": "floating-point",
  20. "acosh": "floating-point",
  21. "add": "numeric",
  22. "asin": "floating-point",
  23. "asinh": "floating-point",
  24. "atan": "floating-point",
  25. "atan2": "floating-point",
  26. "atanh": "floating-point",
  27. "bitwise_and": "integer or boolean",
  28. "bitwise_invert": "integer or boolean",
  29. "bitwise_left_shift": "integer",
  30. "bitwise_or": "integer or boolean",
  31. "bitwise_right_shift": "integer",
  32. "bitwise_xor": "integer or boolean",
  33. "ceil": "numeric",
  34. "cos": "floating-point",
  35. "cosh": "floating-point",
  36. "divide": "floating-point",
  37. "equal": "all",
  38. "exp": "floating-point",
  39. "expm1": "floating-point",
  40. "floor": "numeric",
  41. "floor_divide": "numeric",
  42. "greater": "numeric",
  43. "greater_equal": "numeric",
  44. "isfinite": "numeric",
  45. "isinf": "numeric",
  46. "isnan": "numeric",
  47. "less": "numeric",
  48. "less_equal": "numeric",
  49. "log": "floating-point",
  50. "logaddexp": "floating-point",
  51. "log10": "floating-point",
  52. "log1p": "floating-point",
  53. "log2": "floating-point",
  54. "logical_and": "boolean",
  55. "logical_not": "boolean",
  56. "logical_or": "boolean",
  57. "logical_xor": "boolean",
  58. "multiply": "numeric",
  59. "negative": "numeric",
  60. "not_equal": "all",
  61. "positive": "numeric",
  62. "pow": "numeric",
  63. "remainder": "numeric",
  64. "round": "numeric",
  65. "sign": "numeric",
  66. "sin": "floating-point",
  67. "sinh": "floating-point",
  68. "sqrt": "floating-point",
  69. "square": "numeric",
  70. "subtract": "numeric",
  71. "tan": "floating-point",
  72. "tanh": "floating-point",
  73. "trunc": "numeric",
  74. }
  75. def _array_vals():
  76. for d in _integer_dtypes:
  77. yield asarray(1, dtype=d)
  78. for d in _boolean_dtypes:
  79. yield asarray(False, dtype=d)
  80. for d in _floating_dtypes:
  81. yield asarray(1.0, dtype=d)
  82. for x in _array_vals():
  83. for func_name, types in elementwise_function_input_types.items():
  84. dtypes = _dtype_categories[types]
  85. func = getattr(_elementwise_functions, func_name)
  86. if nargs(func) == 2:
  87. for y in _array_vals():
  88. if x.dtype not in dtypes or y.dtype not in dtypes:
  89. assert_raises(TypeError, lambda: func(x, y))
  90. else:
  91. if x.dtype not in dtypes:
  92. assert_raises(TypeError, lambda: func(x))
  93. def test_bitwise_shift_error():
  94. # bitwise shift functions should raise when the second argument is negative
  95. assert_raises(
  96. ValueError, lambda: bitwise_left_shift(asarray([1, 1]), asarray([1, -1]))
  97. )
  98. assert_raises(
  99. ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
  100. )