test_arrayterator.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from operator import mul
  2. from functools import reduce
  3. import numpy as np
  4. from numpy.random import randint
  5. from numpy.lib import Arrayterator
  6. from numpy.testing import assert_
  7. def test():
  8. np.random.seed(np.arange(10))
  9. # Create a random array
  10. ndims = randint(5)+1
  11. shape = tuple(randint(10)+1 for dim in range(ndims))
  12. els = reduce(mul, shape)
  13. a = np.arange(els)
  14. a.shape = shape
  15. buf_size = randint(2*els)
  16. b = Arrayterator(a, buf_size)
  17. # Check that each block has at most ``buf_size`` elements
  18. for block in b:
  19. assert_(len(block.flat) <= (buf_size or els))
  20. # Check that all elements are iterated correctly
  21. assert_(list(b.flat) == list(a.flat))
  22. # Slice arrayterator
  23. start = [randint(dim) for dim in shape]
  24. stop = [randint(dim)+1 for dim in shape]
  25. step = [randint(dim)+1 for dim in shape]
  26. slice_ = tuple(slice(*t) for t in zip(start, stop, step))
  27. c = b[slice_]
  28. d = a[slice_]
  29. # Check that each block has at most ``buf_size`` elements
  30. for block in c:
  31. assert_(len(block.flat) <= (buf_size or els))
  32. # Check that the arrayterator is sliced correctly
  33. assert_(np.all(c.__array__() == d))
  34. # Check that all elements are iterated correctly
  35. assert_(list(c.flat) == list(d.flat))