_sputils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. """ Utility functions for sparse matrix module
  2. """
  3. import sys
  4. import operator
  5. import numpy as np
  6. from scipy._lib._util import prod
  7. import scipy.sparse as sp
  8. __all__ = ['upcast', 'getdtype', 'getdata', 'isscalarlike', 'isintlike',
  9. 'isshape', 'issequence', 'isdense', 'ismatrix', 'get_sum_dtype']
  10. supported_dtypes = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc,
  11. np.uintc, np.int_, np.uint, np.longlong, np.ulonglong,
  12. np.single, np.double,
  13. np.longdouble, np.csingle, np.cdouble, np.clongdouble]
  14. _upcast_memo = {}
  15. def upcast(*args):
  16. """Returns the nearest supported sparse dtype for the
  17. combination of one or more types.
  18. upcast(t0, t1, ..., tn) -> T where T is a supported dtype
  19. Examples
  20. --------
  21. >>> upcast('int32')
  22. <type 'numpy.int32'>
  23. >>> upcast('bool')
  24. <type 'numpy.bool_'>
  25. >>> upcast('int32','float32')
  26. <type 'numpy.float64'>
  27. >>> upcast('bool',complex,float)
  28. <type 'numpy.complex128'>
  29. """
  30. t = _upcast_memo.get(hash(args))
  31. if t is not None:
  32. return t
  33. upcast = np.result_type(*args)
  34. for t in supported_dtypes:
  35. if np.can_cast(upcast, t):
  36. _upcast_memo[hash(args)] = t
  37. return t
  38. raise TypeError('no supported conversion for types: %r' % (args,))
  39. def upcast_char(*args):
  40. """Same as `upcast` but taking dtype.char as input (faster)."""
  41. t = _upcast_memo.get(args)
  42. if t is not None:
  43. return t
  44. t = upcast(*map(np.dtype, args))
  45. _upcast_memo[args] = t
  46. return t
  47. def upcast_scalar(dtype, scalar):
  48. """Determine data type for binary operation between an array of
  49. type `dtype` and a scalar.
  50. """
  51. return (np.array([0], dtype=dtype) * scalar).dtype
  52. def downcast_intp_index(arr):
  53. """
  54. Down-cast index array to np.intp dtype if it is of a larger dtype.
  55. Raise an error if the array contains a value that is too large for
  56. intp.
  57. """
  58. if arr.dtype.itemsize > np.dtype(np.intp).itemsize:
  59. if arr.size == 0:
  60. return arr.astype(np.intp)
  61. maxval = arr.max()
  62. minval = arr.min()
  63. if maxval > np.iinfo(np.intp).max or minval < np.iinfo(np.intp).min:
  64. raise ValueError("Cannot deal with arrays with indices larger "
  65. "than the machine maximum address size "
  66. "(e.g. 64-bit indices on 32-bit machine).")
  67. return arr.astype(np.intp)
  68. return arr
  69. def to_native(A):
  70. """
  71. Ensure that the data type of the NumPy array `A` has native byte order.
  72. `A` must be a NumPy array. If the data type of `A` does not have native
  73. byte order, a copy of `A` with a native byte order is returned. Otherwise
  74. `A` is returned.
  75. """
  76. dt = A.dtype
  77. if dt.isnative:
  78. # Don't call `asarray()` if A is already native, to avoid unnecessarily
  79. # creating a view of the input array.
  80. return A
  81. return np.asarray(A, dtype=dt.newbyteorder('native'))
  82. def getdtype(dtype, a=None, default=None):
  83. """Function used to simplify argument processing. If 'dtype' is not
  84. specified (is None), returns a.dtype; otherwise returns a np.dtype
  85. object created from the specified dtype argument. If 'dtype' and 'a'
  86. are both None, construct a data type out of the 'default' parameter.
  87. Furthermore, 'dtype' must be in 'allowed' set.
  88. """
  89. # TODO is this really what we want?
  90. if dtype is None:
  91. try:
  92. newdtype = a.dtype
  93. except AttributeError as e:
  94. if default is not None:
  95. newdtype = np.dtype(default)
  96. else:
  97. raise TypeError("could not interpret data type") from e
  98. else:
  99. newdtype = np.dtype(dtype)
  100. if newdtype == np.object_:
  101. raise ValueError(
  102. "object dtype is not supported by sparse matrices"
  103. )
  104. return newdtype
  105. def getdata(obj, dtype=None, copy=False):
  106. """
  107. This is a wrapper of `np.array(obj, dtype=dtype, copy=copy)`
  108. that will generate a warning if the result is an object array.
  109. """
  110. data = np.array(obj, dtype=dtype, copy=copy)
  111. # Defer to getdtype for checking that the dtype is OK.
  112. # This is called for the validation only; we don't need the return value.
  113. getdtype(data.dtype)
  114. return data
  115. def get_index_dtype(arrays=(), maxval=None, check_contents=False):
  116. """
  117. Based on input (integer) arrays `a`, determine a suitable index data
  118. type that can hold the data in the arrays.
  119. Parameters
  120. ----------
  121. arrays : tuple of array_like
  122. Input arrays whose types/contents to check
  123. maxval : float, optional
  124. Maximum value needed
  125. check_contents : bool, optional
  126. Whether to check the values in the arrays and not just their types.
  127. Default: False (check only the types)
  128. Returns
  129. -------
  130. dtype : dtype
  131. Suitable index data type (int32 or int64)
  132. """
  133. int32min = np.int32(np.iinfo(np.int32).min)
  134. int32max = np.int32(np.iinfo(np.int32).max)
  135. # not using intc directly due to misinteractions with pythran
  136. dtype = np.int32 if np.intc().itemsize == 4 else np.int64
  137. if maxval is not None:
  138. maxval = np.int64(maxval)
  139. if maxval > int32max:
  140. dtype = np.int64
  141. if isinstance(arrays, np.ndarray):
  142. arrays = (arrays,)
  143. for arr in arrays:
  144. arr = np.asarray(arr)
  145. if not np.can_cast(arr.dtype, np.int32):
  146. if check_contents:
  147. if arr.size == 0:
  148. # a bigger type not needed
  149. continue
  150. elif np.issubdtype(arr.dtype, np.integer):
  151. maxval = arr.max()
  152. minval = arr.min()
  153. if minval >= int32min and maxval <= int32max:
  154. # a bigger type not needed
  155. continue
  156. dtype = np.int64
  157. break
  158. return dtype
  159. def get_sum_dtype(dtype):
  160. """Mimic numpy's casting for np.sum"""
  161. if dtype.kind == 'u' and np.can_cast(dtype, np.uint):
  162. return np.uint
  163. if np.can_cast(dtype, np.int_):
  164. return np.int_
  165. return dtype
  166. def isscalarlike(x):
  167. """Is x either a scalar, an array scalar, or a 0-dim array?"""
  168. return np.isscalar(x) or (isdense(x) and x.ndim == 0)
  169. def isintlike(x):
  170. """Is x appropriate as an index into a sparse matrix? Returns True
  171. if it can be cast safely to a machine int.
  172. """
  173. # Fast-path check to eliminate non-scalar values. operator.index would
  174. # catch this case too, but the exception catching is slow.
  175. if np.ndim(x) != 0:
  176. return False
  177. try:
  178. operator.index(x)
  179. except (TypeError, ValueError):
  180. try:
  181. loose_int = bool(int(x) == x)
  182. except (TypeError, ValueError):
  183. return False
  184. if loose_int:
  185. msg = "Inexact indices into sparse matrices are not allowed"
  186. raise ValueError(msg)
  187. return loose_int
  188. return True
  189. def isshape(x, nonneg=False):
  190. """Is x a valid 2-tuple of dimensions?
  191. If nonneg, also checks that the dimensions are non-negative.
  192. """
  193. try:
  194. # Assume it's a tuple of matrix dimensions (M, N)
  195. (M, N) = x
  196. except Exception:
  197. return False
  198. else:
  199. if isintlike(M) and isintlike(N):
  200. if np.ndim(M) == 0 and np.ndim(N) == 0:
  201. if not nonneg or (M >= 0 and N >= 0):
  202. return True
  203. return False
  204. def issequence(t):
  205. return ((isinstance(t, (list, tuple)) and
  206. (len(t) == 0 or np.isscalar(t[0]))) or
  207. (isinstance(t, np.ndarray) and (t.ndim == 1)))
  208. def ismatrix(t):
  209. return ((isinstance(t, (list, tuple)) and
  210. len(t) > 0 and issequence(t[0])) or
  211. (isinstance(t, np.ndarray) and t.ndim == 2))
  212. def isdense(x):
  213. return isinstance(x, np.ndarray)
  214. def validateaxis(axis):
  215. if axis is not None:
  216. axis_type = type(axis)
  217. # In NumPy, you can pass in tuples for 'axis', but they are
  218. # not very useful for sparse matrices given their limited
  219. # dimensions, so let's make it explicit that they are not
  220. # allowed to be passed in
  221. if axis_type == tuple:
  222. raise TypeError(("Tuples are not accepted for the 'axis' "
  223. "parameter. Please pass in one of the "
  224. "following: {-2, -1, 0, 1, None}."))
  225. # If not a tuple, check that the provided axis is actually
  226. # an integer and raise a TypeError similar to NumPy's
  227. if not np.issubdtype(np.dtype(axis_type), np.integer):
  228. raise TypeError("axis must be an integer, not {name}"
  229. .format(name=axis_type.__name__))
  230. if not (-2 <= axis <= 1):
  231. raise ValueError("axis out of range")
  232. def check_shape(args, current_shape=None):
  233. """Imitate numpy.matrix handling of shape arguments"""
  234. if len(args) == 0:
  235. raise TypeError("function missing 1 required positional argument: "
  236. "'shape'")
  237. elif len(args) == 1:
  238. try:
  239. shape_iter = iter(args[0])
  240. except TypeError:
  241. new_shape = (operator.index(args[0]), )
  242. else:
  243. new_shape = tuple(operator.index(arg) for arg in shape_iter)
  244. else:
  245. new_shape = tuple(operator.index(arg) for arg in args)
  246. if current_shape is None:
  247. if len(new_shape) != 2:
  248. raise ValueError('shape must be a 2-tuple of positive integers')
  249. elif any(d < 0 for d in new_shape):
  250. raise ValueError("'shape' elements cannot be negative")
  251. else:
  252. # Check the current size only if needed
  253. current_size = prod(current_shape)
  254. # Check for negatives
  255. negative_indexes = [i for i, x in enumerate(new_shape) if x < 0]
  256. if len(negative_indexes) == 0:
  257. new_size = prod(new_shape)
  258. if new_size != current_size:
  259. raise ValueError('cannot reshape array of size {} into shape {}'
  260. .format(current_size, new_shape))
  261. elif len(negative_indexes) == 1:
  262. skip = negative_indexes[0]
  263. specified = prod(new_shape[0:skip] + new_shape[skip+1:])
  264. unspecified, remainder = divmod(current_size, specified)
  265. if remainder != 0:
  266. err_shape = tuple('newshape' if x < 0 else x for x in new_shape)
  267. raise ValueError('cannot reshape array of size {} into shape {}'
  268. ''.format(current_size, err_shape))
  269. new_shape = new_shape[0:skip] + (unspecified,) + new_shape[skip+1:]
  270. else:
  271. raise ValueError('can only specify one unknown dimension')
  272. if len(new_shape) != 2:
  273. raise ValueError('matrix shape must be two-dimensional')
  274. return new_shape
  275. def check_reshape_kwargs(kwargs):
  276. """Unpack keyword arguments for reshape function.
  277. This is useful because keyword arguments after star arguments are not
  278. allowed in Python 2, but star keyword arguments are. This function unpacks
  279. 'order' and 'copy' from the star keyword arguments (with defaults) and
  280. throws an error for any remaining.
  281. """
  282. order = kwargs.pop('order', 'C')
  283. copy = kwargs.pop('copy', False)
  284. if kwargs: # Some unused kwargs remain
  285. raise TypeError('reshape() got unexpected keywords arguments: {}'
  286. .format(', '.join(kwargs.keys())))
  287. return order, copy
  288. def is_pydata_spmatrix(m):
  289. """
  290. Check whether object is pydata/sparse matrix, avoiding importing the module.
  291. """
  292. base_cls = getattr(sys.modules.get('sparse'), 'SparseArray', None)
  293. return base_cls is not None and isinstance(m, base_cls)
  294. ###############################################################################
  295. # Wrappers for NumPy types that are deprecated
  296. # Numpy versions of these functions raise deprecation warnings, the
  297. # ones below do not.
  298. def matrix(*args, **kwargs):
  299. return np.array(*args, **kwargs).view(np.matrix)
  300. def asmatrix(data, dtype=None):
  301. if isinstance(data, np.matrix) and (dtype is None or data.dtype == dtype):
  302. return data
  303. return np.asarray(data, dtype=dtype).view(np.matrix)
  304. ###############################################################################
  305. def _todata(s: 'sp.spmatrix') -> np.ndarray:
  306. """Access nonzero values, possibly after summing duplicates.
  307. Parameters
  308. ----------
  309. s : sparse matrix
  310. Input sparse matrix.
  311. Returns
  312. -------
  313. data: ndarray
  314. Nonzero values of the array, with shape (s.nnz,)
  315. """
  316. if isinstance(s, sp._data._data_matrix):
  317. return s._deduped_data()
  318. if isinstance(s, sp.dok_matrix):
  319. return np.fromiter(s.values(), dtype=s.dtype, count=s.nnz)
  320. if isinstance(s, sp.lil_matrix):
  321. data = np.empty(s.nnz, dtype=s.dtype)
  322. sp._csparsetools.lil_flatten_to_array(s.data, data)
  323. return data
  324. return s.tocoo()._deduped_data()