test_mixins.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import numbers
  2. import operator
  3. import numpy as np
  4. from numpy.testing import assert_, assert_equal, assert_raises
  5. # NOTE: This class should be kept as an exact copy of the example from the
  6. # docstring for NDArrayOperatorsMixin.
  7. class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
  8. def __init__(self, value):
  9. self.value = np.asarray(value)
  10. # One might also consider adding the built-in list type to this
  11. # list, to support operations like np.add(array_like, list)
  12. _HANDLED_TYPES = (np.ndarray, numbers.Number)
  13. def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  14. out = kwargs.get('out', ())
  15. for x in inputs + out:
  16. # Only support operations with instances of _HANDLED_TYPES.
  17. # Use ArrayLike instead of type(self) for isinstance to
  18. # allow subclasses that don't override __array_ufunc__ to
  19. # handle ArrayLike objects.
  20. if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)):
  21. return NotImplemented
  22. # Defer to the implementation of the ufunc on unwrapped values.
  23. inputs = tuple(x.value if isinstance(x, ArrayLike) else x
  24. for x in inputs)
  25. if out:
  26. kwargs['out'] = tuple(
  27. x.value if isinstance(x, ArrayLike) else x
  28. for x in out)
  29. result = getattr(ufunc, method)(*inputs, **kwargs)
  30. if type(result) is tuple:
  31. # multiple return values
  32. return tuple(type(self)(x) for x in result)
  33. elif method == 'at':
  34. # no return value
  35. return None
  36. else:
  37. # one return value
  38. return type(self)(result)
  39. def __repr__(self):
  40. return '%s(%r)' % (type(self).__name__, self.value)
  41. def wrap_array_like(result):
  42. if type(result) is tuple:
  43. return tuple(ArrayLike(r) for r in result)
  44. else:
  45. return ArrayLike(result)
  46. def _assert_equal_type_and_value(result, expected, err_msg=None):
  47. assert_equal(type(result), type(expected), err_msg=err_msg)
  48. if isinstance(result, tuple):
  49. assert_equal(len(result), len(expected), err_msg=err_msg)
  50. for result_item, expected_item in zip(result, expected):
  51. _assert_equal_type_and_value(result_item, expected_item, err_msg)
  52. else:
  53. assert_equal(result.value, expected.value, err_msg=err_msg)
  54. assert_equal(getattr(result.value, 'dtype', None),
  55. getattr(expected.value, 'dtype', None), err_msg=err_msg)
  56. _ALL_BINARY_OPERATORS = [
  57. operator.lt,
  58. operator.le,
  59. operator.eq,
  60. operator.ne,
  61. operator.gt,
  62. operator.ge,
  63. operator.add,
  64. operator.sub,
  65. operator.mul,
  66. operator.truediv,
  67. operator.floordiv,
  68. operator.mod,
  69. divmod,
  70. pow,
  71. operator.lshift,
  72. operator.rshift,
  73. operator.and_,
  74. operator.xor,
  75. operator.or_,
  76. ]
  77. class TestNDArrayOperatorsMixin:
  78. def test_array_like_add(self):
  79. def check(result):
  80. _assert_equal_type_and_value(result, ArrayLike(0))
  81. check(ArrayLike(0) + 0)
  82. check(0 + ArrayLike(0))
  83. check(ArrayLike(0) + np.array(0))
  84. check(np.array(0) + ArrayLike(0))
  85. check(ArrayLike(np.array(0)) + 0)
  86. check(0 + ArrayLike(np.array(0)))
  87. check(ArrayLike(np.array(0)) + np.array(0))
  88. check(np.array(0) + ArrayLike(np.array(0)))
  89. def test_inplace(self):
  90. array_like = ArrayLike(np.array([0]))
  91. array_like += 1
  92. _assert_equal_type_and_value(array_like, ArrayLike(np.array([1])))
  93. array = np.array([0])
  94. array += ArrayLike(1)
  95. _assert_equal_type_and_value(array, ArrayLike(np.array([1])))
  96. def test_opt_out(self):
  97. class OptOut:
  98. """Object that opts out of __array_ufunc__."""
  99. __array_ufunc__ = None
  100. def __add__(self, other):
  101. return self
  102. def __radd__(self, other):
  103. return self
  104. array_like = ArrayLike(1)
  105. opt_out = OptOut()
  106. # supported operations
  107. assert_(array_like + opt_out is opt_out)
  108. assert_(opt_out + array_like is opt_out)
  109. # not supported
  110. with assert_raises(TypeError):
  111. # don't use the Python default, array_like = array_like + opt_out
  112. array_like += opt_out
  113. with assert_raises(TypeError):
  114. array_like - opt_out
  115. with assert_raises(TypeError):
  116. opt_out - array_like
  117. def test_subclass(self):
  118. class SubArrayLike(ArrayLike):
  119. """Should take precedence over ArrayLike."""
  120. x = ArrayLike(0)
  121. y = SubArrayLike(1)
  122. _assert_equal_type_and_value(x + y, y)
  123. _assert_equal_type_and_value(y + x, y)
  124. def test_object(self):
  125. x = ArrayLike(0)
  126. obj = object()
  127. with assert_raises(TypeError):
  128. x + obj
  129. with assert_raises(TypeError):
  130. obj + x
  131. with assert_raises(TypeError):
  132. x += obj
  133. def test_unary_methods(self):
  134. array = np.array([-1, 0, 1, 2])
  135. array_like = ArrayLike(array)
  136. for op in [operator.neg,
  137. operator.pos,
  138. abs,
  139. operator.invert]:
  140. _assert_equal_type_and_value(op(array_like), ArrayLike(op(array)))
  141. def test_forward_binary_methods(self):
  142. array = np.array([-1, 0, 1, 2])
  143. array_like = ArrayLike(array)
  144. for op in _ALL_BINARY_OPERATORS:
  145. expected = wrap_array_like(op(array, 1))
  146. actual = op(array_like, 1)
  147. err_msg = 'failed for operator {}'.format(op)
  148. _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
  149. def test_reflected_binary_methods(self):
  150. for op in _ALL_BINARY_OPERATORS:
  151. expected = wrap_array_like(op(2, 1))
  152. actual = op(2, ArrayLike(1))
  153. err_msg = 'failed for operator {}'.format(op)
  154. _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
  155. def test_matmul(self):
  156. array = np.array([1, 2], dtype=np.float64)
  157. array_like = ArrayLike(array)
  158. expected = ArrayLike(np.float64(5))
  159. _assert_equal_type_and_value(expected, np.matmul(array_like, array))
  160. _assert_equal_type_and_value(
  161. expected, operator.matmul(array_like, array))
  162. _assert_equal_type_and_value(
  163. expected, operator.matmul(array, array_like))
  164. def test_ufunc_at(self):
  165. array = ArrayLike(np.array([1, 2, 3, 4]))
  166. assert_(np.negative.at(array, np.array([0, 1])) is None)
  167. _assert_equal_type_and_value(array, ArrayLike([-1, -2, 3, 4]))
  168. def test_ufunc_two_outputs(self):
  169. mantissa, exponent = np.frexp(2 ** -3)
  170. expected = (ArrayLike(mantissa), ArrayLike(exponent))
  171. _assert_equal_type_and_value(
  172. np.frexp(ArrayLike(2 ** -3)), expected)
  173. _assert_equal_type_and_value(
  174. np.frexp(ArrayLike(np.array(2 ** -3))), expected)