test_csr.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import numpy as np
  2. from numpy.testing import assert_array_almost_equal, assert_
  3. from scipy.sparse import csr_matrix, hstack
  4. import pytest
  5. def _check_csr_rowslice(i, sl, X, Xcsr):
  6. np_slice = X[i, sl]
  7. csr_slice = Xcsr[i, sl]
  8. assert_array_almost_equal(np_slice, csr_slice.toarray()[0])
  9. assert_(type(csr_slice) is csr_matrix)
  10. def test_csr_rowslice():
  11. N = 10
  12. np.random.seed(0)
  13. X = np.random.random((N, N))
  14. X[X > 0.7] = 0
  15. Xcsr = csr_matrix(X)
  16. slices = [slice(None, None, None),
  17. slice(None, None, -1),
  18. slice(1, -2, 2),
  19. slice(-2, 1, -2)]
  20. for i in range(N):
  21. for sl in slices:
  22. _check_csr_rowslice(i, sl, X, Xcsr)
  23. def test_csr_getrow():
  24. N = 10
  25. np.random.seed(0)
  26. X = np.random.random((N, N))
  27. X[X > 0.7] = 0
  28. Xcsr = csr_matrix(X)
  29. for i in range(N):
  30. arr_row = X[i:i + 1, :]
  31. csr_row = Xcsr.getrow(i)
  32. assert_array_almost_equal(arr_row, csr_row.toarray())
  33. assert_(type(csr_row) is csr_matrix)
  34. def test_csr_getcol():
  35. N = 10
  36. np.random.seed(0)
  37. X = np.random.random((N, N))
  38. X[X > 0.7] = 0
  39. Xcsr = csr_matrix(X)
  40. for i in range(N):
  41. arr_col = X[:, i:i + 1]
  42. csr_col = Xcsr.getcol(i)
  43. assert_array_almost_equal(arr_col, csr_col.toarray())
  44. assert_(type(csr_col) is csr_matrix)
  45. @pytest.mark.parametrize("matrix_input, axis, expected_shape",
  46. [(csr_matrix([[1, 0, 0, 0],
  47. [0, 0, 0, 0],
  48. [0, 2, 3, 0]]),
  49. 0, (0, 4)),
  50. (csr_matrix([[1, 0, 0, 0],
  51. [0, 0, 0, 0],
  52. [0, 2, 3, 0]]),
  53. 1, (3, 0)),
  54. (csr_matrix([[1, 0, 0, 0],
  55. [0, 0, 0, 0],
  56. [0, 2, 3, 0]]),
  57. 'both', (0, 0)),
  58. (csr_matrix([[0, 1, 0, 0, 0],
  59. [0, 0, 0, 0, 0],
  60. [0, 0, 2, 3, 0]]),
  61. 0, (0, 5))])
  62. def test_csr_empty_slices(matrix_input, axis, expected_shape):
  63. # see gh-11127 for related discussion
  64. slice_1 = matrix_input.A.shape[0] - 1
  65. slice_2 = slice_1
  66. slice_3 = slice_2 - 1
  67. if axis == 0:
  68. actual_shape_1 = matrix_input[slice_1:slice_2, :].A.shape
  69. actual_shape_2 = matrix_input[slice_1:slice_3, :].A.shape
  70. elif axis == 1:
  71. actual_shape_1 = matrix_input[:, slice_1:slice_2].A.shape
  72. actual_shape_2 = matrix_input[:, slice_1:slice_3].A.shape
  73. elif axis == 'both':
  74. actual_shape_1 = matrix_input[slice_1:slice_2, slice_1:slice_2].A.shape
  75. actual_shape_2 = matrix_input[slice_1:slice_3, slice_1:slice_3].A.shape
  76. assert actual_shape_1 == expected_shape
  77. assert actual_shape_1 == actual_shape_2
  78. def test_csr_bool_indexing():
  79. data = csr_matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
  80. list_indices1 = [False, True, False]
  81. array_indices1 = np.array(list_indices1)
  82. list_indices2 = [[False, True, False], [False, True, False], [False, True, False]]
  83. array_indices2 = np.array(list_indices2)
  84. list_indices3 = ([False, True, False], [False, True, False])
  85. array_indices3 = (np.array(list_indices3[0]), np.array(list_indices3[1]))
  86. slice_list1 = data[list_indices1].toarray()
  87. slice_array1 = data[array_indices1].toarray()
  88. slice_list2 = data[list_indices2]
  89. slice_array2 = data[array_indices2]
  90. slice_list3 = data[list_indices3]
  91. slice_array3 = data[array_indices3]
  92. assert (slice_list1 == slice_array1).all()
  93. assert (slice_list2 == slice_array2).all()
  94. assert (slice_list3 == slice_array3).all()
  95. def test_csr_hstack_int64():
  96. """
  97. Tests if hstack properly promotes to indices and indptr arrays to np.int64
  98. when using np.int32 during concatenation would result in either array
  99. overflowing.
  100. """
  101. max_int32 = np.iinfo(np.int32).max
  102. # First case: indices would overflow with int32
  103. data = [1.0]
  104. row = [0]
  105. max_indices_1 = max_int32 - 1
  106. max_indices_2 = 3
  107. # Individual indices arrays are representable with int32
  108. col_1 = [max_indices_1 - 1]
  109. col_2 = [max_indices_2 - 1]
  110. X_1 = csr_matrix((data, (row, col_1)))
  111. X_2 = csr_matrix((data, (row, col_2)))
  112. assert max(max_indices_1 - 1, max_indices_2 - 1) < max_int32
  113. assert X_1.indices.dtype == X_1.indptr.dtype == np.int32
  114. assert X_2.indices.dtype == X_2.indptr.dtype == np.int32
  115. # ... but when concatenating their CSR matrices, the resulting indices
  116. # array can't be represented with int32 and must be promoted to int64.
  117. X_hs = hstack([X_1, X_2], format="csr")
  118. assert X_hs.indices.max() == max_indices_1 + max_indices_2 - 1
  119. assert max_indices_1 + max_indices_2 - 1 > max_int32
  120. assert X_hs.indices.dtype == X_hs.indptr.dtype == np.int64
  121. # Even if the matrices are empty, we must account for their size
  122. # contribution so that we may safely set the final elements.
  123. X_1_empty = csr_matrix(X_1.shape)
  124. X_2_empty = csr_matrix(X_2.shape)
  125. X_hs_empty = hstack([X_1_empty, X_2_empty], format="csr")
  126. assert X_hs_empty.shape == X_hs.shape
  127. assert X_hs_empty.indices.dtype == np.int64
  128. # Should be just small enough to stay in int32 after stack. Note that
  129. # we theoretically could support indices.max() == max_int32, but due to an
  130. # edge-case in the underlying sparsetools code
  131. # (namely the `coo_tocsr` routine),
  132. # we require that max(X_hs_32.shape) < max_int32 as well.
  133. # Hence we can only support max_int32 - 1.
  134. col_3 = [max_int32 - max_indices_1 - 1]
  135. X_3 = csr_matrix((data, (row, col_3)))
  136. X_hs_32 = hstack([X_1, X_3], format="csr")
  137. assert X_hs_32.indices.dtype == np.int32
  138. assert X_hs_32.indices.max() == max_int32 - 1