test_array_object.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. import operator
  2. from numpy.testing import assert_raises
  3. import numpy as np
  4. import pytest
  5. from .. import ones, asarray, reshape, result_type, all, equal
  6. from .._array_object import Array
  7. from .._dtypes import (
  8. _all_dtypes,
  9. _boolean_dtypes,
  10. _floating_dtypes,
  11. _integer_dtypes,
  12. _integer_or_boolean_dtypes,
  13. _numeric_dtypes,
  14. int8,
  15. int16,
  16. int32,
  17. int64,
  18. uint64,
  19. bool as bool_,
  20. )
  21. def test_validate_index():
  22. # The indexing tests in the official array API test suite test that the
  23. # array object correctly handles the subset of indices that are required
  24. # by the spec. But the NumPy array API implementation specifically
  25. # disallows any index not required by the spec, via Array._validate_index.
  26. # This test focuses on testing that non-valid indices are correctly
  27. # rejected. See
  28. # https://data-apis.org/array-api/latest/API_specification/indexing.html
  29. # and the docstring of Array._validate_index for the exact indexing
  30. # behavior that should be allowed. This does not test indices that are
  31. # already invalid in NumPy itself because Array will generally just pass
  32. # such indices directly to the underlying np.ndarray.
  33. a = ones((3, 4))
  34. # Out of bounds slices are not allowed
  35. assert_raises(IndexError, lambda: a[:4])
  36. assert_raises(IndexError, lambda: a[:-4])
  37. assert_raises(IndexError, lambda: a[:3:-1])
  38. assert_raises(IndexError, lambda: a[:-5:-1])
  39. assert_raises(IndexError, lambda: a[4:])
  40. assert_raises(IndexError, lambda: a[-4:])
  41. assert_raises(IndexError, lambda: a[4::-1])
  42. assert_raises(IndexError, lambda: a[-4::-1])
  43. assert_raises(IndexError, lambda: a[...,:5])
  44. assert_raises(IndexError, lambda: a[...,:-5])
  45. assert_raises(IndexError, lambda: a[...,:5:-1])
  46. assert_raises(IndexError, lambda: a[...,:-6:-1])
  47. assert_raises(IndexError, lambda: a[...,5:])
  48. assert_raises(IndexError, lambda: a[...,-5:])
  49. assert_raises(IndexError, lambda: a[...,5::-1])
  50. assert_raises(IndexError, lambda: a[...,-5::-1])
  51. # Boolean indices cannot be part of a larger tuple index
  52. assert_raises(IndexError, lambda: a[a[:,0]==1,0])
  53. assert_raises(IndexError, lambda: a[a[:,0]==1,...])
  54. assert_raises(IndexError, lambda: a[..., a[0]==1])
  55. assert_raises(IndexError, lambda: a[[True, True, True]])
  56. assert_raises(IndexError, lambda: a[(True, True, True),])
  57. # Integer array indices are not allowed (except for 0-D)
  58. idx = asarray([[0, 1]])
  59. assert_raises(IndexError, lambda: a[idx])
  60. assert_raises(IndexError, lambda: a[idx,])
  61. assert_raises(IndexError, lambda: a[[0, 1]])
  62. assert_raises(IndexError, lambda: a[(0, 1), (0, 1)])
  63. assert_raises(IndexError, lambda: a[[0, 1]])
  64. assert_raises(IndexError, lambda: a[np.array([[0, 1]])])
  65. # Multiaxis indices must contain exactly as many indices as dimensions
  66. assert_raises(IndexError, lambda: a[()])
  67. assert_raises(IndexError, lambda: a[0,])
  68. assert_raises(IndexError, lambda: a[0])
  69. assert_raises(IndexError, lambda: a[:])
  70. def test_operators():
  71. # For every operator, we test that it works for the required type
  72. # combinations and raises TypeError otherwise
  73. binary_op_dtypes = {
  74. "__add__": "numeric",
  75. "__and__": "integer_or_boolean",
  76. "__eq__": "all",
  77. "__floordiv__": "numeric",
  78. "__ge__": "numeric",
  79. "__gt__": "numeric",
  80. "__le__": "numeric",
  81. "__lshift__": "integer",
  82. "__lt__": "numeric",
  83. "__mod__": "numeric",
  84. "__mul__": "numeric",
  85. "__ne__": "all",
  86. "__or__": "integer_or_boolean",
  87. "__pow__": "numeric",
  88. "__rshift__": "integer",
  89. "__sub__": "numeric",
  90. "__truediv__": "floating",
  91. "__xor__": "integer_or_boolean",
  92. }
  93. # Recompute each time because of in-place ops
  94. def _array_vals():
  95. for d in _integer_dtypes:
  96. yield asarray(1, dtype=d)
  97. for d in _boolean_dtypes:
  98. yield asarray(False, dtype=d)
  99. for d in _floating_dtypes:
  100. yield asarray(1.0, dtype=d)
  101. for op, dtypes in binary_op_dtypes.items():
  102. ops = [op]
  103. if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
  104. rop = "__r" + op[2:]
  105. iop = "__i" + op[2:]
  106. ops += [rop, iop]
  107. for s in [1, 1.0, False]:
  108. for _op in ops:
  109. for a in _array_vals():
  110. # Test array op scalar. From the spec, the following combinations
  111. # are supported:
  112. # - Python bool for a bool array dtype,
  113. # - a Python int within the bounds of the given dtype for integer array dtypes,
  114. # - a Python int or float for floating-point array dtypes
  115. # We do not do bounds checking for int scalars, but rather use the default
  116. # NumPy behavior for casting in that case.
  117. if ((dtypes == "all"
  118. or dtypes == "numeric" and a.dtype in _numeric_dtypes
  119. or dtypes == "integer" and a.dtype in _integer_dtypes
  120. or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
  121. or dtypes == "boolean" and a.dtype in _boolean_dtypes
  122. or dtypes == "floating" and a.dtype in _floating_dtypes
  123. )
  124. # bool is a subtype of int, which is why we avoid
  125. # isinstance here.
  126. and (a.dtype in _boolean_dtypes and type(s) == bool
  127. or a.dtype in _integer_dtypes and type(s) == int
  128. or a.dtype in _floating_dtypes and type(s) in [float, int]
  129. )):
  130. # Only test for no error
  131. getattr(a, _op)(s)
  132. else:
  133. assert_raises(TypeError, lambda: getattr(a, _op)(s))
  134. # Test array op array.
  135. for _op in ops:
  136. for x in _array_vals():
  137. for y in _array_vals():
  138. # See the promotion table in NEP 47 or the array
  139. # API spec page on type promotion. Mixed kind
  140. # promotion is not defined.
  141. if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
  142. or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
  143. or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
  144. or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
  145. or x.dtype in _boolean_dtypes and y.dtype not in _boolean_dtypes
  146. or y.dtype in _boolean_dtypes and x.dtype not in _boolean_dtypes
  147. or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
  148. or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
  149. ):
  150. assert_raises(TypeError, lambda: getattr(x, _op)(y))
  151. # Ensure in-place operators only promote to the same dtype as the left operand.
  152. elif (
  153. _op.startswith("__i")
  154. and result_type(x.dtype, y.dtype) != x.dtype
  155. ):
  156. assert_raises(TypeError, lambda: getattr(x, _op)(y))
  157. # Ensure only those dtypes that are required for every operator are allowed.
  158. elif (dtypes == "all" and (x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
  159. or x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
  160. or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
  161. or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _numeric_dtypes
  162. or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
  163. or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes)
  164. or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
  165. or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
  166. ):
  167. getattr(x, _op)(y)
  168. else:
  169. assert_raises(TypeError, lambda: getattr(x, _op)(y))
  170. unary_op_dtypes = {
  171. "__abs__": "numeric",
  172. "__invert__": "integer_or_boolean",
  173. "__neg__": "numeric",
  174. "__pos__": "numeric",
  175. }
  176. for op, dtypes in unary_op_dtypes.items():
  177. for a in _array_vals():
  178. if (
  179. dtypes == "numeric"
  180. and a.dtype in _numeric_dtypes
  181. or dtypes == "integer_or_boolean"
  182. and a.dtype in _integer_or_boolean_dtypes
  183. ):
  184. # Only test for no error
  185. getattr(a, op)()
  186. else:
  187. assert_raises(TypeError, lambda: getattr(a, op)())
  188. # Finally, matmul() must be tested separately, because it works a bit
  189. # different from the other operations.
  190. def _matmul_array_vals():
  191. for a in _array_vals():
  192. yield a
  193. for d in _all_dtypes:
  194. yield ones((3, 4), dtype=d)
  195. yield ones((4, 2), dtype=d)
  196. yield ones((4, 4), dtype=d)
  197. # Scalars always error
  198. for _op in ["__matmul__", "__rmatmul__", "__imatmul__"]:
  199. for s in [1, 1.0, False]:
  200. for a in _matmul_array_vals():
  201. if (type(s) in [float, int] and a.dtype in _floating_dtypes
  202. or type(s) == int and a.dtype in _integer_dtypes):
  203. # Type promotion is valid, but @ is not allowed on 0-D
  204. # inputs, so the error is a ValueError
  205. assert_raises(ValueError, lambda: getattr(a, _op)(s))
  206. else:
  207. assert_raises(TypeError, lambda: getattr(a, _op)(s))
  208. for x in _matmul_array_vals():
  209. for y in _matmul_array_vals():
  210. if (x.dtype == uint64 and y.dtype in [int8, int16, int32, int64]
  211. or y.dtype == uint64 and x.dtype in [int8, int16, int32, int64]
  212. or x.dtype in _integer_dtypes and y.dtype not in _integer_dtypes
  213. or y.dtype in _integer_dtypes and x.dtype not in _integer_dtypes
  214. or x.dtype in _floating_dtypes and y.dtype not in _floating_dtypes
  215. or y.dtype in _floating_dtypes and x.dtype not in _floating_dtypes
  216. or x.dtype in _boolean_dtypes
  217. or y.dtype in _boolean_dtypes
  218. ):
  219. assert_raises(TypeError, lambda: x.__matmul__(y))
  220. assert_raises(TypeError, lambda: y.__rmatmul__(x))
  221. assert_raises(TypeError, lambda: x.__imatmul__(y))
  222. elif x.shape == () or y.shape == () or x.shape[1] != y.shape[0]:
  223. assert_raises(ValueError, lambda: x.__matmul__(y))
  224. assert_raises(ValueError, lambda: y.__rmatmul__(x))
  225. if result_type(x.dtype, y.dtype) != x.dtype:
  226. assert_raises(TypeError, lambda: x.__imatmul__(y))
  227. else:
  228. assert_raises(ValueError, lambda: x.__imatmul__(y))
  229. else:
  230. x.__matmul__(y)
  231. y.__rmatmul__(x)
  232. if result_type(x.dtype, y.dtype) != x.dtype:
  233. assert_raises(TypeError, lambda: x.__imatmul__(y))
  234. elif y.shape[0] != y.shape[1]:
  235. # This one fails because x @ y has a different shape from x
  236. assert_raises(ValueError, lambda: x.__imatmul__(y))
  237. else:
  238. x.__imatmul__(y)
  239. def test_python_scalar_construtors():
  240. b = asarray(False)
  241. i = asarray(0)
  242. f = asarray(0.0)
  243. assert bool(b) == False
  244. assert int(i) == 0
  245. assert float(f) == 0.0
  246. assert operator.index(i) == 0
  247. # bool/int/float should only be allowed on 0-D arrays.
  248. assert_raises(TypeError, lambda: bool(asarray([False])))
  249. assert_raises(TypeError, lambda: int(asarray([0])))
  250. assert_raises(TypeError, lambda: float(asarray([0.0])))
  251. assert_raises(TypeError, lambda: operator.index(asarray([0])))
  252. # bool/int/float should only be allowed on arrays of the corresponding
  253. # dtype
  254. assert_raises(ValueError, lambda: bool(i))
  255. assert_raises(ValueError, lambda: bool(f))
  256. assert_raises(ValueError, lambda: int(b))
  257. assert_raises(ValueError, lambda: int(f))
  258. assert_raises(ValueError, lambda: float(b))
  259. assert_raises(ValueError, lambda: float(i))
  260. assert_raises(TypeError, lambda: operator.index(b))
  261. assert_raises(TypeError, lambda: operator.index(f))
  262. def test_device_property():
  263. a = ones((3, 4))
  264. assert a.device == 'cpu'
  265. assert all(equal(a.to_device('cpu'), a))
  266. assert_raises(ValueError, lambda: a.to_device('gpu'))
  267. assert all(equal(asarray(a, device='cpu'), a))
  268. assert_raises(ValueError, lambda: asarray(a, device='gpu'))
  269. def test_array_properties():
  270. a = ones((1, 2, 3))
  271. b = ones((2, 3))
  272. assert_raises(ValueError, lambda: a.T)
  273. assert isinstance(b.T, Array)
  274. assert b.T.shape == (3, 2)
  275. assert isinstance(a.mT, Array)
  276. assert a.mT.shape == (1, 3, 2)
  277. assert isinstance(b.mT, Array)
  278. assert b.mT.shape == (3, 2)
  279. def test___array__():
  280. a = ones((2, 3), dtype=int16)
  281. assert np.asarray(a) is a._array
  282. b = np.asarray(a, dtype=np.float64)
  283. assert np.all(np.equal(b, np.ones((2, 3), dtype=np.float64)))
  284. assert b.dtype == np.float64
  285. def test_allow_newaxis():
  286. a = ones(5)
  287. indexed_a = a[None, :]
  288. assert indexed_a.shape == (1, 5)
  289. def test_disallow_flat_indexing_with_newaxis():
  290. a = ones((3, 3, 3))
  291. with pytest.raises(IndexError):
  292. a[None, 0, 0]
  293. def test_disallow_mask_with_newaxis():
  294. a = ones((3, 3, 3))
  295. with pytest.raises(IndexError):
  296. a[None, asarray(True)]
  297. @pytest.mark.parametrize("shape", [(), (5,), (3, 3, 3)])
  298. @pytest.mark.parametrize("index", ["string", False, True])
  299. def test_error_on_invalid_index(shape, index):
  300. a = ones(shape)
  301. with pytest.raises(IndexError):
  302. a[index]
  303. def test_mask_0d_array_without_errors():
  304. a = ones(())
  305. a[asarray(True)]
  306. @pytest.mark.parametrize(
  307. "i", [slice(5), slice(5, 0), asarray(True), asarray([0, 1])]
  308. )
  309. def test_error_on_invalid_index_with_ellipsis(i):
  310. a = ones((3, 3, 3))
  311. with pytest.raises(IndexError):
  312. a[..., i]
  313. with pytest.raises(IndexError):
  314. a[i, ...]
  315. def test_array_keys_use_private_array():
  316. """
  317. Indexing operations convert array keys before indexing the internal array
  318. Fails when array_api array keys are not converted into NumPy-proper arrays
  319. in __getitem__(). This is achieved by passing array_api arrays with 0-sized
  320. dimensions, which NumPy-proper treats erroneously - not sure why!
  321. TODO: Find and use appropriate __setitem__() case.
  322. """
  323. a = ones((0, 0), dtype=bool_)
  324. assert a[a].shape == (0,)
  325. a = ones((0,), dtype=bool_)
  326. key = ones((0, 0), dtype=bool_)
  327. with pytest.raises(IndexError):
  328. a[key]