test_stride_tricks.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. import numpy as np
  2. from numpy.core._rational_tests import rational
  3. from numpy.testing import (
  4. assert_equal, assert_array_equal, assert_raises, assert_,
  5. assert_raises_regex, assert_warns,
  6. )
  7. from numpy.lib.stride_tricks import (
  8. as_strided, broadcast_arrays, _broadcast_shape, broadcast_to,
  9. broadcast_shapes, sliding_window_view,
  10. )
  11. import pytest
  12. def assert_shapes_correct(input_shapes, expected_shape):
  13. # Broadcast a list of arrays with the given input shapes and check the
  14. # common output shape.
  15. inarrays = [np.zeros(s) for s in input_shapes]
  16. outarrays = broadcast_arrays(*inarrays)
  17. outshapes = [a.shape for a in outarrays]
  18. expected = [expected_shape] * len(inarrays)
  19. assert_equal(outshapes, expected)
  20. def assert_incompatible_shapes_raise(input_shapes):
  21. # Broadcast a list of arrays with the given (incompatible) input shapes
  22. # and check that they raise a ValueError.
  23. inarrays = [np.zeros(s) for s in input_shapes]
  24. assert_raises(ValueError, broadcast_arrays, *inarrays)
  25. def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False):
  26. # Broadcast two shapes against each other and check that the data layout
  27. # is the same as if a ufunc did the broadcasting.
  28. x0 = np.zeros(shape0, dtype=int)
  29. # Note that multiply.reduce's identity element is 1.0, so when shape1==(),
  30. # this gives the desired n==1.
  31. n = int(np.multiply.reduce(shape1))
  32. x1 = np.arange(n).reshape(shape1)
  33. if transposed:
  34. x0 = x0.T
  35. x1 = x1.T
  36. if flipped:
  37. x0 = x0[::-1]
  38. x1 = x1[::-1]
  39. # Use the add ufunc to do the broadcasting. Since we're adding 0s to x1, the
  40. # result should be exactly the same as the broadcasted view of x1.
  41. y = x0 + x1
  42. b0, b1 = broadcast_arrays(x0, x1)
  43. assert_array_equal(y, b1)
  44. def test_same():
  45. x = np.arange(10)
  46. y = np.arange(10)
  47. bx, by = broadcast_arrays(x, y)
  48. assert_array_equal(x, bx)
  49. assert_array_equal(y, by)
  50. def test_broadcast_kwargs():
  51. # ensure that a TypeError is appropriately raised when
  52. # np.broadcast_arrays() is called with any keyword
  53. # argument other than 'subok'
  54. x = np.arange(10)
  55. y = np.arange(10)
  56. with assert_raises_regex(TypeError, 'got an unexpected keyword'):
  57. broadcast_arrays(x, y, dtype='float64')
  58. def test_one_off():
  59. x = np.array([[1, 2, 3]])
  60. y = np.array([[1], [2], [3]])
  61. bx, by = broadcast_arrays(x, y)
  62. bx0 = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
  63. by0 = bx0.T
  64. assert_array_equal(bx0, bx)
  65. assert_array_equal(by0, by)
  66. def test_same_input_shapes():
  67. # Check that the final shape is just the input shape.
  68. data = [
  69. (),
  70. (1,),
  71. (3,),
  72. (0, 1),
  73. (0, 3),
  74. (1, 0),
  75. (3, 0),
  76. (1, 3),
  77. (3, 1),
  78. (3, 3),
  79. ]
  80. for shape in data:
  81. input_shapes = [shape]
  82. # Single input.
  83. assert_shapes_correct(input_shapes, shape)
  84. # Double input.
  85. input_shapes2 = [shape, shape]
  86. assert_shapes_correct(input_shapes2, shape)
  87. # Triple input.
  88. input_shapes3 = [shape, shape, shape]
  89. assert_shapes_correct(input_shapes3, shape)
  90. def test_two_compatible_by_ones_input_shapes():
  91. # Check that two different input shapes of the same length, but some have
  92. # ones, broadcast to the correct shape.
  93. data = [
  94. [[(1,), (3,)], (3,)],
  95. [[(1, 3), (3, 3)], (3, 3)],
  96. [[(3, 1), (3, 3)], (3, 3)],
  97. [[(1, 3), (3, 1)], (3, 3)],
  98. [[(1, 1), (3, 3)], (3, 3)],
  99. [[(1, 1), (1, 3)], (1, 3)],
  100. [[(1, 1), (3, 1)], (3, 1)],
  101. [[(1, 0), (0, 0)], (0, 0)],
  102. [[(0, 1), (0, 0)], (0, 0)],
  103. [[(1, 0), (0, 1)], (0, 0)],
  104. [[(1, 1), (0, 0)], (0, 0)],
  105. [[(1, 1), (1, 0)], (1, 0)],
  106. [[(1, 1), (0, 1)], (0, 1)],
  107. ]
  108. for input_shapes, expected_shape in data:
  109. assert_shapes_correct(input_shapes, expected_shape)
  110. # Reverse the input shapes since broadcasting should be symmetric.
  111. assert_shapes_correct(input_shapes[::-1], expected_shape)
  112. def test_two_compatible_by_prepending_ones_input_shapes():
  113. # Check that two different input shapes (of different lengths) broadcast
  114. # to the correct shape.
  115. data = [
  116. [[(), (3,)], (3,)],
  117. [[(3,), (3, 3)], (3, 3)],
  118. [[(3,), (3, 1)], (3, 3)],
  119. [[(1,), (3, 3)], (3, 3)],
  120. [[(), (3, 3)], (3, 3)],
  121. [[(1, 1), (3,)], (1, 3)],
  122. [[(1,), (3, 1)], (3, 1)],
  123. [[(1,), (1, 3)], (1, 3)],
  124. [[(), (1, 3)], (1, 3)],
  125. [[(), (3, 1)], (3, 1)],
  126. [[(), (0,)], (0,)],
  127. [[(0,), (0, 0)], (0, 0)],
  128. [[(0,), (0, 1)], (0, 0)],
  129. [[(1,), (0, 0)], (0, 0)],
  130. [[(), (0, 0)], (0, 0)],
  131. [[(1, 1), (0,)], (1, 0)],
  132. [[(1,), (0, 1)], (0, 1)],
  133. [[(1,), (1, 0)], (1, 0)],
  134. [[(), (1, 0)], (1, 0)],
  135. [[(), (0, 1)], (0, 1)],
  136. ]
  137. for input_shapes, expected_shape in data:
  138. assert_shapes_correct(input_shapes, expected_shape)
  139. # Reverse the input shapes since broadcasting should be symmetric.
  140. assert_shapes_correct(input_shapes[::-1], expected_shape)
  141. def test_incompatible_shapes_raise_valueerror():
  142. # Check that a ValueError is raised for incompatible shapes.
  143. data = [
  144. [(3,), (4,)],
  145. [(2, 3), (2,)],
  146. [(3,), (3,), (4,)],
  147. [(1, 3, 4), (2, 3, 3)],
  148. ]
  149. for input_shapes in data:
  150. assert_incompatible_shapes_raise(input_shapes)
  151. # Reverse the input shapes since broadcasting should be symmetric.
  152. assert_incompatible_shapes_raise(input_shapes[::-1])
  153. def test_same_as_ufunc():
  154. # Check that the data layout is the same as if a ufunc did the operation.
  155. data = [
  156. [[(1,), (3,)], (3,)],
  157. [[(1, 3), (3, 3)], (3, 3)],
  158. [[(3, 1), (3, 3)], (3, 3)],
  159. [[(1, 3), (3, 1)], (3, 3)],
  160. [[(1, 1), (3, 3)], (3, 3)],
  161. [[(1, 1), (1, 3)], (1, 3)],
  162. [[(1, 1), (3, 1)], (3, 1)],
  163. [[(1, 0), (0, 0)], (0, 0)],
  164. [[(0, 1), (0, 0)], (0, 0)],
  165. [[(1, 0), (0, 1)], (0, 0)],
  166. [[(1, 1), (0, 0)], (0, 0)],
  167. [[(1, 1), (1, 0)], (1, 0)],
  168. [[(1, 1), (0, 1)], (0, 1)],
  169. [[(), (3,)], (3,)],
  170. [[(3,), (3, 3)], (3, 3)],
  171. [[(3,), (3, 1)], (3, 3)],
  172. [[(1,), (3, 3)], (3, 3)],
  173. [[(), (3, 3)], (3, 3)],
  174. [[(1, 1), (3,)], (1, 3)],
  175. [[(1,), (3, 1)], (3, 1)],
  176. [[(1,), (1, 3)], (1, 3)],
  177. [[(), (1, 3)], (1, 3)],
  178. [[(), (3, 1)], (3, 1)],
  179. [[(), (0,)], (0,)],
  180. [[(0,), (0, 0)], (0, 0)],
  181. [[(0,), (0, 1)], (0, 0)],
  182. [[(1,), (0, 0)], (0, 0)],
  183. [[(), (0, 0)], (0, 0)],
  184. [[(1, 1), (0,)], (1, 0)],
  185. [[(1,), (0, 1)], (0, 1)],
  186. [[(1,), (1, 0)], (1, 0)],
  187. [[(), (1, 0)], (1, 0)],
  188. [[(), (0, 1)], (0, 1)],
  189. ]
  190. for input_shapes, expected_shape in data:
  191. assert_same_as_ufunc(input_shapes[0], input_shapes[1],
  192. "Shapes: %s %s" % (input_shapes[0], input_shapes[1]))
  193. # Reverse the input shapes since broadcasting should be symmetric.
  194. assert_same_as_ufunc(input_shapes[1], input_shapes[0])
  195. # Try them transposed, too.
  196. assert_same_as_ufunc(input_shapes[0], input_shapes[1], True)
  197. # ... and flipped for non-rank-0 inputs in order to test negative
  198. # strides.
  199. if () not in input_shapes:
  200. assert_same_as_ufunc(input_shapes[0], input_shapes[1], False, True)
  201. assert_same_as_ufunc(input_shapes[0], input_shapes[1], True, True)
  202. def test_broadcast_to_succeeds():
  203. data = [
  204. [np.array(0), (0,), np.array(0)],
  205. [np.array(0), (1,), np.zeros(1)],
  206. [np.array(0), (3,), np.zeros(3)],
  207. [np.ones(1), (1,), np.ones(1)],
  208. [np.ones(1), (2,), np.ones(2)],
  209. [np.ones(1), (1, 2, 3), np.ones((1, 2, 3))],
  210. [np.arange(3), (3,), np.arange(3)],
  211. [np.arange(3), (1, 3), np.arange(3).reshape(1, -1)],
  212. [np.arange(3), (2, 3), np.array([[0, 1, 2], [0, 1, 2]])],
  213. # test if shape is not a tuple
  214. [np.ones(0), 0, np.ones(0)],
  215. [np.ones(1), 1, np.ones(1)],
  216. [np.ones(1), 2, np.ones(2)],
  217. # these cases with size 0 are strange, but they reproduce the behavior
  218. # of broadcasting with ufuncs (see test_same_as_ufunc above)
  219. [np.ones(1), (0,), np.ones(0)],
  220. [np.ones((1, 2)), (0, 2), np.ones((0, 2))],
  221. [np.ones((2, 1)), (2, 0), np.ones((2, 0))],
  222. ]
  223. for input_array, shape, expected in data:
  224. actual = broadcast_to(input_array, shape)
  225. assert_array_equal(expected, actual)
  226. def test_broadcast_to_raises():
  227. data = [
  228. [(0,), ()],
  229. [(1,), ()],
  230. [(3,), ()],
  231. [(3,), (1,)],
  232. [(3,), (2,)],
  233. [(3,), (4,)],
  234. [(1, 2), (2, 1)],
  235. [(1, 1), (1,)],
  236. [(1,), -1],
  237. [(1,), (-1,)],
  238. [(1, 2), (-1, 2)],
  239. ]
  240. for orig_shape, target_shape in data:
  241. arr = np.zeros(orig_shape)
  242. assert_raises(ValueError, lambda: broadcast_to(arr, target_shape))
  243. def test_broadcast_shape():
  244. # tests internal _broadcast_shape
  245. # _broadcast_shape is already exercised indirectly by broadcast_arrays
  246. # _broadcast_shape is also exercised by the public broadcast_shapes function
  247. assert_equal(_broadcast_shape(), ())
  248. assert_equal(_broadcast_shape([1, 2]), (2,))
  249. assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1))
  250. assert_equal(_broadcast_shape(np.ones((1, 1)), np.ones((3, 4))), (3, 4))
  251. assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 32)), (1, 2))
  252. assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 100)), (1, 2))
  253. # regression tests for gh-5862
  254. assert_equal(_broadcast_shape(*([np.ones(2)] * 32 + [1])), (2,))
  255. bad_args = [np.ones(2)] * 32 + [np.ones(3)] * 32
  256. assert_raises(ValueError, lambda: _broadcast_shape(*bad_args))
  257. def test_broadcast_shapes_succeeds():
  258. # tests public broadcast_shapes
  259. data = [
  260. [[], ()],
  261. [[()], ()],
  262. [[(7,)], (7,)],
  263. [[(1, 2), (2,)], (1, 2)],
  264. [[(1, 1)], (1, 1)],
  265. [[(1, 1), (3, 4)], (3, 4)],
  266. [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],
  267. [[(5, 6, 1)], (5, 6, 1)],
  268. [[(1, 3), (3, 1)], (3, 3)],
  269. [[(1, 0), (0, 0)], (0, 0)],
  270. [[(0, 1), (0, 0)], (0, 0)],
  271. [[(1, 0), (0, 1)], (0, 0)],
  272. [[(1, 1), (0, 0)], (0, 0)],
  273. [[(1, 1), (1, 0)], (1, 0)],
  274. [[(1, 1), (0, 1)], (0, 1)],
  275. [[(), (0,)], (0,)],
  276. [[(0,), (0, 0)], (0, 0)],
  277. [[(0,), (0, 1)], (0, 0)],
  278. [[(1,), (0, 0)], (0, 0)],
  279. [[(), (0, 0)], (0, 0)],
  280. [[(1, 1), (0,)], (1, 0)],
  281. [[(1,), (0, 1)], (0, 1)],
  282. [[(1,), (1, 0)], (1, 0)],
  283. [[(), (1, 0)], (1, 0)],
  284. [[(), (0, 1)], (0, 1)],
  285. [[(1,), (3,)], (3,)],
  286. [[2, (3, 2)], (3, 2)],
  287. ]
  288. for input_shapes, target_shape in data:
  289. assert_equal(broadcast_shapes(*input_shapes), target_shape)
  290. assert_equal(broadcast_shapes(*([(1, 2)] * 32)), (1, 2))
  291. assert_equal(broadcast_shapes(*([(1, 2)] * 100)), (1, 2))
  292. # regression tests for gh-5862
  293. assert_equal(broadcast_shapes(*([(2,)] * 32)), (2,))
  294. def test_broadcast_shapes_raises():
  295. # tests public broadcast_shapes
  296. data = [
  297. [(3,), (4,)],
  298. [(2, 3), (2,)],
  299. [(3,), (3,), (4,)],
  300. [(1, 3, 4), (2, 3, 3)],
  301. [(1, 2), (3,1), (3,2), (10, 5)],
  302. [2, (2, 3)],
  303. ]
  304. for input_shapes in data:
  305. assert_raises(ValueError, lambda: broadcast_shapes(*input_shapes))
  306. bad_args = [(2,)] * 32 + [(3,)] * 32
  307. assert_raises(ValueError, lambda: broadcast_shapes(*bad_args))
  308. def test_as_strided():
  309. a = np.array([None])
  310. a_view = as_strided(a)
  311. expected = np.array([None])
  312. assert_array_equal(a_view, np.array([None]))
  313. a = np.array([1, 2, 3, 4])
  314. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
  315. expected = np.array([1, 3])
  316. assert_array_equal(a_view, expected)
  317. a = np.array([1, 2, 3, 4])
  318. a_view = as_strided(a, shape=(3, 4), strides=(0, 1 * a.itemsize))
  319. expected = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
  320. assert_array_equal(a_view, expected)
  321. # Regression test for gh-5081
  322. dt = np.dtype([('num', 'i4'), ('obj', 'O')])
  323. a = np.empty((4,), dtype=dt)
  324. a['num'] = np.arange(1, 5)
  325. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  326. expected_num = [[1, 2, 3, 4]] * 3
  327. expected_obj = [[None]*4]*3
  328. assert_equal(a_view.dtype, dt)
  329. assert_array_equal(expected_num, a_view['num'])
  330. assert_array_equal(expected_obj, a_view['obj'])
  331. # Make sure that void types without fields are kept unchanged
  332. a = np.empty((4,), dtype='V4')
  333. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  334. assert_equal(a.dtype, a_view.dtype)
  335. # Make sure that the only type that could fail is properly handled
  336. dt = np.dtype({'names': [''], 'formats': ['V4']})
  337. a = np.empty((4,), dtype=dt)
  338. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  339. assert_equal(a.dtype, a_view.dtype)
  340. # Custom dtypes should not be lost (gh-9161)
  341. r = [rational(i) for i in range(4)]
  342. a = np.array(r, dtype=rational)
  343. a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
  344. assert_equal(a.dtype, a_view.dtype)
  345. assert_array_equal([r] * 3, a_view)
  346. class TestSlidingWindowView:
  347. def test_1d(self):
  348. arr = np.arange(5)
  349. arr_view = sliding_window_view(arr, 2)
  350. expected = np.array([[0, 1],
  351. [1, 2],
  352. [2, 3],
  353. [3, 4]])
  354. assert_array_equal(arr_view, expected)
  355. def test_2d(self):
  356. i, j = np.ogrid[:3, :4]
  357. arr = 10*i + j
  358. shape = (2, 2)
  359. arr_view = sliding_window_view(arr, shape)
  360. expected = np.array([[[[0, 1], [10, 11]],
  361. [[1, 2], [11, 12]],
  362. [[2, 3], [12, 13]]],
  363. [[[10, 11], [20, 21]],
  364. [[11, 12], [21, 22]],
  365. [[12, 13], [22, 23]]]])
  366. assert_array_equal(arr_view, expected)
  367. def test_2d_with_axis(self):
  368. i, j = np.ogrid[:3, :4]
  369. arr = 10*i + j
  370. arr_view = sliding_window_view(arr, 3, 0)
  371. expected = np.array([[[0, 10, 20],
  372. [1, 11, 21],
  373. [2, 12, 22],
  374. [3, 13, 23]]])
  375. assert_array_equal(arr_view, expected)
  376. def test_2d_repeated_axis(self):
  377. i, j = np.ogrid[:3, :4]
  378. arr = 10*i + j
  379. arr_view = sliding_window_view(arr, (2, 3), (1, 1))
  380. expected = np.array([[[[0, 1, 2],
  381. [1, 2, 3]]],
  382. [[[10, 11, 12],
  383. [11, 12, 13]]],
  384. [[[20, 21, 22],
  385. [21, 22, 23]]]])
  386. assert_array_equal(arr_view, expected)
  387. def test_2d_without_axis(self):
  388. i, j = np.ogrid[:4, :4]
  389. arr = 10*i + j
  390. shape = (2, 3)
  391. arr_view = sliding_window_view(arr, shape)
  392. expected = np.array([[[[0, 1, 2], [10, 11, 12]],
  393. [[1, 2, 3], [11, 12, 13]]],
  394. [[[10, 11, 12], [20, 21, 22]],
  395. [[11, 12, 13], [21, 22, 23]]],
  396. [[[20, 21, 22], [30, 31, 32]],
  397. [[21, 22, 23], [31, 32, 33]]]])
  398. assert_array_equal(arr_view, expected)
  399. def test_errors(self):
  400. i, j = np.ogrid[:4, :4]
  401. arr = 10*i + j
  402. with pytest.raises(ValueError, match='cannot contain negative values'):
  403. sliding_window_view(arr, (-1, 3))
  404. with pytest.raises(
  405. ValueError,
  406. match='must provide window_shape for all dimensions of `x`'):
  407. sliding_window_view(arr, (1,))
  408. with pytest.raises(
  409. ValueError,
  410. match='Must provide matching length window_shape and axis'):
  411. sliding_window_view(arr, (1, 3, 4), axis=(0, 1))
  412. with pytest.raises(
  413. ValueError,
  414. match='window shape cannot be larger than input array'):
  415. sliding_window_view(arr, (5, 5))
  416. def test_writeable(self):
  417. arr = np.arange(5)
  418. view = sliding_window_view(arr, 2, writeable=False)
  419. assert_(not view.flags.writeable)
  420. with pytest.raises(
  421. ValueError,
  422. match='assignment destination is read-only'):
  423. view[0, 0] = 3
  424. view = sliding_window_view(arr, 2, writeable=True)
  425. assert_(view.flags.writeable)
  426. view[0, 1] = 3
  427. assert_array_equal(arr, np.array([0, 3, 2, 3, 4]))
  428. def test_subok(self):
  429. class MyArray(np.ndarray):
  430. pass
  431. arr = np.arange(5).view(MyArray)
  432. assert_(not isinstance(sliding_window_view(arr, 2,
  433. subok=False),
  434. MyArray))
  435. assert_(isinstance(sliding_window_view(arr, 2, subok=True), MyArray))
  436. # Default behavior
  437. assert_(not isinstance(sliding_window_view(arr, 2), MyArray))
  438. def as_strided_writeable():
  439. arr = np.ones(10)
  440. view = as_strided(arr, writeable=False)
  441. assert_(not view.flags.writeable)
  442. # Check that writeable also is fine:
  443. view = as_strided(arr, writeable=True)
  444. assert_(view.flags.writeable)
  445. view[...] = 3
  446. assert_array_equal(arr, np.full_like(arr, 3))
  447. # Test that things do not break down for readonly:
  448. arr.flags.writeable = False
  449. view = as_strided(arr, writeable=False)
  450. view = as_strided(arr, writeable=True)
  451. assert_(not view.flags.writeable)
  452. class VerySimpleSubClass(np.ndarray):
  453. def __new__(cls, *args, **kwargs):
  454. return np.array(*args, subok=True, **kwargs).view(cls)
  455. class SimpleSubClass(VerySimpleSubClass):
  456. def __new__(cls, *args, **kwargs):
  457. self = np.array(*args, subok=True, **kwargs).view(cls)
  458. self.info = 'simple'
  459. return self
  460. def __array_finalize__(self, obj):
  461. self.info = getattr(obj, 'info', '') + ' finalized'
  462. def test_subclasses():
  463. # test that subclass is preserved only if subok=True
  464. a = VerySimpleSubClass([1, 2, 3, 4])
  465. assert_(type(a) is VerySimpleSubClass)
  466. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
  467. assert_(type(a_view) is np.ndarray)
  468. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
  469. assert_(type(a_view) is VerySimpleSubClass)
  470. # test that if a subclass has __array_finalize__, it is used
  471. a = SimpleSubClass([1, 2, 3, 4])
  472. a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
  473. assert_(type(a_view) is SimpleSubClass)
  474. assert_(a_view.info == 'simple finalized')
  475. # similar tests for broadcast_arrays
  476. b = np.arange(len(a)).reshape(-1, 1)
  477. a_view, b_view = broadcast_arrays(a, b)
  478. assert_(type(a_view) is np.ndarray)
  479. assert_(type(b_view) is np.ndarray)
  480. assert_(a_view.shape == b_view.shape)
  481. a_view, b_view = broadcast_arrays(a, b, subok=True)
  482. assert_(type(a_view) is SimpleSubClass)
  483. assert_(a_view.info == 'simple finalized')
  484. assert_(type(b_view) is np.ndarray)
  485. assert_(a_view.shape == b_view.shape)
  486. # and for broadcast_to
  487. shape = (2, 4)
  488. a_view = broadcast_to(a, shape)
  489. assert_(type(a_view) is np.ndarray)
  490. assert_(a_view.shape == shape)
  491. a_view = broadcast_to(a, shape, subok=True)
  492. assert_(type(a_view) is SimpleSubClass)
  493. assert_(a_view.info == 'simple finalized')
  494. assert_(a_view.shape == shape)
  495. def test_writeable():
  496. # broadcast_to should return a readonly array
  497. original = np.array([1, 2, 3])
  498. result = broadcast_to(original, (2, 3))
  499. assert_equal(result.flags.writeable, False)
  500. assert_raises(ValueError, result.__setitem__, slice(None), 0)
  501. # but the result of broadcast_arrays needs to be writeable, to
  502. # preserve backwards compatibility
  503. for is_broadcast, results in [(False, broadcast_arrays(original,)),
  504. (True, broadcast_arrays(0, original))]:
  505. for result in results:
  506. # This will change to False in a future version
  507. if is_broadcast:
  508. with assert_warns(FutureWarning):
  509. assert_equal(result.flags.writeable, True)
  510. with assert_warns(DeprecationWarning):
  511. result[:] = 0
  512. # Warning not emitted, writing to the array resets it
  513. assert_equal(result.flags.writeable, True)
  514. else:
  515. # No warning:
  516. assert_equal(result.flags.writeable, True)
  517. for results in [broadcast_arrays(original),
  518. broadcast_arrays(0, original)]:
  519. for result in results:
  520. # resets the warn_on_write DeprecationWarning
  521. result.flags.writeable = True
  522. # check: no warning emitted
  523. assert_equal(result.flags.writeable, True)
  524. result[:] = 0
  525. # keep readonly input readonly
  526. original.flags.writeable = False
  527. _, result = broadcast_arrays(0, original)
  528. assert_equal(result.flags.writeable, False)
  529. # regression test for GH6491
  530. shape = (2,)
  531. strides = [0]
  532. tricky_array = as_strided(np.array(0), shape, strides)
  533. other = np.zeros((1,))
  534. first, second = broadcast_arrays(tricky_array, other)
  535. assert_(first.shape == second.shape)
  536. def test_writeable_memoryview():
  537. # The result of broadcast_arrays exports as a non-writeable memoryview
  538. # because otherwise there is no good way to opt in to the new behaviour
  539. # (i.e. you would need to set writeable to False explicitly).
  540. # See gh-13929.
  541. original = np.array([1, 2, 3])
  542. for is_broadcast, results in [(False, broadcast_arrays(original,)),
  543. (True, broadcast_arrays(0, original))]:
  544. for result in results:
  545. # This will change to False in a future version
  546. if is_broadcast:
  547. # memoryview(result, writable=True) will give warning but cannot
  548. # be tested using the python API.
  549. assert memoryview(result).readonly
  550. else:
  551. assert not memoryview(result).readonly
  552. def test_reference_types():
  553. input_array = np.array('a', dtype=object)
  554. expected = np.array(['a'] * 3, dtype=object)
  555. actual = broadcast_to(input_array, (3,))
  556. assert_array_equal(expected, actual)
  557. actual, _ = broadcast_arrays(input_array, np.ones(3))
  558. assert_array_equal(expected, actual)