_interface.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829
  1. """Abstract linear algebra library.
  2. This module defines a class hierarchy that implements a kind of "lazy"
  3. matrix representation, called the ``LinearOperator``. It can be used to do
  4. linear algebra with extremely large sparse or structured matrices, without
  5. representing those explicitly in memory. Such matrices can be added,
  6. multiplied, transposed, etc.
  7. As a motivating example, suppose you want have a matrix where almost all of
  8. the elements have the value one. The standard sparse matrix representation
  9. skips the storage of zeros, but not ones. By contrast, a LinearOperator is
  10. able to represent such matrices efficiently. First, we need a compact way to
  11. represent an all-ones matrix::
  12. >>> import numpy as np
  13. >>> class Ones(LinearOperator):
  14. ... def __init__(self, shape):
  15. ... super().__init__(dtype=None, shape=shape)
  16. ... def _matvec(self, x):
  17. ... return np.repeat(x.sum(), self.shape[0])
  18. Instances of this class emulate ``np.ones(shape)``, but using a constant
  19. amount of storage, independent of ``shape``. The ``_matvec`` method specifies
  20. how this linear operator multiplies with (operates on) a vector. We can now
  21. add this operator to a sparse matrix that stores only offsets from one::
  22. >>> from scipy.sparse import csr_matrix
  23. >>> offsets = csr_matrix([[1, 0, 2], [0, -1, 0], [0, 0, 3]])
  24. >>> A = aslinearoperator(offsets) + Ones(offsets.shape)
  25. >>> A.dot([1, 2, 3])
  26. array([13, 4, 15])
  27. The result is the same as that given by its dense, explicitly-stored
  28. counterpart::
  29. >>> (np.ones(A.shape, A.dtype) + offsets.toarray()).dot([1, 2, 3])
  30. array([13, 4, 15])
  31. Several algorithms in the ``scipy.sparse`` library are able to operate on
  32. ``LinearOperator`` instances.
  33. """
  34. import warnings
  35. import numpy as np
  36. from scipy.sparse import isspmatrix
  37. from scipy.sparse._sputils import isshape, isintlike, asmatrix, is_pydata_spmatrix
  38. __all__ = ['LinearOperator', 'aslinearoperator']
  39. class LinearOperator:
  40. """Common interface for performing matrix vector products
  41. Many iterative methods (e.g. cg, gmres) do not need to know the
  42. individual entries of a matrix to solve a linear system A*x=b.
  43. Such solvers only require the computation of matrix vector
  44. products, A*v where v is a dense vector. This class serves as
  45. an abstract interface between iterative solvers and matrix-like
  46. objects.
  47. To construct a concrete LinearOperator, either pass appropriate
  48. callables to the constructor of this class, or subclass it.
  49. A subclass must implement either one of the methods ``_matvec``
  50. and ``_matmat``, and the attributes/properties ``shape`` (pair of
  51. integers) and ``dtype`` (may be None). It may call the ``__init__``
  52. on this class to have these attributes validated. Implementing
  53. ``_matvec`` automatically implements ``_matmat`` (using a naive
  54. algorithm) and vice-versa.
  55. Optionally, a subclass may implement ``_rmatvec`` or ``_adjoint``
  56. to implement the Hermitian adjoint (conjugate transpose). As with
  57. ``_matvec`` and ``_matmat``, implementing either ``_rmatvec`` or
  58. ``_adjoint`` implements the other automatically. Implementing
  59. ``_adjoint`` is preferable; ``_rmatvec`` is mostly there for
  60. backwards compatibility.
  61. Parameters
  62. ----------
  63. shape : tuple
  64. Matrix dimensions (M, N).
  65. matvec : callable f(v)
  66. Returns returns A * v.
  67. rmatvec : callable f(v)
  68. Returns A^H * v, where A^H is the conjugate transpose of A.
  69. matmat : callable f(V)
  70. Returns A * V, where V is a dense matrix with dimensions (N, K).
  71. dtype : dtype
  72. Data type of the matrix.
  73. rmatmat : callable f(V)
  74. Returns A^H * V, where V is a dense matrix with dimensions (M, K).
  75. Attributes
  76. ----------
  77. args : tuple
  78. For linear operators describing products etc. of other linear
  79. operators, the operands of the binary operation.
  80. ndim : int
  81. Number of dimensions (this is always 2)
  82. See Also
  83. --------
  84. aslinearoperator : Construct LinearOperators
  85. Notes
  86. -----
  87. The user-defined matvec() function must properly handle the case
  88. where v has shape (N,) as well as the (N,1) case. The shape of
  89. the return type is handled internally by LinearOperator.
  90. LinearOperator instances can also be multiplied, added with each
  91. other and exponentiated, all lazily: the result of these operations
  92. is always a new, composite LinearOperator, that defers linear
  93. operations to the original operators and combines the results.
  94. More details regarding how to subclass a LinearOperator and several
  95. examples of concrete LinearOperator instances can be found in the
  96. external project `PyLops <https://pylops.readthedocs.io>`_.
  97. Examples
  98. --------
  99. >>> import numpy as np
  100. >>> from scipy.sparse.linalg import LinearOperator
  101. >>> def mv(v):
  102. ... return np.array([2*v[0], 3*v[1]])
  103. ...
  104. >>> A = LinearOperator((2,2), matvec=mv)
  105. >>> A
  106. <2x2 _CustomLinearOperator with dtype=float64>
  107. >>> A.matvec(np.ones(2))
  108. array([ 2., 3.])
  109. >>> A * np.ones(2)
  110. array([ 2., 3.])
  111. """
  112. ndim = 2
  113. def __new__(cls, *args, **kwargs):
  114. if cls is LinearOperator:
  115. # Operate as _CustomLinearOperator factory.
  116. return super(LinearOperator, cls).__new__(_CustomLinearOperator)
  117. else:
  118. obj = super(LinearOperator, cls).__new__(cls)
  119. if (type(obj)._matvec == LinearOperator._matvec
  120. and type(obj)._matmat == LinearOperator._matmat):
  121. warnings.warn("LinearOperator subclass should implement"
  122. " at least one of _matvec and _matmat.",
  123. category=RuntimeWarning, stacklevel=2)
  124. return obj
  125. def __init__(self, dtype, shape):
  126. """Initialize this LinearOperator.
  127. To be called by subclasses. ``dtype`` may be None; ``shape`` should
  128. be convertible to a length-2 tuple.
  129. """
  130. if dtype is not None:
  131. dtype = np.dtype(dtype)
  132. shape = tuple(shape)
  133. if not isshape(shape):
  134. raise ValueError("invalid shape %r (must be 2-d)" % (shape,))
  135. self.dtype = dtype
  136. self.shape = shape
  137. def _init_dtype(self):
  138. """Called from subclasses at the end of the __init__ routine.
  139. """
  140. if self.dtype is None:
  141. v = np.zeros(self.shape[-1])
  142. self.dtype = np.asarray(self.matvec(v)).dtype
  143. def _matmat(self, X):
  144. """Default matrix-matrix multiplication handler.
  145. Falls back on the user-defined _matvec method, so defining that will
  146. define matrix multiplication (though in a very suboptimal way).
  147. """
  148. return np.hstack([self.matvec(col.reshape(-1,1)) for col in X.T])
  149. def _matvec(self, x):
  150. """Default matrix-vector multiplication handler.
  151. If self is a linear operator of shape (M, N), then this method will
  152. be called on a shape (N,) or (N, 1) ndarray, and should return a
  153. shape (M,) or (M, 1) ndarray.
  154. This default implementation falls back on _matmat, so defining that
  155. will define matrix-vector multiplication as well.
  156. """
  157. return self.matmat(x.reshape(-1, 1))
  158. def matvec(self, x):
  159. """Matrix-vector multiplication.
  160. Performs the operation y=A*x where A is an MxN linear
  161. operator and x is a column vector or 1-d array.
  162. Parameters
  163. ----------
  164. x : {matrix, ndarray}
  165. An array with shape (N,) or (N,1).
  166. Returns
  167. -------
  168. y : {matrix, ndarray}
  169. A matrix or ndarray with shape (M,) or (M,1) depending
  170. on the type and shape of the x argument.
  171. Notes
  172. -----
  173. This matvec wraps the user-specified matvec routine or overridden
  174. _matvec method to ensure that y has the correct shape and type.
  175. """
  176. x = np.asanyarray(x)
  177. M,N = self.shape
  178. if x.shape != (N,) and x.shape != (N,1):
  179. raise ValueError('dimension mismatch')
  180. y = self._matvec(x)
  181. if isinstance(x, np.matrix):
  182. y = asmatrix(y)
  183. else:
  184. y = np.asarray(y)
  185. if x.ndim == 1:
  186. y = y.reshape(M)
  187. elif x.ndim == 2:
  188. y = y.reshape(M,1)
  189. else:
  190. raise ValueError('invalid shape returned by user-defined matvec()')
  191. return y
  192. def rmatvec(self, x):
  193. """Adjoint matrix-vector multiplication.
  194. Performs the operation y = A^H * x where A is an MxN linear
  195. operator and x is a column vector or 1-d array.
  196. Parameters
  197. ----------
  198. x : {matrix, ndarray}
  199. An array with shape (M,) or (M,1).
  200. Returns
  201. -------
  202. y : {matrix, ndarray}
  203. A matrix or ndarray with shape (N,) or (N,1) depending
  204. on the type and shape of the x argument.
  205. Notes
  206. -----
  207. This rmatvec wraps the user-specified rmatvec routine or overridden
  208. _rmatvec method to ensure that y has the correct shape and type.
  209. """
  210. x = np.asanyarray(x)
  211. M,N = self.shape
  212. if x.shape != (M,) and x.shape != (M,1):
  213. raise ValueError('dimension mismatch')
  214. y = self._rmatvec(x)
  215. if isinstance(x, np.matrix):
  216. y = asmatrix(y)
  217. else:
  218. y = np.asarray(y)
  219. if x.ndim == 1:
  220. y = y.reshape(N)
  221. elif x.ndim == 2:
  222. y = y.reshape(N,1)
  223. else:
  224. raise ValueError('invalid shape returned by user-defined rmatvec()')
  225. return y
  226. def _rmatvec(self, x):
  227. """Default implementation of _rmatvec; defers to adjoint."""
  228. if type(self)._adjoint == LinearOperator._adjoint:
  229. # _adjoint not overridden, prevent infinite recursion
  230. raise NotImplementedError
  231. else:
  232. return self.H.matvec(x)
  233. def matmat(self, X):
  234. """Matrix-matrix multiplication.
  235. Performs the operation y=A*X where A is an MxN linear
  236. operator and X dense N*K matrix or ndarray.
  237. Parameters
  238. ----------
  239. X : {matrix, ndarray}
  240. An array with shape (N,K).
  241. Returns
  242. -------
  243. Y : {matrix, ndarray}
  244. A matrix or ndarray with shape (M,K) depending on
  245. the type of the X argument.
  246. Notes
  247. -----
  248. This matmat wraps any user-specified matmat routine or overridden
  249. _matmat method to ensure that y has the correct type.
  250. """
  251. X = np.asanyarray(X)
  252. if X.ndim != 2:
  253. raise ValueError('expected 2-d ndarray or matrix, not %d-d'
  254. % X.ndim)
  255. if X.shape[0] != self.shape[1]:
  256. raise ValueError('dimension mismatch: %r, %r'
  257. % (self.shape, X.shape))
  258. Y = self._matmat(X)
  259. if isinstance(Y, np.matrix):
  260. Y = asmatrix(Y)
  261. return Y
  262. def rmatmat(self, X):
  263. """Adjoint matrix-matrix multiplication.
  264. Performs the operation y = A^H * x where A is an MxN linear
  265. operator and x is a column vector or 1-d array, or 2-d array.
  266. The default implementation defers to the adjoint.
  267. Parameters
  268. ----------
  269. X : {matrix, ndarray}
  270. A matrix or 2D array.
  271. Returns
  272. -------
  273. Y : {matrix, ndarray}
  274. A matrix or 2D array depending on the type of the input.
  275. Notes
  276. -----
  277. This rmatmat wraps the user-specified rmatmat routine.
  278. """
  279. X = np.asanyarray(X)
  280. if X.ndim != 2:
  281. raise ValueError('expected 2-d ndarray or matrix, not %d-d'
  282. % X.ndim)
  283. if X.shape[0] != self.shape[0]:
  284. raise ValueError('dimension mismatch: %r, %r'
  285. % (self.shape, X.shape))
  286. Y = self._rmatmat(X)
  287. if isinstance(Y, np.matrix):
  288. Y = asmatrix(Y)
  289. return Y
  290. def _rmatmat(self, X):
  291. """Default implementation of _rmatmat defers to rmatvec or adjoint."""
  292. if type(self)._adjoint == LinearOperator._adjoint:
  293. return np.hstack([self.rmatvec(col.reshape(-1, 1)) for col in X.T])
  294. else:
  295. return self.H.matmat(X)
  296. def __call__(self, x):
  297. return self*x
  298. def __mul__(self, x):
  299. return self.dot(x)
  300. def dot(self, x):
  301. """Matrix-matrix or matrix-vector multiplication.
  302. Parameters
  303. ----------
  304. x : array_like
  305. 1-d or 2-d array, representing a vector or matrix.
  306. Returns
  307. -------
  308. Ax : array
  309. 1-d or 2-d array (depending on the shape of x) that represents
  310. the result of applying this linear operator on x.
  311. """
  312. if isinstance(x, LinearOperator):
  313. return _ProductLinearOperator(self, x)
  314. elif np.isscalar(x):
  315. return _ScaledLinearOperator(self, x)
  316. else:
  317. x = np.asarray(x)
  318. if x.ndim == 1 or x.ndim == 2 and x.shape[1] == 1:
  319. return self.matvec(x)
  320. elif x.ndim == 2:
  321. return self.matmat(x)
  322. else:
  323. raise ValueError('expected 1-d or 2-d array or matrix, got %r'
  324. % x)
  325. def __matmul__(self, other):
  326. if np.isscalar(other):
  327. raise ValueError("Scalar operands are not allowed, "
  328. "use '*' instead")
  329. return self.__mul__(other)
  330. def __rmatmul__(self, other):
  331. if np.isscalar(other):
  332. raise ValueError("Scalar operands are not allowed, "
  333. "use '*' instead")
  334. return self.__rmul__(other)
  335. def __rmul__(self, x):
  336. if np.isscalar(x):
  337. return _ScaledLinearOperator(self, x)
  338. else:
  339. return NotImplemented
  340. def __pow__(self, p):
  341. if np.isscalar(p):
  342. return _PowerLinearOperator(self, p)
  343. else:
  344. return NotImplemented
  345. def __add__(self, x):
  346. if isinstance(x, LinearOperator):
  347. return _SumLinearOperator(self, x)
  348. else:
  349. return NotImplemented
  350. def __neg__(self):
  351. return _ScaledLinearOperator(self, -1)
  352. def __sub__(self, x):
  353. return self.__add__(-x)
  354. def __repr__(self):
  355. M,N = self.shape
  356. if self.dtype is None:
  357. dt = 'unspecified dtype'
  358. else:
  359. dt = 'dtype=' + str(self.dtype)
  360. return '<%dx%d %s with %s>' % (M, N, self.__class__.__name__, dt)
  361. def adjoint(self):
  362. """Hermitian adjoint.
  363. Returns the Hermitian adjoint of self, aka the Hermitian
  364. conjugate or Hermitian transpose. For a complex matrix, the
  365. Hermitian adjoint is equal to the conjugate transpose.
  366. Can be abbreviated self.H instead of self.adjoint().
  367. Returns
  368. -------
  369. A_H : LinearOperator
  370. Hermitian adjoint of self.
  371. """
  372. return self._adjoint()
  373. H = property(adjoint)
  374. def transpose(self):
  375. """Transpose this linear operator.
  376. Returns a LinearOperator that represents the transpose of this one.
  377. Can be abbreviated self.T instead of self.transpose().
  378. """
  379. return self._transpose()
  380. T = property(transpose)
  381. def _adjoint(self):
  382. """Default implementation of _adjoint; defers to rmatvec."""
  383. return _AdjointLinearOperator(self)
  384. def _transpose(self):
  385. """ Default implementation of _transpose; defers to rmatvec + conj"""
  386. return _TransposedLinearOperator(self)
  387. class _CustomLinearOperator(LinearOperator):
  388. """Linear operator defined in terms of user-specified operations."""
  389. def __init__(self, shape, matvec, rmatvec=None, matmat=None,
  390. dtype=None, rmatmat=None):
  391. super().__init__(dtype, shape)
  392. self.args = ()
  393. self.__matvec_impl = matvec
  394. self.__rmatvec_impl = rmatvec
  395. self.__rmatmat_impl = rmatmat
  396. self.__matmat_impl = matmat
  397. self._init_dtype()
  398. def _matmat(self, X):
  399. if self.__matmat_impl is not None:
  400. return self.__matmat_impl(X)
  401. else:
  402. return super()._matmat(X)
  403. def _matvec(self, x):
  404. return self.__matvec_impl(x)
  405. def _rmatvec(self, x):
  406. func = self.__rmatvec_impl
  407. if func is None:
  408. raise NotImplementedError("rmatvec is not defined")
  409. return self.__rmatvec_impl(x)
  410. def _rmatmat(self, X):
  411. if self.__rmatmat_impl is not None:
  412. return self.__rmatmat_impl(X)
  413. else:
  414. return super()._rmatmat(X)
  415. def _adjoint(self):
  416. return _CustomLinearOperator(shape=(self.shape[1], self.shape[0]),
  417. matvec=self.__rmatvec_impl,
  418. rmatvec=self.__matvec_impl,
  419. matmat=self.__rmatmat_impl,
  420. rmatmat=self.__matmat_impl,
  421. dtype=self.dtype)
  422. class _AdjointLinearOperator(LinearOperator):
  423. """Adjoint of arbitrary Linear Operator"""
  424. def __init__(self, A):
  425. shape = (A.shape[1], A.shape[0])
  426. super().__init__(dtype=A.dtype, shape=shape)
  427. self.A = A
  428. self.args = (A,)
  429. def _matvec(self, x):
  430. return self.A._rmatvec(x)
  431. def _rmatvec(self, x):
  432. return self.A._matvec(x)
  433. def _matmat(self, x):
  434. return self.A._rmatmat(x)
  435. def _rmatmat(self, x):
  436. return self.A._matmat(x)
  437. class _TransposedLinearOperator(LinearOperator):
  438. """Transposition of arbitrary Linear Operator"""
  439. def __init__(self, A):
  440. shape = (A.shape[1], A.shape[0])
  441. super().__init__(dtype=A.dtype, shape=shape)
  442. self.A = A
  443. self.args = (A,)
  444. def _matvec(self, x):
  445. # NB. np.conj works also on sparse matrices
  446. return np.conj(self.A._rmatvec(np.conj(x)))
  447. def _rmatvec(self, x):
  448. return np.conj(self.A._matvec(np.conj(x)))
  449. def _matmat(self, x):
  450. # NB. np.conj works also on sparse matrices
  451. return np.conj(self.A._rmatmat(np.conj(x)))
  452. def _rmatmat(self, x):
  453. return np.conj(self.A._matmat(np.conj(x)))
  454. def _get_dtype(operators, dtypes=None):
  455. if dtypes is None:
  456. dtypes = []
  457. for obj in operators:
  458. if obj is not None and hasattr(obj, 'dtype'):
  459. dtypes.append(obj.dtype)
  460. return np.result_type(*dtypes)
  461. class _SumLinearOperator(LinearOperator):
  462. def __init__(self, A, B):
  463. if not isinstance(A, LinearOperator) or \
  464. not isinstance(B, LinearOperator):
  465. raise ValueError('both operands have to be a LinearOperator')
  466. if A.shape != B.shape:
  467. raise ValueError('cannot add %r and %r: shape mismatch'
  468. % (A, B))
  469. self.args = (A, B)
  470. super().__init__(_get_dtype([A, B]), A.shape)
  471. def _matvec(self, x):
  472. return self.args[0].matvec(x) + self.args[1].matvec(x)
  473. def _rmatvec(self, x):
  474. return self.args[0].rmatvec(x) + self.args[1].rmatvec(x)
  475. def _rmatmat(self, x):
  476. return self.args[0].rmatmat(x) + self.args[1].rmatmat(x)
  477. def _matmat(self, x):
  478. return self.args[0].matmat(x) + self.args[1].matmat(x)
  479. def _adjoint(self):
  480. A, B = self.args
  481. return A.H + B.H
  482. class _ProductLinearOperator(LinearOperator):
  483. def __init__(self, A, B):
  484. if not isinstance(A, LinearOperator) or \
  485. not isinstance(B, LinearOperator):
  486. raise ValueError('both operands have to be a LinearOperator')
  487. if A.shape[1] != B.shape[0]:
  488. raise ValueError('cannot multiply %r and %r: shape mismatch'
  489. % (A, B))
  490. super().__init__(_get_dtype([A, B]),
  491. (A.shape[0], B.shape[1]))
  492. self.args = (A, B)
  493. def _matvec(self, x):
  494. return self.args[0].matvec(self.args[1].matvec(x))
  495. def _rmatvec(self, x):
  496. return self.args[1].rmatvec(self.args[0].rmatvec(x))
  497. def _rmatmat(self, x):
  498. return self.args[1].rmatmat(self.args[0].rmatmat(x))
  499. def _matmat(self, x):
  500. return self.args[0].matmat(self.args[1].matmat(x))
  501. def _adjoint(self):
  502. A, B = self.args
  503. return B.H * A.H
  504. class _ScaledLinearOperator(LinearOperator):
  505. def __init__(self, A, alpha):
  506. if not isinstance(A, LinearOperator):
  507. raise ValueError('LinearOperator expected as A')
  508. if not np.isscalar(alpha):
  509. raise ValueError('scalar expected as alpha')
  510. dtype = _get_dtype([A], [type(alpha)])
  511. super().__init__(dtype, A.shape)
  512. self.args = (A, alpha)
  513. def _matvec(self, x):
  514. return self.args[1] * self.args[0].matvec(x)
  515. def _rmatvec(self, x):
  516. return np.conj(self.args[1]) * self.args[0].rmatvec(x)
  517. def _rmatmat(self, x):
  518. return np.conj(self.args[1]) * self.args[0].rmatmat(x)
  519. def _matmat(self, x):
  520. return self.args[1] * self.args[0].matmat(x)
  521. def _adjoint(self):
  522. A, alpha = self.args
  523. return A.H * np.conj(alpha)
  524. class _PowerLinearOperator(LinearOperator):
  525. def __init__(self, A, p):
  526. if not isinstance(A, LinearOperator):
  527. raise ValueError('LinearOperator expected as A')
  528. if A.shape[0] != A.shape[1]:
  529. raise ValueError('square LinearOperator expected, got %r' % A)
  530. if not isintlike(p) or p < 0:
  531. raise ValueError('non-negative integer expected as p')
  532. super().__init__(_get_dtype([A]), A.shape)
  533. self.args = (A, p)
  534. def _power(self, fun, x):
  535. res = np.array(x, copy=True)
  536. for i in range(self.args[1]):
  537. res = fun(res)
  538. return res
  539. def _matvec(self, x):
  540. return self._power(self.args[0].matvec, x)
  541. def _rmatvec(self, x):
  542. return self._power(self.args[0].rmatvec, x)
  543. def _rmatmat(self, x):
  544. return self._power(self.args[0].rmatmat, x)
  545. def _matmat(self, x):
  546. return self._power(self.args[0].matmat, x)
  547. def _adjoint(self):
  548. A, p = self.args
  549. return A.H ** p
  550. class MatrixLinearOperator(LinearOperator):
  551. def __init__(self, A):
  552. super().__init__(A.dtype, A.shape)
  553. self.A = A
  554. self.__adj = None
  555. self.args = (A,)
  556. def _matmat(self, X):
  557. return self.A.dot(X)
  558. def _adjoint(self):
  559. if self.__adj is None:
  560. self.__adj = _AdjointMatrixOperator(self)
  561. return self.__adj
  562. class _AdjointMatrixOperator(MatrixLinearOperator):
  563. def __init__(self, adjoint):
  564. self.A = adjoint.A.T.conj()
  565. self.__adjoint = adjoint
  566. self.args = (adjoint,)
  567. self.shape = adjoint.shape[1], adjoint.shape[0]
  568. @property
  569. def dtype(self):
  570. return self.__adjoint.dtype
  571. def _adjoint(self):
  572. return self.__adjoint
  573. class IdentityOperator(LinearOperator):
  574. def __init__(self, shape, dtype=None):
  575. super().__init__(dtype, shape)
  576. def _matvec(self, x):
  577. return x
  578. def _rmatvec(self, x):
  579. return x
  580. def _rmatmat(self, x):
  581. return x
  582. def _matmat(self, x):
  583. return x
  584. def _adjoint(self):
  585. return self
  586. def aslinearoperator(A):
  587. """Return A as a LinearOperator.
  588. 'A' may be any of the following types:
  589. - ndarray
  590. - matrix
  591. - sparse matrix (e.g. csr_matrix, lil_matrix, etc.)
  592. - LinearOperator
  593. - An object with .shape and .matvec attributes
  594. See the LinearOperator documentation for additional information.
  595. Notes
  596. -----
  597. If 'A' has no .dtype attribute, the data type is determined by calling
  598. :func:`LinearOperator.matvec()` - set the .dtype attribute to prevent this
  599. call upon the linear operator creation.
  600. Examples
  601. --------
  602. >>> import numpy as np
  603. >>> from scipy.sparse.linalg import aslinearoperator
  604. >>> M = np.array([[1,2,3],[4,5,6]], dtype=np.int32)
  605. >>> aslinearoperator(M)
  606. <2x3 MatrixLinearOperator with dtype=int32>
  607. """
  608. if isinstance(A, LinearOperator):
  609. return A
  610. elif isinstance(A, np.ndarray) or isinstance(A, np.matrix):
  611. if A.ndim > 2:
  612. raise ValueError('array must have ndim <= 2')
  613. A = np.atleast_2d(np.asarray(A))
  614. return MatrixLinearOperator(A)
  615. elif isspmatrix(A) or is_pydata_spmatrix(A):
  616. return MatrixLinearOperator(A)
  617. else:
  618. if hasattr(A, 'shape') and hasattr(A, 'matvec'):
  619. rmatvec = None
  620. rmatmat = None
  621. dtype = None
  622. if hasattr(A, 'rmatvec'):
  623. rmatvec = A.rmatvec
  624. if hasattr(A, 'rmatmat'):
  625. rmatmat = A.rmatmat
  626. if hasattr(A, 'dtype'):
  627. dtype = A.dtype
  628. return LinearOperator(A.shape, A.matvec, rmatvec=rmatvec,
  629. rmatmat=rmatmat, dtype=dtype)
  630. else:
  631. raise TypeError('type not understood')