_compressed.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318
  1. """Base class for sparse matrix formats using compressed storage."""
  2. __all__ = []
  3. from warnings import warn
  4. import operator
  5. import numpy as np
  6. from scipy._lib._util import _prune_array
  7. from ._base import spmatrix, isspmatrix, SparseEfficiencyWarning
  8. from ._data import _data_matrix, _minmax_mixin
  9. from . import _sparsetools
  10. from ._sparsetools import (get_csr_submatrix, csr_sample_offsets, csr_todense,
  11. csr_sample_values, csr_row_index, csr_row_slice,
  12. csr_column_index1, csr_column_index2)
  13. from ._index import IndexMixin
  14. from ._sputils import (upcast, upcast_char, to_native, isdense, isshape,
  15. getdtype, isscalarlike, isintlike, get_index_dtype,
  16. downcast_intp_index, get_sum_dtype, check_shape,
  17. is_pydata_spmatrix)
  18. class _cs_matrix(_data_matrix, _minmax_mixin, IndexMixin):
  19. """base matrix class for compressed row- and column-oriented matrices"""
  20. def __init__(self, arg1, shape=None, dtype=None, copy=False):
  21. _data_matrix.__init__(self)
  22. if isspmatrix(arg1):
  23. if arg1.format == self.format and copy:
  24. arg1 = arg1.copy()
  25. else:
  26. arg1 = arg1.asformat(self.format)
  27. self._set_self(arg1)
  28. elif isinstance(arg1, tuple):
  29. if isshape(arg1):
  30. # It's a tuple of matrix dimensions (M, N)
  31. # create empty matrix
  32. self._shape = check_shape(arg1)
  33. M, N = self.shape
  34. # Select index dtype large enough to pass array and
  35. # scalar parameters to sparsetools
  36. idx_dtype = get_index_dtype(maxval=max(M, N))
  37. self.data = np.zeros(0, getdtype(dtype, default=float))
  38. self.indices = np.zeros(0, idx_dtype)
  39. self.indptr = np.zeros(self._swap((M, N))[0] + 1,
  40. dtype=idx_dtype)
  41. else:
  42. if len(arg1) == 2:
  43. # (data, ij) format
  44. other = self.__class__(
  45. self._coo_container(arg1, shape=shape, dtype=dtype)
  46. )
  47. self._set_self(other)
  48. elif len(arg1) == 3:
  49. # (data, indices, indptr) format
  50. (data, indices, indptr) = arg1
  51. # Select index dtype large enough to pass array and
  52. # scalar parameters to sparsetools
  53. maxval = None
  54. if shape is not None:
  55. maxval = max(shape)
  56. idx_dtype = get_index_dtype((indices, indptr),
  57. maxval=maxval,
  58. check_contents=True)
  59. self.indices = np.array(indices, copy=copy,
  60. dtype=idx_dtype)
  61. self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
  62. self.data = np.array(data, copy=copy, dtype=dtype)
  63. else:
  64. raise ValueError("unrecognized {}_matrix "
  65. "constructor usage".format(self.format))
  66. else:
  67. # must be dense
  68. try:
  69. arg1 = np.asarray(arg1)
  70. except Exception as e:
  71. raise ValueError("unrecognized {}_matrix constructor usage"
  72. "".format(self.format)) from e
  73. self._set_self(self.__class__(
  74. self._coo_container(arg1, dtype=dtype)
  75. ))
  76. # Read matrix dimensions given, if any
  77. if shape is not None:
  78. self._shape = check_shape(shape)
  79. else:
  80. if self.shape is None:
  81. # shape not already set, try to infer dimensions
  82. try:
  83. major_dim = len(self.indptr) - 1
  84. minor_dim = self.indices.max() + 1
  85. except Exception as e:
  86. raise ValueError('unable to infer matrix dimensions') from e
  87. else:
  88. self._shape = check_shape(self._swap((major_dim,
  89. minor_dim)))
  90. if dtype is not None:
  91. self.data = self.data.astype(dtype, copy=False)
  92. self.check_format(full_check=False)
  93. def getnnz(self, axis=None):
  94. if axis is None:
  95. return int(self.indptr[-1])
  96. else:
  97. if axis < 0:
  98. axis += 2
  99. axis, _ = self._swap((axis, 1 - axis))
  100. _, N = self._swap(self.shape)
  101. if axis == 0:
  102. return np.bincount(downcast_intp_index(self.indices),
  103. minlength=N)
  104. elif axis == 1:
  105. return np.diff(self.indptr)
  106. raise ValueError('axis out of bounds')
  107. getnnz.__doc__ = spmatrix.getnnz.__doc__
  108. def _set_self(self, other, copy=False):
  109. """take the member variables of other and assign them to self"""
  110. if copy:
  111. other = other.copy()
  112. self.data = other.data
  113. self.indices = other.indices
  114. self.indptr = other.indptr
  115. self._shape = check_shape(other.shape)
  116. def check_format(self, full_check=True):
  117. """check whether the matrix format is valid
  118. Parameters
  119. ----------
  120. full_check : bool, optional
  121. If `True`, rigorous check, O(N) operations. Otherwise
  122. basic check, O(1) operations (default True).
  123. """
  124. # use _swap to determine proper bounds
  125. major_name, minor_name = self._swap(('row', 'column'))
  126. major_dim, minor_dim = self._swap(self.shape)
  127. # index arrays should have integer data types
  128. if self.indptr.dtype.kind != 'i':
  129. warn("indptr array has non-integer dtype ({})"
  130. "".format(self.indptr.dtype.name), stacklevel=3)
  131. if self.indices.dtype.kind != 'i':
  132. warn("indices array has non-integer dtype ({})"
  133. "".format(self.indices.dtype.name), stacklevel=3)
  134. idx_dtype = get_index_dtype((self.indptr, self.indices))
  135. self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
  136. self.indices = np.asarray(self.indices, dtype=idx_dtype)
  137. self.data = to_native(self.data)
  138. # check array shapes
  139. for x in [self.data.ndim, self.indices.ndim, self.indptr.ndim]:
  140. if x != 1:
  141. raise ValueError('data, indices, and indptr should be 1-D')
  142. # check index pointer
  143. if (len(self.indptr) != major_dim + 1):
  144. raise ValueError("index pointer size ({}) should be ({})"
  145. "".format(len(self.indptr), major_dim + 1))
  146. if (self.indptr[0] != 0):
  147. raise ValueError("index pointer should start with 0")
  148. # check index and data arrays
  149. if (len(self.indices) != len(self.data)):
  150. raise ValueError("indices and data should have the same size")
  151. if (self.indptr[-1] > len(self.indices)):
  152. raise ValueError("Last value of index pointer should be less than "
  153. "the size of index and data arrays")
  154. self.prune()
  155. if full_check:
  156. # check format validity (more expensive)
  157. if self.nnz > 0:
  158. if self.indices.max() >= minor_dim:
  159. raise ValueError("{} index values must be < {}"
  160. "".format(minor_name, minor_dim))
  161. if self.indices.min() < 0:
  162. raise ValueError("{} index values must be >= 0"
  163. "".format(minor_name))
  164. if np.diff(self.indptr).min() < 0:
  165. raise ValueError("index pointer values must form a "
  166. "non-decreasing sequence")
  167. # if not self.has_sorted_indices():
  168. # warn('Indices were not in sorted order. Sorting indices.')
  169. # self.sort_indices()
  170. # assert(self.has_sorted_indices())
  171. # TODO check for duplicates?
  172. #######################
  173. # Boolean comparisons #
  174. #######################
  175. def _scalar_binopt(self, other, op):
  176. """Scalar version of self._binopt, for cases in which no new nonzeros
  177. are added. Produces a new spmatrix in canonical form.
  178. """
  179. self.sum_duplicates()
  180. res = self._with_data(op(self.data, other), copy=True)
  181. res.eliminate_zeros()
  182. return res
  183. def __eq__(self, other):
  184. # Scalar other.
  185. if isscalarlike(other):
  186. if np.isnan(other):
  187. return self.__class__(self.shape, dtype=np.bool_)
  188. if other == 0:
  189. warn("Comparing a sparse matrix with 0 using == is inefficient"
  190. ", try using != instead.", SparseEfficiencyWarning,
  191. stacklevel=3)
  192. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  193. inv = self._scalar_binopt(other, operator.ne)
  194. return all_true - inv
  195. else:
  196. return self._scalar_binopt(other, operator.eq)
  197. # Dense other.
  198. elif isdense(other):
  199. return self.todense() == other
  200. # Pydata sparse other.
  201. elif is_pydata_spmatrix(other):
  202. return NotImplemented
  203. # Sparse other.
  204. elif isspmatrix(other):
  205. warn("Comparing sparse matrices using == is inefficient, try using"
  206. " != instead.", SparseEfficiencyWarning, stacklevel=3)
  207. # TODO sparse broadcasting
  208. if self.shape != other.shape:
  209. return False
  210. elif self.format != other.format:
  211. other = other.asformat(self.format)
  212. res = self._binopt(other, '_ne_')
  213. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  214. return all_true - res
  215. else:
  216. return False
  217. def __ne__(self, other):
  218. # Scalar other.
  219. if isscalarlike(other):
  220. if np.isnan(other):
  221. warn("Comparing a sparse matrix with nan using != is"
  222. " inefficient", SparseEfficiencyWarning, stacklevel=3)
  223. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  224. return all_true
  225. elif other != 0:
  226. warn("Comparing a sparse matrix with a nonzero scalar using !="
  227. " is inefficient, try using == instead.",
  228. SparseEfficiencyWarning, stacklevel=3)
  229. all_true = self.__class__(np.ones(self.shape), dtype=np.bool_)
  230. inv = self._scalar_binopt(other, operator.eq)
  231. return all_true - inv
  232. else:
  233. return self._scalar_binopt(other, operator.ne)
  234. # Dense other.
  235. elif isdense(other):
  236. return self.todense() != other
  237. # Pydata sparse other.
  238. elif is_pydata_spmatrix(other):
  239. return NotImplemented
  240. # Sparse other.
  241. elif isspmatrix(other):
  242. # TODO sparse broadcasting
  243. if self.shape != other.shape:
  244. return True
  245. elif self.format != other.format:
  246. other = other.asformat(self.format)
  247. return self._binopt(other, '_ne_')
  248. else:
  249. return True
  250. def _inequality(self, other, op, op_name, bad_scalar_msg):
  251. # Scalar other.
  252. if isscalarlike(other):
  253. if 0 == other and op_name in ('_le_', '_ge_'):
  254. raise NotImplementedError(" >= and <= don't work with 0.")
  255. elif op(0, other):
  256. warn(bad_scalar_msg, SparseEfficiencyWarning)
  257. other_arr = np.empty(self.shape, dtype=np.result_type(other))
  258. other_arr.fill(other)
  259. other_arr = self.__class__(other_arr)
  260. return self._binopt(other_arr, op_name)
  261. else:
  262. return self._scalar_binopt(other, op)
  263. # Dense other.
  264. elif isdense(other):
  265. return op(self.todense(), other)
  266. # Sparse other.
  267. elif isspmatrix(other):
  268. # TODO sparse broadcasting
  269. if self.shape != other.shape:
  270. raise ValueError("inconsistent shapes")
  271. elif self.format != other.format:
  272. other = other.asformat(self.format)
  273. if op_name not in ('_ge_', '_le_'):
  274. return self._binopt(other, op_name)
  275. warn("Comparing sparse matrices using >= and <= is inefficient, "
  276. "using <, >, or !=, instead.", SparseEfficiencyWarning)
  277. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  278. res = self._binopt(other, '_gt_' if op_name == '_le_' else '_lt_')
  279. return all_true - res
  280. else:
  281. raise ValueError("Operands could not be compared.")
  282. def __lt__(self, other):
  283. return self._inequality(other, operator.lt, '_lt_',
  284. "Comparing a sparse matrix with a scalar "
  285. "greater than zero using < is inefficient, "
  286. "try using >= instead.")
  287. def __gt__(self, other):
  288. return self._inequality(other, operator.gt, '_gt_',
  289. "Comparing a sparse matrix with a scalar "
  290. "less than zero using > is inefficient, "
  291. "try using <= instead.")
  292. def __le__(self, other):
  293. return self._inequality(other, operator.le, '_le_',
  294. "Comparing a sparse matrix with a scalar "
  295. "greater than zero using <= is inefficient, "
  296. "try using > instead.")
  297. def __ge__(self, other):
  298. return self._inequality(other, operator.ge, '_ge_',
  299. "Comparing a sparse matrix with a scalar "
  300. "less than zero using >= is inefficient, "
  301. "try using < instead.")
  302. #################################
  303. # Arithmetic operator overrides #
  304. #################################
  305. def _add_dense(self, other):
  306. if other.shape != self.shape:
  307. raise ValueError('Incompatible shapes ({} and {})'
  308. .format(self.shape, other.shape))
  309. dtype = upcast_char(self.dtype.char, other.dtype.char)
  310. order = self._swap('CF')[0]
  311. result = np.array(other, dtype=dtype, order=order, copy=True)
  312. M, N = self._swap(self.shape)
  313. y = result if result.flags.c_contiguous else result.T
  314. csr_todense(M, N, self.indptr, self.indices, self.data, y)
  315. return self._container(result, copy=False)
  316. def _add_sparse(self, other):
  317. return self._binopt(other, '_plus_')
  318. def _sub_sparse(self, other):
  319. return self._binopt(other, '_minus_')
  320. def multiply(self, other):
  321. """Point-wise multiplication by another matrix, vector, or
  322. scalar.
  323. """
  324. # Scalar multiplication.
  325. if isscalarlike(other):
  326. return self._mul_scalar(other)
  327. # Sparse matrix or vector.
  328. if isspmatrix(other):
  329. if self.shape == other.shape:
  330. other = self.__class__(other)
  331. return self._binopt(other, '_elmul_')
  332. # Single element.
  333. elif other.shape == (1, 1):
  334. return self._mul_scalar(other.toarray()[0, 0])
  335. elif self.shape == (1, 1):
  336. return other._mul_scalar(self.toarray()[0, 0])
  337. # A row times a column.
  338. elif self.shape[1] == 1 and other.shape[0] == 1:
  339. return self._mul_sparse_matrix(other.tocsc())
  340. elif self.shape[0] == 1 and other.shape[1] == 1:
  341. return other._mul_sparse_matrix(self.tocsc())
  342. # Row vector times matrix. other is a row.
  343. elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
  344. other = self._dia_container(
  345. (other.toarray().ravel(), [0]),
  346. shape=(other.shape[1], other.shape[1])
  347. )
  348. return self._mul_sparse_matrix(other)
  349. # self is a row.
  350. elif self.shape[0] == 1 and self.shape[1] == other.shape[1]:
  351. copy = self._dia_container(
  352. (self.toarray().ravel(), [0]),
  353. shape=(self.shape[1], self.shape[1])
  354. )
  355. return other._mul_sparse_matrix(copy)
  356. # Column vector times matrix. other is a column.
  357. elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
  358. other = self._dia_container(
  359. (other.toarray().ravel(), [0]),
  360. shape=(other.shape[0], other.shape[0])
  361. )
  362. return other._mul_sparse_matrix(self)
  363. # self is a column.
  364. elif self.shape[1] == 1 and self.shape[0] == other.shape[0]:
  365. copy = self._dia_container(
  366. (self.toarray().ravel(), [0]),
  367. shape=(self.shape[0], self.shape[0])
  368. )
  369. return copy._mul_sparse_matrix(other)
  370. else:
  371. raise ValueError("inconsistent shapes")
  372. # Assume other is a dense matrix/array, which produces a single-item
  373. # object array if other isn't convertible to ndarray.
  374. other = np.atleast_2d(other)
  375. if other.ndim != 2:
  376. return np.multiply(self.toarray(), other)
  377. # Single element / wrapped object.
  378. if other.size == 1:
  379. return self._mul_scalar(other.flat[0])
  380. # Fast case for trivial sparse matrix.
  381. elif self.shape == (1, 1):
  382. return np.multiply(self.toarray()[0, 0], other)
  383. ret = self.tocoo()
  384. # Matching shapes.
  385. if self.shape == other.shape:
  386. data = np.multiply(ret.data, other[ret.row, ret.col])
  387. # Sparse row vector times...
  388. elif self.shape[0] == 1:
  389. if other.shape[1] == 1: # Dense column vector.
  390. data = np.multiply(ret.data, other)
  391. elif other.shape[1] == self.shape[1]: # Dense matrix.
  392. data = np.multiply(ret.data, other[:, ret.col])
  393. else:
  394. raise ValueError("inconsistent shapes")
  395. row = np.repeat(np.arange(other.shape[0]), len(ret.row))
  396. col = np.tile(ret.col, other.shape[0])
  397. return self._coo_container(
  398. (data.view(np.ndarray).ravel(), (row, col)),
  399. shape=(other.shape[0], self.shape[1]),
  400. copy=False
  401. )
  402. # Sparse column vector times...
  403. elif self.shape[1] == 1:
  404. if other.shape[0] == 1: # Dense row vector.
  405. data = np.multiply(ret.data[:, None], other)
  406. elif other.shape[0] == self.shape[0]: # Dense matrix.
  407. data = np.multiply(ret.data[:, None], other[ret.row])
  408. else:
  409. raise ValueError("inconsistent shapes")
  410. row = np.repeat(ret.row, other.shape[1])
  411. col = np.tile(np.arange(other.shape[1]), len(ret.col))
  412. return self._coo_container(
  413. (data.view(np.ndarray).ravel(), (row, col)),
  414. shape=(self.shape[0], other.shape[1]),
  415. copy=False
  416. )
  417. # Sparse matrix times dense row vector.
  418. elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
  419. data = np.multiply(ret.data, other[:, ret.col].ravel())
  420. # Sparse matrix times dense column vector.
  421. elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
  422. data = np.multiply(ret.data, other[ret.row].ravel())
  423. else:
  424. raise ValueError("inconsistent shapes")
  425. ret.data = data.view(np.ndarray).ravel()
  426. return ret
  427. ###########################
  428. # Multiplication handlers #
  429. ###########################
  430. def _mul_vector(self, other):
  431. M, N = self.shape
  432. # output array
  433. result = np.zeros(M, dtype=upcast_char(self.dtype.char,
  434. other.dtype.char))
  435. # csr_matvec or csc_matvec
  436. fn = getattr(_sparsetools, self.format + '_matvec')
  437. fn(M, N, self.indptr, self.indices, self.data, other, result)
  438. return result
  439. def _mul_multivector(self, other):
  440. M, N = self.shape
  441. n_vecs = other.shape[1] # number of column vectors
  442. result = np.zeros((M, n_vecs),
  443. dtype=upcast_char(self.dtype.char, other.dtype.char))
  444. # csr_matvecs or csc_matvecs
  445. fn = getattr(_sparsetools, self.format + '_matvecs')
  446. fn(M, N, n_vecs, self.indptr, self.indices, self.data,
  447. other.ravel(), result.ravel())
  448. return result
  449. def _mul_sparse_matrix(self, other):
  450. M, K1 = self.shape
  451. K2, N = other.shape
  452. major_axis = self._swap((M, N))[0]
  453. other = self.__class__(other) # convert to this format
  454. idx_dtype = get_index_dtype((self.indptr, self.indices,
  455. other.indptr, other.indices))
  456. fn = getattr(_sparsetools, self.format + '_matmat_maxnnz')
  457. nnz = fn(M, N,
  458. np.asarray(self.indptr, dtype=idx_dtype),
  459. np.asarray(self.indices, dtype=idx_dtype),
  460. np.asarray(other.indptr, dtype=idx_dtype),
  461. np.asarray(other.indices, dtype=idx_dtype))
  462. idx_dtype = get_index_dtype((self.indptr, self.indices,
  463. other.indptr, other.indices),
  464. maxval=nnz)
  465. indptr = np.empty(major_axis + 1, dtype=idx_dtype)
  466. indices = np.empty(nnz, dtype=idx_dtype)
  467. data = np.empty(nnz, dtype=upcast(self.dtype, other.dtype))
  468. fn = getattr(_sparsetools, self.format + '_matmat')
  469. fn(M, N, np.asarray(self.indptr, dtype=idx_dtype),
  470. np.asarray(self.indices, dtype=idx_dtype),
  471. self.data,
  472. np.asarray(other.indptr, dtype=idx_dtype),
  473. np.asarray(other.indices, dtype=idx_dtype),
  474. other.data,
  475. indptr, indices, data)
  476. return self.__class__((data, indices, indptr), shape=(M, N))
  477. def diagonal(self, k=0):
  478. rows, cols = self.shape
  479. if k <= -rows or k >= cols:
  480. return np.empty(0, dtype=self.data.dtype)
  481. fn = getattr(_sparsetools, self.format + "_diagonal")
  482. y = np.empty(min(rows + min(k, 0), cols - max(k, 0)),
  483. dtype=upcast(self.dtype))
  484. fn(k, self.shape[0], self.shape[1], self.indptr, self.indices,
  485. self.data, y)
  486. return y
  487. diagonal.__doc__ = spmatrix.diagonal.__doc__
  488. #####################
  489. # Other binary ops #
  490. #####################
  491. def _maximum_minimum(self, other, npop, op_name, dense_check):
  492. if isscalarlike(other):
  493. if dense_check(other):
  494. warn("Taking maximum (minimum) with > 0 (< 0) number results"
  495. " to a dense matrix.", SparseEfficiencyWarning,
  496. stacklevel=3)
  497. other_arr = np.empty(self.shape, dtype=np.asarray(other).dtype)
  498. other_arr.fill(other)
  499. other_arr = self.__class__(other_arr)
  500. return self._binopt(other_arr, op_name)
  501. else:
  502. self.sum_duplicates()
  503. new_data = npop(self.data, np.asarray(other))
  504. mat = self.__class__((new_data, self.indices, self.indptr),
  505. dtype=new_data.dtype, shape=self.shape)
  506. return mat
  507. elif isdense(other):
  508. return npop(self.todense(), other)
  509. elif isspmatrix(other):
  510. return self._binopt(other, op_name)
  511. else:
  512. raise ValueError("Operands not compatible.")
  513. def maximum(self, other):
  514. return self._maximum_minimum(other, np.maximum,
  515. '_maximum_', lambda x: np.asarray(x) > 0)
  516. maximum.__doc__ = spmatrix.maximum.__doc__
  517. def minimum(self, other):
  518. return self._maximum_minimum(other, np.minimum,
  519. '_minimum_', lambda x: np.asarray(x) < 0)
  520. minimum.__doc__ = spmatrix.minimum.__doc__
  521. #####################
  522. # Reduce operations #
  523. #####################
  524. def sum(self, axis=None, dtype=None, out=None):
  525. """Sum the matrix over the given axis. If the axis is None, sum
  526. over both rows and columns, returning a scalar.
  527. """
  528. # The spmatrix base class already does axis=0 and axis=1 efficiently
  529. # so we only do the case axis=None here
  530. if (not hasattr(self, 'blocksize') and
  531. axis in self._swap(((1, -1), (0, 2)))[0]):
  532. # faster than multiplication for large minor axis in CSC/CSR
  533. res_dtype = get_sum_dtype(self.dtype)
  534. ret = np.zeros(len(self.indptr) - 1, dtype=res_dtype)
  535. major_index, value = self._minor_reduce(np.add)
  536. ret[major_index] = value
  537. ret = self._ascontainer(ret)
  538. if axis % 2 == 1:
  539. ret = ret.T
  540. if out is not None and out.shape != ret.shape:
  541. raise ValueError('dimensions do not match')
  542. return ret.sum(axis=(), dtype=dtype, out=out)
  543. # spmatrix will handle the remaining situations when axis
  544. # is in {None, -1, 0, 1}
  545. else:
  546. return spmatrix.sum(self, axis=axis, dtype=dtype, out=out)
  547. sum.__doc__ = spmatrix.sum.__doc__
  548. def _minor_reduce(self, ufunc, data=None):
  549. """Reduce nonzeros with a ufunc over the minor axis when non-empty
  550. Can be applied to a function of self.data by supplying data parameter.
  551. Warning: this does not call sum_duplicates()
  552. Returns
  553. -------
  554. major_index : array of ints
  555. Major indices where nonzero
  556. value : array of self.dtype
  557. Reduce result for nonzeros in each major_index
  558. """
  559. if data is None:
  560. data = self.data
  561. major_index = np.flatnonzero(np.diff(self.indptr))
  562. value = ufunc.reduceat(data,
  563. downcast_intp_index(self.indptr[major_index]))
  564. return major_index, value
  565. #######################
  566. # Getting and Setting #
  567. #######################
  568. def _get_intXint(self, row, col):
  569. M, N = self._swap(self.shape)
  570. major, minor = self._swap((row, col))
  571. indptr, indices, data = get_csr_submatrix(
  572. M, N, self.indptr, self.indices, self.data,
  573. major, major + 1, minor, minor + 1)
  574. return data.sum(dtype=self.dtype)
  575. def _get_sliceXslice(self, row, col):
  576. major, minor = self._swap((row, col))
  577. if major.step in (1, None) and minor.step in (1, None):
  578. return self._get_submatrix(major, minor, copy=True)
  579. return self._major_slice(major)._minor_slice(minor)
  580. def _get_arrayXarray(self, row, col):
  581. # inner indexing
  582. idx_dtype = self.indices.dtype
  583. M, N = self._swap(self.shape)
  584. major, minor = self._swap((row, col))
  585. major = np.asarray(major, dtype=idx_dtype)
  586. minor = np.asarray(minor, dtype=idx_dtype)
  587. val = np.empty(major.size, dtype=self.dtype)
  588. csr_sample_values(M, N, self.indptr, self.indices, self.data,
  589. major.size, major.ravel(), minor.ravel(), val)
  590. if major.ndim == 1:
  591. return self._ascontainer(val)
  592. return self.__class__(val.reshape(major.shape))
  593. def _get_columnXarray(self, row, col):
  594. # outer indexing
  595. major, minor = self._swap((row, col))
  596. return self._major_index_fancy(major)._minor_index_fancy(minor)
  597. def _major_index_fancy(self, idx):
  598. """Index along the major axis where idx is an array of ints.
  599. """
  600. idx_dtype = self.indices.dtype
  601. indices = np.asarray(idx, dtype=idx_dtype).ravel()
  602. _, N = self._swap(self.shape)
  603. M = len(indices)
  604. new_shape = self._swap((M, N))
  605. if M == 0:
  606. return self.__class__(new_shape, dtype=self.dtype)
  607. row_nnz = self.indptr[indices + 1] - self.indptr[indices]
  608. idx_dtype = self.indices.dtype
  609. res_indptr = np.zeros(M+1, dtype=idx_dtype)
  610. np.cumsum(row_nnz, out=res_indptr[1:])
  611. nnz = res_indptr[-1]
  612. res_indices = np.empty(nnz, dtype=idx_dtype)
  613. res_data = np.empty(nnz, dtype=self.dtype)
  614. csr_row_index(M, indices, self.indptr, self.indices, self.data,
  615. res_indices, res_data)
  616. return self.__class__((res_data, res_indices, res_indptr),
  617. shape=new_shape, copy=False)
  618. def _major_slice(self, idx, copy=False):
  619. """Index along the major axis where idx is a slice object.
  620. """
  621. if idx == slice(None):
  622. return self.copy() if copy else self
  623. M, N = self._swap(self.shape)
  624. start, stop, step = idx.indices(M)
  625. M = len(range(start, stop, step))
  626. new_shape = self._swap((M, N))
  627. if M == 0:
  628. return self.__class__(new_shape, dtype=self.dtype)
  629. # Work out what slices are needed for `row_nnz`
  630. # start,stop can be -1, only if step is negative
  631. start0, stop0 = start, stop
  632. if stop == -1 and start >= 0:
  633. stop0 = None
  634. start1, stop1 = start + 1, stop + 1
  635. row_nnz = self.indptr[start1:stop1:step] - \
  636. self.indptr[start0:stop0:step]
  637. idx_dtype = self.indices.dtype
  638. res_indptr = np.zeros(M+1, dtype=idx_dtype)
  639. np.cumsum(row_nnz, out=res_indptr[1:])
  640. if step == 1:
  641. all_idx = slice(self.indptr[start], self.indptr[stop])
  642. res_indices = np.array(self.indices[all_idx], copy=copy)
  643. res_data = np.array(self.data[all_idx], copy=copy)
  644. else:
  645. nnz = res_indptr[-1]
  646. res_indices = np.empty(nnz, dtype=idx_dtype)
  647. res_data = np.empty(nnz, dtype=self.dtype)
  648. csr_row_slice(start, stop, step, self.indptr, self.indices,
  649. self.data, res_indices, res_data)
  650. return self.__class__((res_data, res_indices, res_indptr),
  651. shape=new_shape, copy=False)
  652. def _minor_index_fancy(self, idx):
  653. """Index along the minor axis where idx is an array of ints.
  654. """
  655. idx_dtype = self.indices.dtype
  656. idx = np.asarray(idx, dtype=idx_dtype).ravel()
  657. M, N = self._swap(self.shape)
  658. k = len(idx)
  659. new_shape = self._swap((M, k))
  660. if k == 0:
  661. return self.__class__(new_shape, dtype=self.dtype)
  662. # pass 1: count idx entries and compute new indptr
  663. col_offsets = np.zeros(N, dtype=idx_dtype)
  664. res_indptr = np.empty_like(self.indptr)
  665. csr_column_index1(k, idx, M, N, self.indptr, self.indices,
  666. col_offsets, res_indptr)
  667. # pass 2: copy indices/data for selected idxs
  668. col_order = np.argsort(idx).astype(idx_dtype, copy=False)
  669. nnz = res_indptr[-1]
  670. res_indices = np.empty(nnz, dtype=idx_dtype)
  671. res_data = np.empty(nnz, dtype=self.dtype)
  672. csr_column_index2(col_order, col_offsets, len(self.indices),
  673. self.indices, self.data, res_indices, res_data)
  674. return self.__class__((res_data, res_indices, res_indptr),
  675. shape=new_shape, copy=False)
  676. def _minor_slice(self, idx, copy=False):
  677. """Index along the minor axis where idx is a slice object.
  678. """
  679. if idx == slice(None):
  680. return self.copy() if copy else self
  681. M, N = self._swap(self.shape)
  682. start, stop, step = idx.indices(N)
  683. N = len(range(start, stop, step))
  684. if N == 0:
  685. return self.__class__(self._swap((M, N)), dtype=self.dtype)
  686. if step == 1:
  687. return self._get_submatrix(minor=idx, copy=copy)
  688. # TODO: don't fall back to fancy indexing here
  689. return self._minor_index_fancy(np.arange(start, stop, step))
  690. def _get_submatrix(self, major=None, minor=None, copy=False):
  691. """Return a submatrix of this matrix.
  692. major, minor: None, int, or slice with step 1
  693. """
  694. M, N = self._swap(self.shape)
  695. i0, i1 = _process_slice(major, M)
  696. j0, j1 = _process_slice(minor, N)
  697. if i0 == 0 and j0 == 0 and i1 == M and j1 == N:
  698. return self.copy() if copy else self
  699. indptr, indices, data = get_csr_submatrix(
  700. M, N, self.indptr, self.indices, self.data, i0, i1, j0, j1)
  701. shape = self._swap((i1 - i0, j1 - j0))
  702. return self.__class__((data, indices, indptr), shape=shape,
  703. dtype=self.dtype, copy=False)
  704. def _set_intXint(self, row, col, x):
  705. i, j = self._swap((row, col))
  706. self._set_many(i, j, x)
  707. def _set_arrayXarray(self, row, col, x):
  708. i, j = self._swap((row, col))
  709. self._set_many(i, j, x)
  710. def _set_arrayXarray_sparse(self, row, col, x):
  711. # clear entries that will be overwritten
  712. self._zero_many(*self._swap((row, col)))
  713. M, N = row.shape # matches col.shape
  714. broadcast_row = M != 1 and x.shape[0] == 1
  715. broadcast_col = N != 1 and x.shape[1] == 1
  716. r, c = x.row, x.col
  717. x = np.asarray(x.data, dtype=self.dtype)
  718. if x.size == 0:
  719. return
  720. if broadcast_row:
  721. r = np.repeat(np.arange(M), len(r))
  722. c = np.tile(c, M)
  723. x = np.tile(x, M)
  724. if broadcast_col:
  725. r = np.repeat(r, N)
  726. c = np.tile(np.arange(N), len(c))
  727. x = np.repeat(x, N)
  728. # only assign entries in the new sparsity structure
  729. i, j = self._swap((row[r, c], col[r, c]))
  730. self._set_many(i, j, x)
  731. def _setdiag(self, values, k):
  732. if 0 in self.shape:
  733. return
  734. M, N = self.shape
  735. broadcast = (values.ndim == 0)
  736. if k < 0:
  737. if broadcast:
  738. max_index = min(M + k, N)
  739. else:
  740. max_index = min(M + k, N, len(values))
  741. i = np.arange(max_index, dtype=self.indices.dtype)
  742. j = np.arange(max_index, dtype=self.indices.dtype)
  743. i -= k
  744. else:
  745. if broadcast:
  746. max_index = min(M, N - k)
  747. else:
  748. max_index = min(M, N - k, len(values))
  749. i = np.arange(max_index, dtype=self.indices.dtype)
  750. j = np.arange(max_index, dtype=self.indices.dtype)
  751. j += k
  752. if not broadcast:
  753. values = values[:len(i)]
  754. self[i, j] = values
  755. def _prepare_indices(self, i, j):
  756. M, N = self._swap(self.shape)
  757. def check_bounds(indices, bound):
  758. idx = indices.max()
  759. if idx >= bound:
  760. raise IndexError('index (%d) out of range (>= %d)' %
  761. (idx, bound))
  762. idx = indices.min()
  763. if idx < -bound:
  764. raise IndexError('index (%d) out of range (< -%d)' %
  765. (idx, bound))
  766. i = np.array(i, dtype=self.indices.dtype, copy=False, ndmin=1).ravel()
  767. j = np.array(j, dtype=self.indices.dtype, copy=False, ndmin=1).ravel()
  768. check_bounds(i, M)
  769. check_bounds(j, N)
  770. return i, j, M, N
  771. def _set_many(self, i, j, x):
  772. """Sets value at each (i, j) to x
  773. Here (i,j) index major and minor respectively, and must not contain
  774. duplicate entries.
  775. """
  776. i, j, M, N = self._prepare_indices(i, j)
  777. x = np.array(x, dtype=self.dtype, copy=False, ndmin=1).ravel()
  778. n_samples = x.size
  779. offsets = np.empty(n_samples, dtype=self.indices.dtype)
  780. ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  781. i, j, offsets)
  782. if ret == 1:
  783. # rinse and repeat
  784. self.sum_duplicates()
  785. csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  786. i, j, offsets)
  787. if -1 not in offsets:
  788. # only affects existing non-zero cells
  789. self.data[offsets] = x
  790. return
  791. else:
  792. warn("Changing the sparsity structure of a {}_matrix is expensive."
  793. " lil_matrix is more efficient.".format(self.format),
  794. SparseEfficiencyWarning, stacklevel=3)
  795. # replace where possible
  796. mask = offsets > -1
  797. self.data[offsets[mask]] = x[mask]
  798. # only insertions remain
  799. mask = ~mask
  800. i = i[mask]
  801. i[i < 0] += M
  802. j = j[mask]
  803. j[j < 0] += N
  804. self._insert_many(i, j, x[mask])
  805. def _zero_many(self, i, j):
  806. """Sets value at each (i, j) to zero, preserving sparsity structure.
  807. Here (i,j) index major and minor respectively.
  808. """
  809. i, j, M, N = self._prepare_indices(i, j)
  810. n_samples = len(i)
  811. offsets = np.empty(n_samples, dtype=self.indices.dtype)
  812. ret = csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  813. i, j, offsets)
  814. if ret == 1:
  815. # rinse and repeat
  816. self.sum_duplicates()
  817. csr_sample_offsets(M, N, self.indptr, self.indices, n_samples,
  818. i, j, offsets)
  819. # only assign zeros to the existing sparsity structure
  820. self.data[offsets[offsets > -1]] = 0
  821. def _insert_many(self, i, j, x):
  822. """Inserts new nonzero at each (i, j) with value x
  823. Here (i,j) index major and minor respectively.
  824. i, j and x must be non-empty, 1d arrays.
  825. Inserts each major group (e.g. all entries per row) at a time.
  826. Maintains has_sorted_indices property.
  827. Modifies i, j, x in place.
  828. """
  829. order = np.argsort(i, kind='mergesort') # stable for duplicates
  830. i = i.take(order, mode='clip')
  831. j = j.take(order, mode='clip')
  832. x = x.take(order, mode='clip')
  833. do_sort = self.has_sorted_indices
  834. # Update index data type
  835. idx_dtype = get_index_dtype((self.indices, self.indptr),
  836. maxval=(self.indptr[-1] + x.size))
  837. self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
  838. self.indices = np.asarray(self.indices, dtype=idx_dtype)
  839. i = np.asarray(i, dtype=idx_dtype)
  840. j = np.asarray(j, dtype=idx_dtype)
  841. # Collate old and new in chunks by major index
  842. indices_parts = []
  843. data_parts = []
  844. ui, ui_indptr = np.unique(i, return_index=True)
  845. ui_indptr = np.append(ui_indptr, len(j))
  846. new_nnzs = np.diff(ui_indptr)
  847. prev = 0
  848. for c, (ii, js, je) in enumerate(zip(ui, ui_indptr, ui_indptr[1:])):
  849. # old entries
  850. start = self.indptr[prev]
  851. stop = self.indptr[ii]
  852. indices_parts.append(self.indices[start:stop])
  853. data_parts.append(self.data[start:stop])
  854. # handle duplicate j: keep last setting
  855. uj, uj_indptr = np.unique(j[js:je][::-1], return_index=True)
  856. if len(uj) == je - js:
  857. indices_parts.append(j[js:je])
  858. data_parts.append(x[js:je])
  859. else:
  860. indices_parts.append(j[js:je][::-1][uj_indptr])
  861. data_parts.append(x[js:je][::-1][uj_indptr])
  862. new_nnzs[c] = len(uj)
  863. prev = ii
  864. # remaining old entries
  865. start = self.indptr[ii]
  866. indices_parts.append(self.indices[start:])
  867. data_parts.append(self.data[start:])
  868. # update attributes
  869. self.indices = np.concatenate(indices_parts)
  870. self.data = np.concatenate(data_parts)
  871. nnzs = np.empty(self.indptr.shape, dtype=idx_dtype)
  872. nnzs[0] = idx_dtype(0)
  873. indptr_diff = np.diff(self.indptr)
  874. indptr_diff[ui] += new_nnzs
  875. nnzs[1:] = indptr_diff
  876. self.indptr = np.cumsum(nnzs, out=nnzs)
  877. if do_sort:
  878. # TODO: only sort where necessary
  879. self.has_sorted_indices = False
  880. self.sort_indices()
  881. self.check_format(full_check=False)
  882. ######################
  883. # Conversion methods #
  884. ######################
  885. def tocoo(self, copy=True):
  886. major_dim, minor_dim = self._swap(self.shape)
  887. minor_indices = self.indices
  888. major_indices = np.empty(len(minor_indices), dtype=self.indices.dtype)
  889. _sparsetools.expandptr(major_dim, self.indptr, major_indices)
  890. row, col = self._swap((major_indices, minor_indices))
  891. return self._coo_container(
  892. (self.data, (row, col)), self.shape, copy=copy,
  893. dtype=self.dtype
  894. )
  895. tocoo.__doc__ = spmatrix.tocoo.__doc__
  896. def toarray(self, order=None, out=None):
  897. if out is None and order is None:
  898. order = self._swap('cf')[0]
  899. out = self._process_toarray_args(order, out)
  900. if not (out.flags.c_contiguous or out.flags.f_contiguous):
  901. raise ValueError('Output array must be C or F contiguous')
  902. # align ideal order with output array order
  903. if out.flags.c_contiguous:
  904. x = self.tocsr()
  905. y = out
  906. else:
  907. x = self.tocsc()
  908. y = out.T
  909. M, N = x._swap(x.shape)
  910. csr_todense(M, N, x.indptr, x.indices, x.data, y)
  911. return out
  912. toarray.__doc__ = spmatrix.toarray.__doc__
  913. ##############################################################
  914. # methods that examine or modify the internal data structure #
  915. ##############################################################
  916. def eliminate_zeros(self):
  917. """Remove zero entries from the matrix
  918. This is an *in place* operation.
  919. """
  920. M, N = self._swap(self.shape)
  921. _sparsetools.csr_eliminate_zeros(M, N, self.indptr, self.indices,
  922. self.data)
  923. self.prune() # nnz may have changed
  924. def __get_has_canonical_format(self):
  925. """Determine whether the matrix has sorted indices and no duplicates
  926. Returns
  927. - True: if the above applies
  928. - False: otherwise
  929. has_canonical_format implies has_sorted_indices, so if the latter flag
  930. is False, so will the former be; if the former is found True, the
  931. latter flag is also set.
  932. """
  933. # first check to see if result was cached
  934. if not getattr(self, '_has_sorted_indices', True):
  935. # not sorted => not canonical
  936. self._has_canonical_format = False
  937. elif not hasattr(self, '_has_canonical_format'):
  938. self.has_canonical_format = bool(
  939. _sparsetools.csr_has_canonical_format(
  940. len(self.indptr) - 1, self.indptr, self.indices))
  941. return self._has_canonical_format
  942. def __set_has_canonical_format(self, val):
  943. self._has_canonical_format = bool(val)
  944. if val:
  945. self.has_sorted_indices = True
  946. has_canonical_format = property(fget=__get_has_canonical_format,
  947. fset=__set_has_canonical_format)
  948. def sum_duplicates(self):
  949. """Eliminate duplicate matrix entries by adding them together
  950. This is an *in place* operation.
  951. """
  952. if self.has_canonical_format:
  953. return
  954. self.sort_indices()
  955. M, N = self._swap(self.shape)
  956. _sparsetools.csr_sum_duplicates(M, N, self.indptr, self.indices,
  957. self.data)
  958. self.prune() # nnz may have changed
  959. self.has_canonical_format = True
  960. def __get_sorted(self):
  961. """Determine whether the matrix has sorted indices
  962. Returns
  963. - True: if the indices of the matrix are in sorted order
  964. - False: otherwise
  965. """
  966. # first check to see if result was cached
  967. if not hasattr(self, '_has_sorted_indices'):
  968. self._has_sorted_indices = bool(
  969. _sparsetools.csr_has_sorted_indices(
  970. len(self.indptr) - 1, self.indptr, self.indices))
  971. return self._has_sorted_indices
  972. def __set_sorted(self, val):
  973. self._has_sorted_indices = bool(val)
  974. has_sorted_indices = property(fget=__get_sorted, fset=__set_sorted)
  975. def sorted_indices(self):
  976. """Return a copy of this matrix with sorted indices
  977. """
  978. A = self.copy()
  979. A.sort_indices()
  980. return A
  981. # an alternative that has linear complexity is the following
  982. # although the previous option is typically faster
  983. # return self.toother().toother()
  984. def sort_indices(self):
  985. """Sort the indices of this matrix *in place*
  986. """
  987. if not self.has_sorted_indices:
  988. _sparsetools.csr_sort_indices(len(self.indptr) - 1, self.indptr,
  989. self.indices, self.data)
  990. self.has_sorted_indices = True
  991. def prune(self):
  992. """Remove empty space after all non-zero elements.
  993. """
  994. major_dim = self._swap(self.shape)[0]
  995. if len(self.indptr) != major_dim + 1:
  996. raise ValueError('index pointer has invalid length')
  997. if len(self.indices) < self.nnz:
  998. raise ValueError('indices array has fewer than nnz elements')
  999. if len(self.data) < self.nnz:
  1000. raise ValueError('data array has fewer than nnz elements')
  1001. self.indices = _prune_array(self.indices[:self.nnz])
  1002. self.data = _prune_array(self.data[:self.nnz])
  1003. def resize(self, *shape):
  1004. shape = check_shape(shape)
  1005. if hasattr(self, 'blocksize'):
  1006. bm, bn = self.blocksize
  1007. new_M, rm = divmod(shape[0], bm)
  1008. new_N, rn = divmod(shape[1], bn)
  1009. if rm or rn:
  1010. raise ValueError("shape must be divisible into %s blocks. "
  1011. "Got %s" % (self.blocksize, shape))
  1012. M, N = self.shape[0] // bm, self.shape[1] // bn
  1013. else:
  1014. new_M, new_N = self._swap(shape)
  1015. M, N = self._swap(self.shape)
  1016. if new_M < M:
  1017. self.indices = self.indices[:self.indptr[new_M]]
  1018. self.data = self.data[:self.indptr[new_M]]
  1019. self.indptr = self.indptr[:new_M + 1]
  1020. elif new_M > M:
  1021. self.indptr = np.resize(self.indptr, new_M + 1)
  1022. self.indptr[M + 1:].fill(self.indptr[M])
  1023. if new_N < N:
  1024. mask = self.indices < new_N
  1025. if not np.all(mask):
  1026. self.indices = self.indices[mask]
  1027. self.data = self.data[mask]
  1028. major_index, val = self._minor_reduce(np.add, mask)
  1029. self.indptr.fill(0)
  1030. self.indptr[1:][major_index] = val
  1031. np.cumsum(self.indptr, out=self.indptr)
  1032. self._shape = shape
  1033. resize.__doc__ = spmatrix.resize.__doc__
  1034. ###################
  1035. # utility methods #
  1036. ###################
  1037. # needed by _data_matrix
  1038. def _with_data(self, data, copy=True):
  1039. """Returns a matrix with the same sparsity structure as self,
  1040. but with different data. By default the structure arrays
  1041. (i.e. .indptr and .indices) are copied.
  1042. """
  1043. if copy:
  1044. return self.__class__((data, self.indices.copy(),
  1045. self.indptr.copy()),
  1046. shape=self.shape,
  1047. dtype=data.dtype)
  1048. else:
  1049. return self.__class__((data, self.indices, self.indptr),
  1050. shape=self.shape, dtype=data.dtype)
  1051. def _binopt(self, other, op):
  1052. """apply the binary operation fn to two sparse matrices."""
  1053. other = self.__class__(other)
  1054. # e.g. csr_plus_csr, csr_minus_csr, etc.
  1055. fn = getattr(_sparsetools, self.format + op + self.format)
  1056. maxnnz = self.nnz + other.nnz
  1057. idx_dtype = get_index_dtype((self.indptr, self.indices,
  1058. other.indptr, other.indices),
  1059. maxval=maxnnz)
  1060. indptr = np.empty(self.indptr.shape, dtype=idx_dtype)
  1061. indices = np.empty(maxnnz, dtype=idx_dtype)
  1062. bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
  1063. if op in bool_ops:
  1064. data = np.empty(maxnnz, dtype=np.bool_)
  1065. else:
  1066. data = np.empty(maxnnz, dtype=upcast(self.dtype, other.dtype))
  1067. fn(self.shape[0], self.shape[1],
  1068. np.asarray(self.indptr, dtype=idx_dtype),
  1069. np.asarray(self.indices, dtype=idx_dtype),
  1070. self.data,
  1071. np.asarray(other.indptr, dtype=idx_dtype),
  1072. np.asarray(other.indices, dtype=idx_dtype),
  1073. other.data,
  1074. indptr, indices, data)
  1075. A = self.__class__((data, indices, indptr), shape=self.shape)
  1076. A.prune()
  1077. return A
  1078. def _divide_sparse(self, other):
  1079. """
  1080. Divide this matrix by a second sparse matrix.
  1081. """
  1082. if other.shape != self.shape:
  1083. raise ValueError('inconsistent shapes')
  1084. r = self._binopt(other, '_eldiv_')
  1085. if np.issubdtype(r.dtype, np.inexact):
  1086. # Eldiv leaves entries outside the combined sparsity
  1087. # pattern empty, so they must be filled manually.
  1088. # Everything outside of other's sparsity is NaN, and everything
  1089. # inside it is either zero or defined by eldiv.
  1090. out = np.empty(self.shape, dtype=self.dtype)
  1091. out.fill(np.nan)
  1092. row, col = other.nonzero()
  1093. out[row, col] = 0
  1094. r = r.tocoo()
  1095. out[r.row, r.col] = r.data
  1096. out = self._container(out)
  1097. else:
  1098. # integers types go with nan <-> 0
  1099. out = r
  1100. return out
  1101. def _process_slice(sl, num):
  1102. if sl is None:
  1103. i0, i1 = 0, num
  1104. elif isinstance(sl, slice):
  1105. i0, i1, stride = sl.indices(num)
  1106. if stride != 1:
  1107. raise ValueError('slicing with step != 1 not supported')
  1108. i0 = min(i0, i1) # give an empty slice when i0 > i1
  1109. elif isintlike(sl):
  1110. if sl < 0:
  1111. sl += num
  1112. i0, i1 = sl, sl + 1
  1113. if i0 < 0 or i1 > num:
  1114. raise IndexError('index out of bounds: 0 <= %d < %d <= %d' %
  1115. (i0, i1, num))
  1116. else:
  1117. raise TypeError('expected slice or scalar')
  1118. return i0, i1