_index.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. """Indexing mixin for sparse matrix classes.
  2. """
  3. import numpy as np
  4. from ._sputils import isintlike
  5. try:
  6. INT_TYPES = (int, long, np.integer)
  7. except NameError:
  8. # long is not defined in Python3
  9. INT_TYPES = (int, np.integer)
  10. def _broadcast_arrays(a, b):
  11. """
  12. Same as np.broadcast_arrays(a, b) but old writeability rules.
  13. NumPy >= 1.17.0 transitions broadcast_arrays to return
  14. read-only arrays. Set writeability explicitly to avoid warnings.
  15. Retain the old writeability rules, as our Cython code assumes
  16. the old behavior.
  17. """
  18. x, y = np.broadcast_arrays(a, b)
  19. x.flags.writeable = a.flags.writeable
  20. y.flags.writeable = b.flags.writeable
  21. return x, y
  22. class IndexMixin:
  23. """
  24. This class provides common dispatching and validation logic for indexing.
  25. """
  26. def _raise_on_1d_array_slice(self):
  27. """We do not currently support 1D sparse arrays.
  28. This function is called each time that a 1D array would
  29. result, raising an error instead.
  30. Once 1D sparse arrays are implemented, it should be removed.
  31. """
  32. if self._is_array:
  33. raise NotImplementedError(
  34. 'We have not yet implemented 1D sparse slices; '
  35. 'please index using explicit indices, e.g. `x[:, [0]]`'
  36. )
  37. def __getitem__(self, key):
  38. row, col = self._validate_indices(key)
  39. # Dispatch to specialized methods.
  40. if isinstance(row, INT_TYPES):
  41. if isinstance(col, INT_TYPES):
  42. return self._get_intXint(row, col)
  43. elif isinstance(col, slice):
  44. self._raise_on_1d_array_slice()
  45. return self._get_intXslice(row, col)
  46. elif col.ndim == 1:
  47. self._raise_on_1d_array_slice()
  48. return self._get_intXarray(row, col)
  49. elif col.ndim == 2:
  50. return self._get_intXarray(row, col)
  51. raise IndexError('index results in >2 dimensions')
  52. elif isinstance(row, slice):
  53. if isinstance(col, INT_TYPES):
  54. self._raise_on_1d_array_slice()
  55. return self._get_sliceXint(row, col)
  56. elif isinstance(col, slice):
  57. if row == slice(None) and row == col:
  58. return self.copy()
  59. return self._get_sliceXslice(row, col)
  60. elif col.ndim == 1:
  61. return self._get_sliceXarray(row, col)
  62. raise IndexError('index results in >2 dimensions')
  63. elif row.ndim == 1:
  64. if isinstance(col, INT_TYPES):
  65. self._raise_on_1d_array_slice()
  66. return self._get_arrayXint(row, col)
  67. elif isinstance(col, slice):
  68. return self._get_arrayXslice(row, col)
  69. else: # row.ndim == 2
  70. if isinstance(col, INT_TYPES):
  71. return self._get_arrayXint(row, col)
  72. elif isinstance(col, slice):
  73. raise IndexError('index results in >2 dimensions')
  74. elif row.shape[1] == 1 and (col.ndim == 1 or col.shape[0] == 1):
  75. # special case for outer indexing
  76. return self._get_columnXarray(row[:,0], col.ravel())
  77. # The only remaining case is inner (fancy) indexing
  78. row, col = _broadcast_arrays(row, col)
  79. if row.shape != col.shape:
  80. raise IndexError('number of row and column indices differ')
  81. if row.size == 0:
  82. return self.__class__(np.atleast_2d(row).shape, dtype=self.dtype)
  83. return self._get_arrayXarray(row, col)
  84. def __setitem__(self, key, x):
  85. row, col = self._validate_indices(key)
  86. if isinstance(row, INT_TYPES) and isinstance(col, INT_TYPES):
  87. x = np.asarray(x, dtype=self.dtype)
  88. if x.size != 1:
  89. raise ValueError('Trying to assign a sequence to an item')
  90. self._set_intXint(row, col, x.flat[0])
  91. return
  92. if isinstance(row, slice):
  93. row = np.arange(*row.indices(self.shape[0]))[:, None]
  94. else:
  95. row = np.atleast_1d(row)
  96. if isinstance(col, slice):
  97. col = np.arange(*col.indices(self.shape[1]))[None, :]
  98. if row.ndim == 1:
  99. row = row[:, None]
  100. else:
  101. col = np.atleast_1d(col)
  102. i, j = _broadcast_arrays(row, col)
  103. if i.shape != j.shape:
  104. raise IndexError('number of row and column indices differ')
  105. from ._base import isspmatrix
  106. if isspmatrix(x):
  107. if i.ndim == 1:
  108. # Inner indexing, so treat them like row vectors.
  109. i = i[None]
  110. j = j[None]
  111. broadcast_row = x.shape[0] == 1 and i.shape[0] != 1
  112. broadcast_col = x.shape[1] == 1 and i.shape[1] != 1
  113. if not ((broadcast_row or x.shape[0] == i.shape[0]) and
  114. (broadcast_col or x.shape[1] == i.shape[1])):
  115. raise ValueError('shape mismatch in assignment')
  116. if x.shape[0] == 0 or x.shape[1] == 0:
  117. return
  118. x = x.tocoo(copy=True)
  119. x.sum_duplicates()
  120. self._set_arrayXarray_sparse(i, j, x)
  121. else:
  122. # Make x and i into the same shape
  123. x = np.asarray(x, dtype=self.dtype)
  124. if x.squeeze().shape != i.squeeze().shape:
  125. x = np.broadcast_to(x, i.shape)
  126. if x.size == 0:
  127. return
  128. x = x.reshape(i.shape)
  129. self._set_arrayXarray(i, j, x)
  130. def _validate_indices(self, key):
  131. M, N = self.shape
  132. row, col = _unpack_index(key)
  133. if isintlike(row):
  134. row = int(row)
  135. if row < -M or row >= M:
  136. raise IndexError('row index (%d) out of range' % row)
  137. if row < 0:
  138. row += M
  139. elif not isinstance(row, slice):
  140. row = self._asindices(row, M)
  141. if isintlike(col):
  142. col = int(col)
  143. if col < -N or col >= N:
  144. raise IndexError('column index (%d) out of range' % col)
  145. if col < 0:
  146. col += N
  147. elif not isinstance(col, slice):
  148. col = self._asindices(col, N)
  149. return row, col
  150. def _asindices(self, idx, length):
  151. """Convert `idx` to a valid index for an axis with a given length.
  152. Subclasses that need special validation can override this method.
  153. """
  154. try:
  155. x = np.asarray(idx)
  156. except (ValueError, TypeError, MemoryError) as e:
  157. raise IndexError('invalid index') from e
  158. if x.ndim not in (1, 2):
  159. raise IndexError('Index dimension must be 1 or 2')
  160. if x.size == 0:
  161. return x
  162. # Check bounds
  163. max_indx = x.max()
  164. if max_indx >= length:
  165. raise IndexError('index (%d) out of range' % max_indx)
  166. min_indx = x.min()
  167. if min_indx < 0:
  168. if min_indx < -length:
  169. raise IndexError('index (%d) out of range' % min_indx)
  170. if x is idx or not x.flags.owndata:
  171. x = x.copy()
  172. x[x < 0] += length
  173. return x
  174. def getrow(self, i):
  175. """Return a copy of row i of the matrix, as a (1 x n) row vector.
  176. """
  177. M, N = self.shape
  178. i = int(i)
  179. if i < -M or i >= M:
  180. raise IndexError('index (%d) out of range' % i)
  181. if i < 0:
  182. i += M
  183. return self._get_intXslice(i, slice(None))
  184. def getcol(self, i):
  185. """Return a copy of column i of the matrix, as a (m x 1) column vector.
  186. """
  187. M, N = self.shape
  188. i = int(i)
  189. if i < -N or i >= N:
  190. raise IndexError('index (%d) out of range' % i)
  191. if i < 0:
  192. i += N
  193. return self._get_sliceXint(slice(None), i)
  194. def _get_intXint(self, row, col):
  195. raise NotImplementedError()
  196. def _get_intXarray(self, row, col):
  197. raise NotImplementedError()
  198. def _get_intXslice(self, row, col):
  199. raise NotImplementedError()
  200. def _get_sliceXint(self, row, col):
  201. raise NotImplementedError()
  202. def _get_sliceXslice(self, row, col):
  203. raise NotImplementedError()
  204. def _get_sliceXarray(self, row, col):
  205. raise NotImplementedError()
  206. def _get_arrayXint(self, row, col):
  207. raise NotImplementedError()
  208. def _get_arrayXslice(self, row, col):
  209. raise NotImplementedError()
  210. def _get_columnXarray(self, row, col):
  211. raise NotImplementedError()
  212. def _get_arrayXarray(self, row, col):
  213. raise NotImplementedError()
  214. def _set_intXint(self, row, col, x):
  215. raise NotImplementedError()
  216. def _set_arrayXarray(self, row, col, x):
  217. raise NotImplementedError()
  218. def _set_arrayXarray_sparse(self, row, col, x):
  219. # Fall back to densifying x
  220. x = np.asarray(x.toarray(), dtype=self.dtype)
  221. x, _ = _broadcast_arrays(x, row)
  222. self._set_arrayXarray(row, col, x)
  223. def _unpack_index(index):
  224. """ Parse index. Always return a tuple of the form (row, col).
  225. Valid type for row/col is integer, slice, or array of integers.
  226. """
  227. # First, check if indexing with single boolean matrix.
  228. from ._base import spmatrix, isspmatrix
  229. if (isinstance(index, (spmatrix, np.ndarray)) and
  230. index.ndim == 2 and index.dtype.kind == 'b'):
  231. return index.nonzero()
  232. # Parse any ellipses.
  233. index = _check_ellipsis(index)
  234. # Next, parse the tuple or object
  235. if isinstance(index, tuple):
  236. if len(index) == 2:
  237. row, col = index
  238. elif len(index) == 1:
  239. row, col = index[0], slice(None)
  240. else:
  241. raise IndexError('invalid number of indices')
  242. else:
  243. idx = _compatible_boolean_index(index)
  244. if idx is None:
  245. row, col = index, slice(None)
  246. elif idx.ndim < 2:
  247. return _boolean_index_to_array(idx), slice(None)
  248. elif idx.ndim == 2:
  249. return idx.nonzero()
  250. # Next, check for validity and transform the index as needed.
  251. if isspmatrix(row) or isspmatrix(col):
  252. # Supporting sparse boolean indexing with both row and col does
  253. # not work because spmatrix.ndim is always 2.
  254. raise IndexError(
  255. 'Indexing with sparse matrices is not supported '
  256. 'except boolean indexing where matrix and index '
  257. 'are equal shapes.')
  258. bool_row = _compatible_boolean_index(row)
  259. bool_col = _compatible_boolean_index(col)
  260. if bool_row is not None:
  261. row = _boolean_index_to_array(bool_row)
  262. if bool_col is not None:
  263. col = _boolean_index_to_array(bool_col)
  264. return row, col
  265. def _check_ellipsis(index):
  266. """Process indices with Ellipsis. Returns modified index."""
  267. if index is Ellipsis:
  268. return (slice(None), slice(None))
  269. if not isinstance(index, tuple):
  270. return index
  271. # TODO: Deprecate this multiple-ellipsis handling,
  272. # as numpy no longer supports it.
  273. # Find first ellipsis.
  274. for j, v in enumerate(index):
  275. if v is Ellipsis:
  276. first_ellipsis = j
  277. break
  278. else:
  279. return index
  280. # Try to expand it using shortcuts for common cases
  281. if len(index) == 1:
  282. return (slice(None), slice(None))
  283. if len(index) == 2:
  284. if first_ellipsis == 0:
  285. if index[1] is Ellipsis:
  286. return (slice(None), slice(None))
  287. return (slice(None), index[1])
  288. return (index[0], slice(None))
  289. # Expand it using a general-purpose algorithm
  290. tail = []
  291. for v in index[first_ellipsis+1:]:
  292. if v is not Ellipsis:
  293. tail.append(v)
  294. nd = first_ellipsis + len(tail)
  295. nslice = max(0, 2 - nd)
  296. return index[:first_ellipsis] + (slice(None),)*nslice + tuple(tail)
  297. def _maybe_bool_ndarray(idx):
  298. """Returns a compatible array if elements are boolean.
  299. """
  300. idx = np.asanyarray(idx)
  301. if idx.dtype.kind == 'b':
  302. return idx
  303. return None
  304. def _first_element_bool(idx, max_dim=2):
  305. """Returns True if first element of the incompatible
  306. array type is boolean.
  307. """
  308. if max_dim < 1:
  309. return None
  310. try:
  311. first = next(iter(idx), None)
  312. except TypeError:
  313. return None
  314. if isinstance(first, bool):
  315. return True
  316. return _first_element_bool(first, max_dim-1)
  317. def _compatible_boolean_index(idx):
  318. """Returns a boolean index array that can be converted to
  319. integer array. Returns None if no such array exists.
  320. """
  321. # Presence of attribute `ndim` indicates a compatible array type.
  322. if hasattr(idx, 'ndim') or _first_element_bool(idx):
  323. return _maybe_bool_ndarray(idx)
  324. return None
  325. def _boolean_index_to_array(idx):
  326. if idx.ndim > 1:
  327. raise IndexError('invalid index shape')
  328. return np.where(idx)[0]