test_overrides.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. import inspect
  2. import sys
  3. import os
  4. import tempfile
  5. from io import StringIO
  6. from unittest import mock
  7. import numpy as np
  8. from numpy.testing import (
  9. assert_, assert_equal, assert_raises, assert_raises_regex)
  10. from numpy.core.overrides import (
  11. _get_implementing_args, array_function_dispatch,
  12. verify_matching_signatures, ARRAY_FUNCTION_ENABLED)
  13. from numpy.compat import pickle
  14. import pytest
  15. requires_array_function = pytest.mark.skipif(
  16. not ARRAY_FUNCTION_ENABLED,
  17. reason="__array_function__ dispatch not enabled.")
  18. def _return_not_implemented(self, *args, **kwargs):
  19. return NotImplemented
  20. # need to define this at the top level to test pickling
  21. @array_function_dispatch(lambda array: (array,))
  22. def dispatched_one_arg(array):
  23. """Docstring."""
  24. return 'original'
  25. @array_function_dispatch(lambda array1, array2: (array1, array2))
  26. def dispatched_two_arg(array1, array2):
  27. """Docstring."""
  28. return 'original'
  29. class TestGetImplementingArgs:
  30. def test_ndarray(self):
  31. array = np.array(1)
  32. args = _get_implementing_args([array])
  33. assert_equal(list(args), [array])
  34. args = _get_implementing_args([array, array])
  35. assert_equal(list(args), [array])
  36. args = _get_implementing_args([array, 1])
  37. assert_equal(list(args), [array])
  38. args = _get_implementing_args([1, array])
  39. assert_equal(list(args), [array])
  40. def test_ndarray_subclasses(self):
  41. class OverrideSub(np.ndarray):
  42. __array_function__ = _return_not_implemented
  43. class NoOverrideSub(np.ndarray):
  44. pass
  45. array = np.array(1).view(np.ndarray)
  46. override_sub = np.array(1).view(OverrideSub)
  47. no_override_sub = np.array(1).view(NoOverrideSub)
  48. args = _get_implementing_args([array, override_sub])
  49. assert_equal(list(args), [override_sub, array])
  50. args = _get_implementing_args([array, no_override_sub])
  51. assert_equal(list(args), [no_override_sub, array])
  52. args = _get_implementing_args(
  53. [override_sub, no_override_sub])
  54. assert_equal(list(args), [override_sub, no_override_sub])
  55. def test_ndarray_and_duck_array(self):
  56. class Other:
  57. __array_function__ = _return_not_implemented
  58. array = np.array(1)
  59. other = Other()
  60. args = _get_implementing_args([other, array])
  61. assert_equal(list(args), [other, array])
  62. args = _get_implementing_args([array, other])
  63. assert_equal(list(args), [array, other])
  64. def test_ndarray_subclass_and_duck_array(self):
  65. class OverrideSub(np.ndarray):
  66. __array_function__ = _return_not_implemented
  67. class Other:
  68. __array_function__ = _return_not_implemented
  69. array = np.array(1)
  70. subarray = np.array(1).view(OverrideSub)
  71. other = Other()
  72. assert_equal(_get_implementing_args([array, subarray, other]),
  73. [subarray, array, other])
  74. assert_equal(_get_implementing_args([array, other, subarray]),
  75. [subarray, array, other])
  76. def test_many_duck_arrays(self):
  77. class A:
  78. __array_function__ = _return_not_implemented
  79. class B(A):
  80. __array_function__ = _return_not_implemented
  81. class C(A):
  82. __array_function__ = _return_not_implemented
  83. class D:
  84. __array_function__ = _return_not_implemented
  85. a = A()
  86. b = B()
  87. c = C()
  88. d = D()
  89. assert_equal(_get_implementing_args([1]), [])
  90. assert_equal(_get_implementing_args([a]), [a])
  91. assert_equal(_get_implementing_args([a, 1]), [a])
  92. assert_equal(_get_implementing_args([a, a, a]), [a])
  93. assert_equal(_get_implementing_args([a, d, a]), [a, d])
  94. assert_equal(_get_implementing_args([a, b]), [b, a])
  95. assert_equal(_get_implementing_args([b, a]), [b, a])
  96. assert_equal(_get_implementing_args([a, b, c]), [b, c, a])
  97. assert_equal(_get_implementing_args([a, c, b]), [c, b, a])
  98. def test_too_many_duck_arrays(self):
  99. namespace = dict(__array_function__=_return_not_implemented)
  100. types = [type('A' + str(i), (object,), namespace) for i in range(33)]
  101. relevant_args = [t() for t in types]
  102. actual = _get_implementing_args(relevant_args[:32])
  103. assert_equal(actual, relevant_args[:32])
  104. with assert_raises_regex(TypeError, 'distinct argument types'):
  105. _get_implementing_args(relevant_args)
  106. class TestNDArrayArrayFunction:
  107. @requires_array_function
  108. def test_method(self):
  109. class Other:
  110. __array_function__ = _return_not_implemented
  111. class NoOverrideSub(np.ndarray):
  112. pass
  113. class OverrideSub(np.ndarray):
  114. __array_function__ = _return_not_implemented
  115. array = np.array([1])
  116. other = Other()
  117. no_override_sub = array.view(NoOverrideSub)
  118. override_sub = array.view(OverrideSub)
  119. result = array.__array_function__(func=dispatched_two_arg,
  120. types=(np.ndarray,),
  121. args=(array, 1.), kwargs={})
  122. assert_equal(result, 'original')
  123. result = array.__array_function__(func=dispatched_two_arg,
  124. types=(np.ndarray, Other),
  125. args=(array, other), kwargs={})
  126. assert_(result is NotImplemented)
  127. result = array.__array_function__(func=dispatched_two_arg,
  128. types=(np.ndarray, NoOverrideSub),
  129. args=(array, no_override_sub),
  130. kwargs={})
  131. assert_equal(result, 'original')
  132. result = array.__array_function__(func=dispatched_two_arg,
  133. types=(np.ndarray, OverrideSub),
  134. args=(array, override_sub),
  135. kwargs={})
  136. assert_equal(result, 'original')
  137. with assert_raises_regex(TypeError, 'no implementation found'):
  138. np.concatenate((array, other))
  139. expected = np.concatenate((array, array))
  140. result = np.concatenate((array, no_override_sub))
  141. assert_equal(result, expected.view(NoOverrideSub))
  142. result = np.concatenate((array, override_sub))
  143. assert_equal(result, expected.view(OverrideSub))
  144. def test_no_wrapper(self):
  145. # This shouldn't happen unless a user intentionally calls
  146. # __array_function__ with invalid arguments, but check that we raise
  147. # an appropriate error all the same.
  148. array = np.array(1)
  149. func = lambda x: x
  150. with assert_raises_regex(AttributeError, '_implementation'):
  151. array.__array_function__(func=func, types=(np.ndarray,),
  152. args=(array,), kwargs={})
  153. @requires_array_function
  154. class TestArrayFunctionDispatch:
  155. def test_pickle(self):
  156. for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
  157. roundtripped = pickle.loads(
  158. pickle.dumps(dispatched_one_arg, protocol=proto))
  159. assert_(roundtripped is dispatched_one_arg)
  160. def test_name_and_docstring(self):
  161. assert_equal(dispatched_one_arg.__name__, 'dispatched_one_arg')
  162. if sys.flags.optimize < 2:
  163. assert_equal(dispatched_one_arg.__doc__, 'Docstring.')
  164. def test_interface(self):
  165. class MyArray:
  166. def __array_function__(self, func, types, args, kwargs):
  167. return (self, func, types, args, kwargs)
  168. original = MyArray()
  169. (obj, func, types, args, kwargs) = dispatched_one_arg(original)
  170. assert_(obj is original)
  171. assert_(func is dispatched_one_arg)
  172. assert_equal(set(types), {MyArray})
  173. # assert_equal uses the overloaded np.iscomplexobj() internally
  174. assert_(args == (original,))
  175. assert_equal(kwargs, {})
  176. def test_not_implemented(self):
  177. class MyArray:
  178. def __array_function__(self, func, types, args, kwargs):
  179. return NotImplemented
  180. array = MyArray()
  181. with assert_raises_regex(TypeError, 'no implementation found'):
  182. dispatched_one_arg(array)
  183. @requires_array_function
  184. class TestVerifyMatchingSignatures:
  185. def test_verify_matching_signatures(self):
  186. verify_matching_signatures(lambda x: 0, lambda x: 0)
  187. verify_matching_signatures(lambda x=None: 0, lambda x=None: 0)
  188. verify_matching_signatures(lambda x=1: 0, lambda x=None: 0)
  189. with assert_raises(RuntimeError):
  190. verify_matching_signatures(lambda a: 0, lambda b: 0)
  191. with assert_raises(RuntimeError):
  192. verify_matching_signatures(lambda x: 0, lambda x=None: 0)
  193. with assert_raises(RuntimeError):
  194. verify_matching_signatures(lambda x=None: 0, lambda y=None: 0)
  195. with assert_raises(RuntimeError):
  196. verify_matching_signatures(lambda x=1: 0, lambda y=1: 0)
  197. def test_array_function_dispatch(self):
  198. with assert_raises(RuntimeError):
  199. @array_function_dispatch(lambda x: (x,))
  200. def f(y):
  201. pass
  202. # should not raise
  203. @array_function_dispatch(lambda x: (x,), verify=False)
  204. def f(y):
  205. pass
  206. def _new_duck_type_and_implements():
  207. """Create a duck array type and implements functions."""
  208. HANDLED_FUNCTIONS = {}
  209. class MyArray:
  210. def __array_function__(self, func, types, args, kwargs):
  211. if func not in HANDLED_FUNCTIONS:
  212. return NotImplemented
  213. if not all(issubclass(t, MyArray) for t in types):
  214. return NotImplemented
  215. return HANDLED_FUNCTIONS[func](*args, **kwargs)
  216. def implements(numpy_function):
  217. """Register an __array_function__ implementations."""
  218. def decorator(func):
  219. HANDLED_FUNCTIONS[numpy_function] = func
  220. return func
  221. return decorator
  222. return (MyArray, implements)
  223. @requires_array_function
  224. class TestArrayFunctionImplementation:
  225. def test_one_arg(self):
  226. MyArray, implements = _new_duck_type_and_implements()
  227. @implements(dispatched_one_arg)
  228. def _(array):
  229. return 'myarray'
  230. assert_equal(dispatched_one_arg(1), 'original')
  231. assert_equal(dispatched_one_arg(MyArray()), 'myarray')
  232. def test_optional_args(self):
  233. MyArray, implements = _new_duck_type_and_implements()
  234. @array_function_dispatch(lambda array, option=None: (array,))
  235. def func_with_option(array, option='default'):
  236. return option
  237. @implements(func_with_option)
  238. def my_array_func_with_option(array, new_option='myarray'):
  239. return new_option
  240. # we don't need to implement every option on __array_function__
  241. # implementations
  242. assert_equal(func_with_option(1), 'default')
  243. assert_equal(func_with_option(1, option='extra'), 'extra')
  244. assert_equal(func_with_option(MyArray()), 'myarray')
  245. with assert_raises(TypeError):
  246. func_with_option(MyArray(), option='extra')
  247. # but new options on implementations can't be used
  248. result = my_array_func_with_option(MyArray(), new_option='yes')
  249. assert_equal(result, 'yes')
  250. with assert_raises(TypeError):
  251. func_with_option(MyArray(), new_option='no')
  252. def test_not_implemented(self):
  253. MyArray, implements = _new_duck_type_and_implements()
  254. @array_function_dispatch(lambda array: (array,), module='my')
  255. def func(array):
  256. return array
  257. array = np.array(1)
  258. assert_(func(array) is array)
  259. assert_equal(func.__module__, 'my')
  260. with assert_raises_regex(
  261. TypeError, "no implementation found for 'my.func'"):
  262. func(MyArray())
  263. def test_signature_error_message(self):
  264. # The lambda function will be named "<lambda>", but the TypeError
  265. # should show the name as "func"
  266. def _dispatcher():
  267. return ()
  268. @array_function_dispatch(_dispatcher)
  269. def func():
  270. pass
  271. try:
  272. func(bad_arg=3)
  273. except TypeError as e:
  274. expected_exception = e
  275. try:
  276. func(bad_arg=3)
  277. raise AssertionError("must fail")
  278. except TypeError as exc:
  279. assert exc.args == expected_exception.args
  280. @pytest.mark.parametrize("value", [234, "this func is not replaced"])
  281. def test_dispatcher_error(self, value):
  282. # If the dispatcher raises an error, we must not attempt to mutate it
  283. error = TypeError(value)
  284. def dispatcher():
  285. raise error
  286. @array_function_dispatch(dispatcher)
  287. def func():
  288. return 3
  289. try:
  290. func()
  291. raise AssertionError("must fail")
  292. except TypeError as exc:
  293. assert exc is error # unmodified exception
  294. class TestNDArrayMethods:
  295. def test_repr(self):
  296. # gh-12162: should still be defined even if __array_function__ doesn't
  297. # implement np.array_repr()
  298. class MyArray(np.ndarray):
  299. def __array_function__(*args, **kwargs):
  300. return NotImplemented
  301. array = np.array(1).view(MyArray)
  302. assert_equal(repr(array), 'MyArray(1)')
  303. assert_equal(str(array), '1')
  304. class TestNumPyFunctions:
  305. def test_set_module(self):
  306. assert_equal(np.sum.__module__, 'numpy')
  307. assert_equal(np.char.equal.__module__, 'numpy.char')
  308. assert_equal(np.fft.fft.__module__, 'numpy.fft')
  309. assert_equal(np.linalg.solve.__module__, 'numpy.linalg')
  310. def test_inspect_sum(self):
  311. signature = inspect.signature(np.sum)
  312. assert_('axis' in signature.parameters)
  313. @requires_array_function
  314. def test_override_sum(self):
  315. MyArray, implements = _new_duck_type_and_implements()
  316. @implements(np.sum)
  317. def _(array):
  318. return 'yes'
  319. assert_equal(np.sum(MyArray()), 'yes')
  320. @requires_array_function
  321. def test_sum_on_mock_array(self):
  322. # We need a proxy for mocks because __array_function__ is only looked
  323. # up in the class dict
  324. class ArrayProxy:
  325. def __init__(self, value):
  326. self.value = value
  327. def __array_function__(self, *args, **kwargs):
  328. return self.value.__array_function__(*args, **kwargs)
  329. def __array__(self, *args, **kwargs):
  330. return self.value.__array__(*args, **kwargs)
  331. proxy = ArrayProxy(mock.Mock(spec=ArrayProxy))
  332. proxy.value.__array_function__.return_value = 1
  333. result = np.sum(proxy)
  334. assert_equal(result, 1)
  335. proxy.value.__array_function__.assert_called_once_with(
  336. np.sum, (ArrayProxy,), (proxy,), {})
  337. proxy.value.__array__.assert_not_called()
  338. @requires_array_function
  339. def test_sum_forwarding_implementation(self):
  340. class MyArray(np.ndarray):
  341. def sum(self, axis, out):
  342. return 'summed'
  343. def __array_function__(self, func, types, args, kwargs):
  344. return super().__array_function__(func, types, args, kwargs)
  345. # note: the internal implementation of np.sum() calls the .sum() method
  346. array = np.array(1).view(MyArray)
  347. assert_equal(np.sum(array), 'summed')
  348. class TestArrayLike:
  349. def setup_method(self):
  350. class MyArray():
  351. def __init__(self, function=None):
  352. self.function = function
  353. def __array_function__(self, func, types, args, kwargs):
  354. assert func is getattr(np, func.__name__)
  355. try:
  356. my_func = getattr(self, func.__name__)
  357. except AttributeError:
  358. return NotImplemented
  359. return my_func(*args, **kwargs)
  360. self.MyArray = MyArray
  361. class MyNoArrayFunctionArray():
  362. def __init__(self, function=None):
  363. self.function = function
  364. self.MyNoArrayFunctionArray = MyNoArrayFunctionArray
  365. def add_method(self, name, arr_class, enable_value_error=False):
  366. def _definition(*args, **kwargs):
  367. # Check that `like=` isn't propagated downstream
  368. assert 'like' not in kwargs
  369. if enable_value_error and 'value_error' in kwargs:
  370. raise ValueError
  371. return arr_class(getattr(arr_class, name))
  372. setattr(arr_class, name, _definition)
  373. def func_args(*args, **kwargs):
  374. return args, kwargs
  375. @requires_array_function
  376. def test_array_like_not_implemented(self):
  377. self.add_method('array', self.MyArray)
  378. ref = self.MyArray.array()
  379. with assert_raises_regex(TypeError, 'no implementation found'):
  380. array_like = np.asarray(1, like=ref)
  381. _array_tests = [
  382. ('array', *func_args((1,))),
  383. ('asarray', *func_args((1,))),
  384. ('asanyarray', *func_args((1,))),
  385. ('ascontiguousarray', *func_args((2, 3))),
  386. ('asfortranarray', *func_args((2, 3))),
  387. ('require', *func_args((np.arange(6).reshape(2, 3),),
  388. requirements=['A', 'F'])),
  389. ('empty', *func_args((1,))),
  390. ('full', *func_args((1,), 2)),
  391. ('ones', *func_args((1,))),
  392. ('zeros', *func_args((1,))),
  393. ('arange', *func_args(3)),
  394. ('frombuffer', *func_args(b'\x00' * 8, dtype=int)),
  395. ('fromiter', *func_args(range(3), dtype=int)),
  396. ('fromstring', *func_args('1,2', dtype=int, sep=',')),
  397. ('loadtxt', *func_args(lambda: StringIO('0 1\n2 3'))),
  398. ('genfromtxt', *func_args(lambda: StringIO('1,2.1'),
  399. dtype=[('int', 'i8'), ('float', 'f8')],
  400. delimiter=',')),
  401. ]
  402. @pytest.mark.parametrize('function, args, kwargs', _array_tests)
  403. @pytest.mark.parametrize('numpy_ref', [True, False])
  404. @requires_array_function
  405. def test_array_like(self, function, args, kwargs, numpy_ref):
  406. self.add_method('array', self.MyArray)
  407. self.add_method(function, self.MyArray)
  408. np_func = getattr(np, function)
  409. my_func = getattr(self.MyArray, function)
  410. if numpy_ref is True:
  411. ref = np.array(1)
  412. else:
  413. ref = self.MyArray.array()
  414. like_args = tuple(a() if callable(a) else a for a in args)
  415. array_like = np_func(*like_args, **kwargs, like=ref)
  416. if numpy_ref is True:
  417. assert type(array_like) is np.ndarray
  418. np_args = tuple(a() if callable(a) else a for a in args)
  419. np_arr = np_func(*np_args, **kwargs)
  420. # Special-case np.empty to ensure values match
  421. if function == "empty":
  422. np_arr.fill(1)
  423. array_like.fill(1)
  424. assert_equal(array_like, np_arr)
  425. else:
  426. assert type(array_like) is self.MyArray
  427. assert array_like.function is my_func
  428. @pytest.mark.parametrize('function, args, kwargs', _array_tests)
  429. @pytest.mark.parametrize('ref', [1, [1], "MyNoArrayFunctionArray"])
  430. @requires_array_function
  431. def test_no_array_function_like(self, function, args, kwargs, ref):
  432. self.add_method('array', self.MyNoArrayFunctionArray)
  433. self.add_method(function, self.MyNoArrayFunctionArray)
  434. np_func = getattr(np, function)
  435. # Instantiate ref if it's the MyNoArrayFunctionArray class
  436. if ref == "MyNoArrayFunctionArray":
  437. ref = self.MyNoArrayFunctionArray.array()
  438. like_args = tuple(a() if callable(a) else a for a in args)
  439. with assert_raises_regex(TypeError,
  440. 'The `like` argument must be an array-like that implements'):
  441. np_func(*like_args, **kwargs, like=ref)
  442. @pytest.mark.parametrize('numpy_ref', [True, False])
  443. def test_array_like_fromfile(self, numpy_ref):
  444. self.add_method('array', self.MyArray)
  445. self.add_method("fromfile", self.MyArray)
  446. if numpy_ref is True:
  447. ref = np.array(1)
  448. else:
  449. ref = self.MyArray.array()
  450. data = np.random.random(5)
  451. with tempfile.TemporaryDirectory() as tmpdir:
  452. fname = os.path.join(tmpdir, "testfile")
  453. data.tofile(fname)
  454. array_like = np.fromfile(fname, like=ref)
  455. if numpy_ref is True:
  456. assert type(array_like) is np.ndarray
  457. np_res = np.fromfile(fname, like=ref)
  458. assert_equal(np_res, data)
  459. assert_equal(array_like, np_res)
  460. else:
  461. assert type(array_like) is self.MyArray
  462. assert array_like.function is self.MyArray.fromfile
  463. @requires_array_function
  464. def test_exception_handling(self):
  465. self.add_method('array', self.MyArray, enable_value_error=True)
  466. ref = self.MyArray.array()
  467. with assert_raises(TypeError):
  468. # Raises the error about `value_error` being invalid first
  469. np.array(1, value_error=True, like=ref)
  470. @pytest.mark.parametrize('function, args, kwargs', _array_tests)
  471. def test_like_as_none(self, function, args, kwargs):
  472. self.add_method('array', self.MyArray)
  473. self.add_method(function, self.MyArray)
  474. np_func = getattr(np, function)
  475. like_args = tuple(a() if callable(a) else a for a in args)
  476. # required for loadtxt and genfromtxt to init w/o error.
  477. like_args_exp = tuple(a() if callable(a) else a for a in args)
  478. array_like = np_func(*like_args, **kwargs, like=None)
  479. expected = np_func(*like_args_exp, **kwargs)
  480. # Special-case np.empty to ensure values match
  481. if function == "empty":
  482. array_like.fill(1)
  483. expected.fill(1)
  484. assert_equal(array_like, expected)