_basic.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815
  1. #
  2. # Author: Pearu Peterson, March 2002
  3. #
  4. # w/ additions by Travis Oliphant, March 2002
  5. # and Jake Vanderplas, August 2012
  6. from warnings import warn
  7. import numpy as np
  8. from numpy import atleast_1d, atleast_2d
  9. from ._flinalg_py import get_flinalg_funcs
  10. from .lapack import get_lapack_funcs, _compute_lwork
  11. from ._misc import LinAlgError, _datacopied, LinAlgWarning
  12. from ._decomp import _asarray_validated
  13. from . import _decomp, _decomp_svd
  14. from ._solve_toeplitz import levinson
  15. __all__ = ['solve', 'solve_triangular', 'solveh_banded', 'solve_banded',
  16. 'solve_toeplitz', 'solve_circulant', 'inv', 'det', 'lstsq',
  17. 'pinv', 'pinvh', 'matrix_balance', 'matmul_toeplitz']
  18. # Linear equations
  19. def _solve_check(n, info, lamch=None, rcond=None):
  20. """ Check arguments during the different steps of the solution phase """
  21. if info < 0:
  22. raise ValueError('LAPACK reported an illegal value in {}-th argument'
  23. '.'.format(-info))
  24. elif 0 < info:
  25. raise LinAlgError('Matrix is singular.')
  26. if lamch is None:
  27. return
  28. E = lamch('E')
  29. if rcond < E:
  30. warn('Ill-conditioned matrix (rcond={:.6g}): '
  31. 'result may not be accurate.'.format(rcond),
  32. LinAlgWarning, stacklevel=3)
  33. def solve(a, b, sym_pos=False, lower=False, overwrite_a=False,
  34. overwrite_b=False, check_finite=True, assume_a='gen',
  35. transposed=False):
  36. """
  37. Solves the linear equation set ``a @ x == b`` for the unknown ``x``
  38. for square `a` matrix.
  39. If the data matrix is known to be a particular type then supplying the
  40. corresponding string to ``assume_a`` key chooses the dedicated solver.
  41. The available options are
  42. =================== ========
  43. generic matrix 'gen'
  44. symmetric 'sym'
  45. hermitian 'her'
  46. positive definite 'pos'
  47. =================== ========
  48. If omitted, ``'gen'`` is the default structure.
  49. The datatype of the arrays define which solver is called regardless
  50. of the values. In other words, even when the complex array entries have
  51. precisely zero imaginary parts, the complex solver will be called based
  52. on the data type of the array.
  53. Parameters
  54. ----------
  55. a : (N, N) array_like
  56. Square input data
  57. b : (N, NRHS) array_like
  58. Input data for the right hand side.
  59. sym_pos : bool, default: False, deprecated
  60. Assume `a` is symmetric and positive definite.
  61. .. deprecated:: 0.19.0
  62. This keyword is deprecated and should be replaced by using
  63. ``assume_a = 'pos'``. `sym_pos` will be removed in SciPy 1.11.0.
  64. lower : bool, default: False
  65. Ignored if ``assume_a == 'gen'`` (the default). If True, the
  66. calculation uses only the data in the lower triangle of `a`;
  67. entries above the diagonal are ignored. If False (default), the
  68. calculation uses only the data in the upper triangle of `a`; entries
  69. below the diagonal are ignored.
  70. overwrite_a : bool, default: False
  71. Allow overwriting data in `a` (may enhance performance).
  72. overwrite_b : bool, default: False
  73. Allow overwriting data in `b` (may enhance performance).
  74. check_finite : bool, default: True
  75. Whether to check that the input matrices contain only finite numbers.
  76. Disabling may give a performance gain, but may result in problems
  77. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  78. assume_a : str, {'gen', 'sym', 'her', 'pos'}
  79. Valid entries are explained above.
  80. transposed : bool, default: False
  81. If True, solve ``a.T @ x == b``. Raises `NotImplementedError`
  82. for complex `a`.
  83. Returns
  84. -------
  85. x : (N, NRHS) ndarray
  86. The solution array.
  87. Raises
  88. ------
  89. ValueError
  90. If size mismatches detected or input a is not square.
  91. LinAlgError
  92. If the matrix is singular.
  93. LinAlgWarning
  94. If an ill-conditioned input a is detected.
  95. NotImplementedError
  96. If transposed is True and input a is a complex matrix.
  97. Notes
  98. -----
  99. If the input b matrix is a 1-D array with N elements, when supplied
  100. together with an NxN input a, it is assumed as a valid column vector
  101. despite the apparent size mismatch. This is compatible with the
  102. numpy.dot() behavior and the returned result is still 1-D array.
  103. The generic, symmetric, Hermitian and positive definite solutions are
  104. obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
  105. LAPACK respectively.
  106. Examples
  107. --------
  108. Given `a` and `b`, solve for `x`:
  109. >>> import numpy as np
  110. >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
  111. >>> b = np.array([2, 4, -1])
  112. >>> from scipy import linalg
  113. >>> x = linalg.solve(a, b)
  114. >>> x
  115. array([ 2., -2., 9.])
  116. >>> np.dot(a, x) == b
  117. array([ True, True, True], dtype=bool)
  118. """
  119. # Flags for 1-D or N-D right-hand side
  120. b_is_1D = False
  121. a1 = atleast_2d(_asarray_validated(a, check_finite=check_finite))
  122. b1 = atleast_1d(_asarray_validated(b, check_finite=check_finite))
  123. n = a1.shape[0]
  124. overwrite_a = overwrite_a or _datacopied(a1, a)
  125. overwrite_b = overwrite_b or _datacopied(b1, b)
  126. if a1.shape[0] != a1.shape[1]:
  127. raise ValueError('Input a needs to be a square matrix.')
  128. if n != b1.shape[0]:
  129. # Last chance to catch 1x1 scalar a and 1-D b arrays
  130. if not (n == 1 and b1.size != 0):
  131. raise ValueError('Input b has to have same number of rows as '
  132. 'input a')
  133. # accommodate empty arrays
  134. if b1.size == 0:
  135. return np.asfortranarray(b1.copy())
  136. # regularize 1-D b arrays to 2D
  137. if b1.ndim == 1:
  138. if n == 1:
  139. b1 = b1[None, :]
  140. else:
  141. b1 = b1[:, None]
  142. b_is_1D = True
  143. # Backwards compatibility - old keyword.
  144. if sym_pos:
  145. message = ("The 'sym_pos' keyword is deprecated and should be "
  146. "replaced by using 'assume_a = \"pos\"'. 'sym_pos' will be"
  147. " removed in SciPy 1.11.0.")
  148. warn(message, DeprecationWarning, stacklevel=2)
  149. assume_a = 'pos'
  150. if assume_a not in ('gen', 'sym', 'her', 'pos'):
  151. raise ValueError('{} is not a recognized matrix structure'
  152. ''.format(assume_a))
  153. # for a real matrix, describe it as "symmetric", not "hermitian"
  154. # (lapack doesn't know what to do with real hermitian matrices)
  155. if assume_a == 'her' and not np.iscomplexobj(a1):
  156. assume_a = 'sym'
  157. # Get the correct lamch function.
  158. # The LAMCH functions only exists for S and D
  159. # So for complex values we have to convert to real/double.
  160. if a1.dtype.char in 'fF': # single precision
  161. lamch = get_lapack_funcs('lamch', dtype='f')
  162. else:
  163. lamch = get_lapack_funcs('lamch', dtype='d')
  164. # Currently we do not have the other forms of the norm calculators
  165. # lansy, lanpo, lanhe.
  166. # However, in any case they only reduce computations slightly...
  167. lange = get_lapack_funcs('lange', (a1,))
  168. # Since the I-norm and 1-norm are the same for symmetric matrices
  169. # we can collect them all in this one call
  170. # Note however, that when issuing 'gen' and form!='none', then
  171. # the I-norm should be used
  172. if transposed:
  173. trans = 1
  174. norm = 'I'
  175. if np.iscomplexobj(a1):
  176. raise NotImplementedError('scipy.linalg.solve can currently '
  177. 'not solve a^T x = b or a^H x = b '
  178. 'for complex matrices.')
  179. else:
  180. trans = 0
  181. norm = '1'
  182. anorm = lange(norm, a1)
  183. # Generalized case 'gesv'
  184. if assume_a == 'gen':
  185. gecon, getrf, getrs = get_lapack_funcs(('gecon', 'getrf', 'getrs'),
  186. (a1, b1))
  187. lu, ipvt, info = getrf(a1, overwrite_a=overwrite_a)
  188. _solve_check(n, info)
  189. x, info = getrs(lu, ipvt, b1,
  190. trans=trans, overwrite_b=overwrite_b)
  191. _solve_check(n, info)
  192. rcond, info = gecon(lu, anorm, norm=norm)
  193. # Hermitian case 'hesv'
  194. elif assume_a == 'her':
  195. hecon, hesv, hesv_lw = get_lapack_funcs(('hecon', 'hesv',
  196. 'hesv_lwork'), (a1, b1))
  197. lwork = _compute_lwork(hesv_lw, n, lower)
  198. lu, ipvt, x, info = hesv(a1, b1, lwork=lwork,
  199. lower=lower,
  200. overwrite_a=overwrite_a,
  201. overwrite_b=overwrite_b)
  202. _solve_check(n, info)
  203. rcond, info = hecon(lu, ipvt, anorm)
  204. # Symmetric case 'sysv'
  205. elif assume_a == 'sym':
  206. sycon, sysv, sysv_lw = get_lapack_funcs(('sycon', 'sysv',
  207. 'sysv_lwork'), (a1, b1))
  208. lwork = _compute_lwork(sysv_lw, n, lower)
  209. lu, ipvt, x, info = sysv(a1, b1, lwork=lwork,
  210. lower=lower,
  211. overwrite_a=overwrite_a,
  212. overwrite_b=overwrite_b)
  213. _solve_check(n, info)
  214. rcond, info = sycon(lu, ipvt, anorm)
  215. # Positive definite case 'posv'
  216. else:
  217. pocon, posv = get_lapack_funcs(('pocon', 'posv'),
  218. (a1, b1))
  219. lu, x, info = posv(a1, b1, lower=lower,
  220. overwrite_a=overwrite_a,
  221. overwrite_b=overwrite_b)
  222. _solve_check(n, info)
  223. rcond, info = pocon(lu, anorm)
  224. _solve_check(n, info, lamch, rcond)
  225. if b_is_1D:
  226. x = x.ravel()
  227. return x
  228. def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
  229. overwrite_b=False, check_finite=True):
  230. """
  231. Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
  232. Parameters
  233. ----------
  234. a : (M, M) array_like
  235. A triangular matrix
  236. b : (M,) or (M, N) array_like
  237. Right-hand side matrix in `a x = b`
  238. lower : bool, optional
  239. Use only data contained in the lower triangle of `a`.
  240. Default is to use upper triangle.
  241. trans : {0, 1, 2, 'N', 'T', 'C'}, optional
  242. Type of system to solve:
  243. ======== =========
  244. trans system
  245. ======== =========
  246. 0 or 'N' a x = b
  247. 1 or 'T' a^T x = b
  248. 2 or 'C' a^H x = b
  249. ======== =========
  250. unit_diagonal : bool, optional
  251. If True, diagonal elements of `a` are assumed to be 1 and
  252. will not be referenced.
  253. overwrite_b : bool, optional
  254. Allow overwriting data in `b` (may enhance performance)
  255. check_finite : bool, optional
  256. Whether to check that the input matrices contain only finite numbers.
  257. Disabling may give a performance gain, but may result in problems
  258. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  259. Returns
  260. -------
  261. x : (M,) or (M, N) ndarray
  262. Solution to the system `a x = b`. Shape of return matches `b`.
  263. Raises
  264. ------
  265. LinAlgError
  266. If `a` is singular
  267. Notes
  268. -----
  269. .. versionadded:: 0.9.0
  270. Examples
  271. --------
  272. Solve the lower triangular system a x = b, where::
  273. [3 0 0 0] [4]
  274. a = [2 1 0 0] b = [2]
  275. [1 0 1 0] [4]
  276. [1 1 1 1] [2]
  277. >>> import numpy as np
  278. >>> from scipy.linalg import solve_triangular
  279. >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
  280. >>> b = np.array([4, 2, 4, 2])
  281. >>> x = solve_triangular(a, b, lower=True)
  282. >>> x
  283. array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333])
  284. >>> a.dot(x) # Check the result
  285. array([ 4., 2., 4., 2.])
  286. """
  287. a1 = _asarray_validated(a, check_finite=check_finite)
  288. b1 = _asarray_validated(b, check_finite=check_finite)
  289. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  290. raise ValueError('expected square matrix')
  291. if a1.shape[0] != b1.shape[0]:
  292. raise ValueError('shapes of a {} and b {} are incompatible'
  293. .format(a1.shape, b1.shape))
  294. overwrite_b = overwrite_b or _datacopied(b1, b)
  295. trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
  296. trtrs, = get_lapack_funcs(('trtrs',), (a1, b1))
  297. if a1.flags.f_contiguous or trans == 2:
  298. x, info = trtrs(a1, b1, overwrite_b=overwrite_b, lower=lower,
  299. trans=trans, unitdiag=unit_diagonal)
  300. else:
  301. # transposed system is solved since trtrs expects Fortran ordering
  302. x, info = trtrs(a1.T, b1, overwrite_b=overwrite_b, lower=not lower,
  303. trans=not trans, unitdiag=unit_diagonal)
  304. if info == 0:
  305. return x
  306. if info > 0:
  307. raise LinAlgError("singular matrix: resolution failed at diagonal %d" %
  308. (info-1))
  309. raise ValueError('illegal value in %dth argument of internal trtrs' %
  310. (-info))
  311. def solve_banded(l_and_u, ab, b, overwrite_ab=False, overwrite_b=False,
  312. check_finite=True):
  313. """
  314. Solve the equation a x = b for x, assuming a is banded matrix.
  315. The matrix a is stored in `ab` using the matrix diagonal ordered form::
  316. ab[u + i - j, j] == a[i,j]
  317. Example of `ab` (shape of a is (6,6), `u` =1, `l` =2)::
  318. * a01 a12 a23 a34 a45
  319. a00 a11 a22 a33 a44 a55
  320. a10 a21 a32 a43 a54 *
  321. a20 a31 a42 a53 * *
  322. Parameters
  323. ----------
  324. (l, u) : (integer, integer)
  325. Number of non-zero lower and upper diagonals
  326. ab : (`l` + `u` + 1, M) array_like
  327. Banded matrix
  328. b : (M,) or (M, K) array_like
  329. Right-hand side
  330. overwrite_ab : bool, optional
  331. Discard data in `ab` (may enhance performance)
  332. overwrite_b : bool, optional
  333. Discard data in `b` (may enhance performance)
  334. check_finite : bool, optional
  335. Whether to check that the input matrices contain only finite numbers.
  336. Disabling may give a performance gain, but may result in problems
  337. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  338. Returns
  339. -------
  340. x : (M,) or (M, K) ndarray
  341. The solution to the system a x = b. Returned shape depends on the
  342. shape of `b`.
  343. Examples
  344. --------
  345. Solve the banded system a x = b, where::
  346. [5 2 -1 0 0] [0]
  347. [1 4 2 -1 0] [1]
  348. a = [0 1 3 2 -1] b = [2]
  349. [0 0 1 2 2] [2]
  350. [0 0 0 1 1] [3]
  351. There is one nonzero diagonal below the main diagonal (l = 1), and
  352. two above (u = 2). The diagonal banded form of the matrix is::
  353. [* * -1 -1 -1]
  354. ab = [* 2 2 2 2]
  355. [5 4 3 2 1]
  356. [1 1 1 1 *]
  357. >>> import numpy as np
  358. >>> from scipy.linalg import solve_banded
  359. >>> ab = np.array([[0, 0, -1, -1, -1],
  360. ... [0, 2, 2, 2, 2],
  361. ... [5, 4, 3, 2, 1],
  362. ... [1, 1, 1, 1, 0]])
  363. >>> b = np.array([0, 1, 2, 2, 3])
  364. >>> x = solve_banded((1, 2), ab, b)
  365. >>> x
  366. array([-2.37288136, 3.93220339, -4. , 4.3559322 , -1.3559322 ])
  367. """
  368. a1 = _asarray_validated(ab, check_finite=check_finite, as_inexact=True)
  369. b1 = _asarray_validated(b, check_finite=check_finite, as_inexact=True)
  370. # Validate shapes.
  371. if a1.shape[-1] != b1.shape[0]:
  372. raise ValueError("shapes of ab and b are not compatible.")
  373. (nlower, nupper) = l_and_u
  374. if nlower + nupper + 1 != a1.shape[0]:
  375. raise ValueError("invalid values for the number of lower and upper "
  376. "diagonals: l+u+1 (%d) does not equal ab.shape[0] "
  377. "(%d)" % (nlower + nupper + 1, ab.shape[0]))
  378. overwrite_b = overwrite_b or _datacopied(b1, b)
  379. if a1.shape[-1] == 1:
  380. b2 = np.array(b1, copy=(not overwrite_b))
  381. b2 /= a1[1, 0]
  382. return b2
  383. if nlower == nupper == 1:
  384. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  385. gtsv, = get_lapack_funcs(('gtsv',), (a1, b1))
  386. du = a1[0, 1:]
  387. d = a1[1, :]
  388. dl = a1[2, :-1]
  389. du2, d, du, x, info = gtsv(dl, d, du, b1, overwrite_ab, overwrite_ab,
  390. overwrite_ab, overwrite_b)
  391. else:
  392. gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
  393. a2 = np.zeros((2*nlower + nupper + 1, a1.shape[1]), dtype=gbsv.dtype)
  394. a2[nlower:, :] = a1
  395. lu, piv, x, info = gbsv(nlower, nupper, a2, b1, overwrite_ab=True,
  396. overwrite_b=overwrite_b)
  397. if info == 0:
  398. return x
  399. if info > 0:
  400. raise LinAlgError("singular matrix")
  401. raise ValueError('illegal value in %d-th argument of internal '
  402. 'gbsv/gtsv' % -info)
  403. def solveh_banded(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
  404. check_finite=True):
  405. """
  406. Solve equation a x = b. a is Hermitian positive-definite banded matrix.
  407. Uses Thomas' Algorithm, which is more efficient than standard LU
  408. factorization, but should only be used for Hermitian positive-definite
  409. matrices.
  410. The matrix ``a`` is stored in `ab` either in lower diagonal or upper
  411. diagonal ordered form:
  412. ab[u + i - j, j] == a[i,j] (if upper form; i <= j)
  413. ab[ i - j, j] == a[i,j] (if lower form; i >= j)
  414. Example of `ab` (shape of ``a`` is (6, 6), number of upper diagonals,
  415. ``u`` =2)::
  416. upper form:
  417. * * a02 a13 a24 a35
  418. * a01 a12 a23 a34 a45
  419. a00 a11 a22 a33 a44 a55
  420. lower form:
  421. a00 a11 a22 a33 a44 a55
  422. a10 a21 a32 a43 a54 *
  423. a20 a31 a42 a53 * *
  424. Cells marked with * are not used.
  425. Parameters
  426. ----------
  427. ab : (``u`` + 1, M) array_like
  428. Banded matrix
  429. b : (M,) or (M, K) array_like
  430. Right-hand side
  431. overwrite_ab : bool, optional
  432. Discard data in `ab` (may enhance performance)
  433. overwrite_b : bool, optional
  434. Discard data in `b` (may enhance performance)
  435. lower : bool, optional
  436. Is the matrix in the lower form. (Default is upper form)
  437. check_finite : bool, optional
  438. Whether to check that the input matrices contain only finite numbers.
  439. Disabling may give a performance gain, but may result in problems
  440. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  441. Returns
  442. -------
  443. x : (M,) or (M, K) ndarray
  444. The solution to the system ``a x = b``. Shape of return matches shape
  445. of `b`.
  446. Notes
  447. -----
  448. In the case of a non-positive definite matrix ``a``, the solver
  449. `solve_banded` may be used.
  450. Examples
  451. --------
  452. Solve the banded system ``A x = b``, where::
  453. [ 4 2 -1 0 0 0] [1]
  454. [ 2 5 2 -1 0 0] [2]
  455. A = [-1 2 6 2 -1 0] b = [2]
  456. [ 0 -1 2 7 2 -1] [3]
  457. [ 0 0 -1 2 8 2] [3]
  458. [ 0 0 0 -1 2 9] [3]
  459. >>> import numpy as np
  460. >>> from scipy.linalg import solveh_banded
  461. ``ab`` contains the main diagonal and the nonzero diagonals below the
  462. main diagonal. That is, we use the lower form:
  463. >>> ab = np.array([[ 4, 5, 6, 7, 8, 9],
  464. ... [ 2, 2, 2, 2, 2, 0],
  465. ... [-1, -1, -1, -1, 0, 0]])
  466. >>> b = np.array([1, 2, 2, 3, 3, 3])
  467. >>> x = solveh_banded(ab, b, lower=True)
  468. >>> x
  469. array([ 0.03431373, 0.45938375, 0.05602241, 0.47759104, 0.17577031,
  470. 0.34733894])
  471. Solve the Hermitian banded system ``H x = b``, where::
  472. [ 8 2-1j 0 0 ] [ 1 ]
  473. H = [2+1j 5 1j 0 ] b = [1+1j]
  474. [ 0 -1j 9 -2-1j] [1-2j]
  475. [ 0 0 -2+1j 6 ] [ 0 ]
  476. In this example, we put the upper diagonals in the array ``hb``:
  477. >>> hb = np.array([[0, 2-1j, 1j, -2-1j],
  478. ... [8, 5, 9, 6 ]])
  479. >>> b = np.array([1, 1+1j, 1-2j, 0])
  480. >>> x = solveh_banded(hb, b)
  481. >>> x
  482. array([ 0.07318536-0.02939412j, 0.11877624+0.17696461j,
  483. 0.10077984-0.23035393j, -0.00479904-0.09358128j])
  484. """
  485. a1 = _asarray_validated(ab, check_finite=check_finite)
  486. b1 = _asarray_validated(b, check_finite=check_finite)
  487. # Validate shapes.
  488. if a1.shape[-1] != b1.shape[0]:
  489. raise ValueError("shapes of ab and b are not compatible.")
  490. overwrite_b = overwrite_b or _datacopied(b1, b)
  491. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  492. if a1.shape[0] == 2:
  493. ptsv, = get_lapack_funcs(('ptsv',), (a1, b1))
  494. if lower:
  495. d = a1[0, :].real
  496. e = a1[1, :-1]
  497. else:
  498. d = a1[1, :].real
  499. e = a1[0, 1:].conj()
  500. d, du, x, info = ptsv(d, e, b1, overwrite_ab, overwrite_ab,
  501. overwrite_b)
  502. else:
  503. pbsv, = get_lapack_funcs(('pbsv',), (a1, b1))
  504. c, x, info = pbsv(a1, b1, lower=lower, overwrite_ab=overwrite_ab,
  505. overwrite_b=overwrite_b)
  506. if info > 0:
  507. raise LinAlgError("%dth leading minor not positive definite" % info)
  508. if info < 0:
  509. raise ValueError('illegal value in %dth argument of internal '
  510. 'pbsv' % -info)
  511. return x
  512. def solve_toeplitz(c_or_cr, b, check_finite=True):
  513. """Solve a Toeplitz system using Levinson Recursion
  514. The Toeplitz matrix has constant diagonals, with c as its first column
  515. and r as its first row. If r is not given, ``r == conjugate(c)`` is
  516. assumed.
  517. Parameters
  518. ----------
  519. c_or_cr : array_like or tuple of (array_like, array_like)
  520. The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
  521. actual shape of ``c``, it will be converted to a 1-D array. If not
  522. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  523. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  524. of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
  525. of ``r``, it will be converted to a 1-D array.
  526. b : (M,) or (M, K) array_like
  527. Right-hand side in ``T x = b``.
  528. check_finite : bool, optional
  529. Whether to check that the input matrices contain only finite numbers.
  530. Disabling may give a performance gain, but may result in problems
  531. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  532. Returns
  533. -------
  534. x : (M,) or (M, K) ndarray
  535. The solution to the system ``T x = b``. Shape of return matches shape
  536. of `b`.
  537. See Also
  538. --------
  539. toeplitz : Toeplitz matrix
  540. Notes
  541. -----
  542. The solution is computed using Levinson-Durbin recursion, which is faster
  543. than generic least-squares methods, but can be less numerically stable.
  544. Examples
  545. --------
  546. Solve the Toeplitz system T x = b, where::
  547. [ 1 -1 -2 -3] [1]
  548. T = [ 3 1 -1 -2] b = [2]
  549. [ 6 3 1 -1] [2]
  550. [10 6 3 1] [5]
  551. To specify the Toeplitz matrix, only the first column and the first
  552. row are needed.
  553. >>> import numpy as np
  554. >>> c = np.array([1, 3, 6, 10]) # First column of T
  555. >>> r = np.array([1, -1, -2, -3]) # First row of T
  556. >>> b = np.array([1, 2, 2, 5])
  557. >>> from scipy.linalg import solve_toeplitz, toeplitz
  558. >>> x = solve_toeplitz((c, r), b)
  559. >>> x
  560. array([ 1.66666667, -1. , -2.66666667, 2.33333333])
  561. Check the result by creating the full Toeplitz matrix and
  562. multiplying it by `x`. We should get `b`.
  563. >>> T = toeplitz(c, r)
  564. >>> T.dot(x)
  565. array([ 1., 2., 2., 5.])
  566. """
  567. # If numerical stability of this algorithm is a problem, a future
  568. # developer might consider implementing other O(N^2) Toeplitz solvers,
  569. # such as GKO (https://www.jstor.org/stable/2153371) or Bareiss.
  570. r, c, b, dtype, b_shape = _validate_args_for_toeplitz_ops(
  571. c_or_cr, b, check_finite, keep_b_shape=True)
  572. # Form a 1-D array of values to be used in the matrix, containing a
  573. # reversed copy of r[1:], followed by c.
  574. vals = np.concatenate((r[-1:0:-1], c))
  575. if b is None:
  576. raise ValueError('illegal value, `b` is a required argument')
  577. if b.ndim == 1:
  578. x, _ = levinson(vals, np.ascontiguousarray(b))
  579. else:
  580. x = np.column_stack([levinson(vals, np.ascontiguousarray(b[:, i]))[0]
  581. for i in range(b.shape[1])])
  582. x = x.reshape(*b_shape)
  583. return x
  584. def _get_axis_len(aname, a, axis):
  585. ax = axis
  586. if ax < 0:
  587. ax += a.ndim
  588. if 0 <= ax < a.ndim:
  589. return a.shape[ax]
  590. raise ValueError("'%saxis' entry is out of bounds" % (aname,))
  591. def solve_circulant(c, b, singular='raise', tol=None,
  592. caxis=-1, baxis=0, outaxis=0):
  593. """Solve C x = b for x, where C is a circulant matrix.
  594. `C` is the circulant matrix associated with the vector `c`.
  595. The system is solved by doing division in Fourier space. The
  596. calculation is::
  597. x = ifft(fft(b) / fft(c))
  598. where `fft` and `ifft` are the fast Fourier transform and its inverse,
  599. respectively. For a large vector `c`, this is *much* faster than
  600. solving the system with the full circulant matrix.
  601. Parameters
  602. ----------
  603. c : array_like
  604. The coefficients of the circulant matrix.
  605. b : array_like
  606. Right-hand side matrix in ``a x = b``.
  607. singular : str, optional
  608. This argument controls how a near singular circulant matrix is
  609. handled. If `singular` is "raise" and the circulant matrix is
  610. near singular, a `LinAlgError` is raised. If `singular` is
  611. "lstsq", the least squares solution is returned. Default is "raise".
  612. tol : float, optional
  613. If any eigenvalue of the circulant matrix has an absolute value
  614. that is less than or equal to `tol`, the matrix is considered to be
  615. near singular. If not given, `tol` is set to::
  616. tol = abs_eigs.max() * abs_eigs.size * np.finfo(np.float64).eps
  617. where `abs_eigs` is the array of absolute values of the eigenvalues
  618. of the circulant matrix.
  619. caxis : int
  620. When `c` has dimension greater than 1, it is viewed as a collection
  621. of circulant vectors. In this case, `caxis` is the axis of `c` that
  622. holds the vectors of circulant coefficients.
  623. baxis : int
  624. When `b` has dimension greater than 1, it is viewed as a collection
  625. of vectors. In this case, `baxis` is the axis of `b` that holds the
  626. right-hand side vectors.
  627. outaxis : int
  628. When `c` or `b` are multidimensional, the value returned by
  629. `solve_circulant` is multidimensional. In this case, `outaxis` is
  630. the axis of the result that holds the solution vectors.
  631. Returns
  632. -------
  633. x : ndarray
  634. Solution to the system ``C x = b``.
  635. Raises
  636. ------
  637. LinAlgError
  638. If the circulant matrix associated with `c` is near singular.
  639. See Also
  640. --------
  641. circulant : circulant matrix
  642. Notes
  643. -----
  644. For a 1-D vector `c` with length `m`, and an array `b`
  645. with shape ``(m, ...)``,
  646. solve_circulant(c, b)
  647. returns the same result as
  648. solve(circulant(c), b)
  649. where `solve` and `circulant` are from `scipy.linalg`.
  650. .. versionadded:: 0.16.0
  651. Examples
  652. --------
  653. >>> import numpy as np
  654. >>> from scipy.linalg import solve_circulant, solve, circulant, lstsq
  655. >>> c = np.array([2, 2, 4])
  656. >>> b = np.array([1, 2, 3])
  657. >>> solve_circulant(c, b)
  658. array([ 0.75, -0.25, 0.25])
  659. Compare that result to solving the system with `scipy.linalg.solve`:
  660. >>> solve(circulant(c), b)
  661. array([ 0.75, -0.25, 0.25])
  662. A singular example:
  663. >>> c = np.array([1, 1, 0, 0])
  664. >>> b = np.array([1, 2, 3, 4])
  665. Calling ``solve_circulant(c, b)`` will raise a `LinAlgError`. For the
  666. least square solution, use the option ``singular='lstsq'``:
  667. >>> solve_circulant(c, b, singular='lstsq')
  668. array([ 0.25, 1.25, 2.25, 1.25])
  669. Compare to `scipy.linalg.lstsq`:
  670. >>> x, resid, rnk, s = lstsq(circulant(c), b)
  671. >>> x
  672. array([ 0.25, 1.25, 2.25, 1.25])
  673. A broadcasting example:
  674. Suppose we have the vectors of two circulant matrices stored in an array
  675. with shape (2, 5), and three `b` vectors stored in an array with shape
  676. (3, 5). For example,
  677. >>> c = np.array([[1.5, 2, 3, 0, 0], [1, 1, 4, 3, 2]])
  678. >>> b = np.arange(15).reshape(-1, 5)
  679. We want to solve all combinations of circulant matrices and `b` vectors,
  680. with the result stored in an array with shape (2, 3, 5). When we
  681. disregard the axes of `c` and `b` that hold the vectors of coefficients,
  682. the shapes of the collections are (2,) and (3,), respectively, which are
  683. not compatible for broadcasting. To have a broadcast result with shape
  684. (2, 3), we add a trivial dimension to `c`: ``c[:, np.newaxis, :]`` has
  685. shape (2, 1, 5). The last dimension holds the coefficients of the
  686. circulant matrices, so when we call `solve_circulant`, we can use the
  687. default ``caxis=-1``. The coefficients of the `b` vectors are in the last
  688. dimension of the array `b`, so we use ``baxis=-1``. If we use the
  689. default `outaxis`, the result will have shape (5, 2, 3), so we'll use
  690. ``outaxis=-1`` to put the solution vectors in the last dimension.
  691. >>> x = solve_circulant(c[:, np.newaxis, :], b, baxis=-1, outaxis=-1)
  692. >>> x.shape
  693. (2, 3, 5)
  694. >>> np.set_printoptions(precision=3) # For compact output of numbers.
  695. >>> x
  696. array([[[-0.118, 0.22 , 1.277, -0.142, 0.302],
  697. [ 0.651, 0.989, 2.046, 0.627, 1.072],
  698. [ 1.42 , 1.758, 2.816, 1.396, 1.841]],
  699. [[ 0.401, 0.304, 0.694, -0.867, 0.377],
  700. [ 0.856, 0.758, 1.149, -0.412, 0.831],
  701. [ 1.31 , 1.213, 1.603, 0.042, 1.286]]])
  702. Check by solving one pair of `c` and `b` vectors (cf. ``x[1, 1, :]``):
  703. >>> solve_circulant(c[1], b[1, :])
  704. array([ 0.856, 0.758, 1.149, -0.412, 0.831])
  705. """
  706. c = np.atleast_1d(c)
  707. nc = _get_axis_len("c", c, caxis)
  708. b = np.atleast_1d(b)
  709. nb = _get_axis_len("b", b, baxis)
  710. if nc != nb:
  711. raise ValueError('Shapes of c {} and b {} are incompatible'
  712. .format(c.shape, b.shape))
  713. fc = np.fft.fft(np.moveaxis(c, caxis, -1), axis=-1)
  714. abs_fc = np.abs(fc)
  715. if tol is None:
  716. # This is the same tolerance as used in np.linalg.matrix_rank.
  717. tol = abs_fc.max(axis=-1) * nc * np.finfo(np.float64).eps
  718. if tol.shape != ():
  719. tol.shape = tol.shape + (1,)
  720. else:
  721. tol = np.atleast_1d(tol)
  722. near_zeros = abs_fc <= tol
  723. is_near_singular = np.any(near_zeros)
  724. if is_near_singular:
  725. if singular == 'raise':
  726. raise LinAlgError("near singular circulant matrix.")
  727. else:
  728. # Replace the small values with 1 to avoid errors in the
  729. # division fb/fc below.
  730. fc[near_zeros] = 1
  731. fb = np.fft.fft(np.moveaxis(b, baxis, -1), axis=-1)
  732. q = fb / fc
  733. if is_near_singular:
  734. # `near_zeros` is a boolean array, same shape as `c`, that is
  735. # True where `fc` is (near) zero. `q` is the broadcasted result
  736. # of fb / fc, so to set the values of `q` to 0 where `fc` is near
  737. # zero, we use a mask that is the broadcast result of an array
  738. # of True values shaped like `b` with `near_zeros`.
  739. mask = np.ones_like(b, dtype=bool) & near_zeros
  740. q[mask] = 0
  741. x = np.fft.ifft(q, axis=-1)
  742. if not (np.iscomplexobj(c) or np.iscomplexobj(b)):
  743. x = x.real
  744. if outaxis != -1:
  745. x = np.moveaxis(x, -1, outaxis)
  746. return x
  747. # matrix inversion
  748. def inv(a, overwrite_a=False, check_finite=True):
  749. """
  750. Compute the inverse of a matrix.
  751. Parameters
  752. ----------
  753. a : array_like
  754. Square matrix to be inverted.
  755. overwrite_a : bool, optional
  756. Discard data in `a` (may improve performance). Default is False.
  757. check_finite : bool, optional
  758. Whether to check that the input matrix contains only finite numbers.
  759. Disabling may give a performance gain, but may result in problems
  760. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  761. Returns
  762. -------
  763. ainv : ndarray
  764. Inverse of the matrix `a`.
  765. Raises
  766. ------
  767. LinAlgError
  768. If `a` is singular.
  769. ValueError
  770. If `a` is not square, or not 2D.
  771. Examples
  772. --------
  773. >>> import numpy as np
  774. >>> from scipy import linalg
  775. >>> a = np.array([[1., 2.], [3., 4.]])
  776. >>> linalg.inv(a)
  777. array([[-2. , 1. ],
  778. [ 1.5, -0.5]])
  779. >>> np.dot(a, linalg.inv(a))
  780. array([[ 1., 0.],
  781. [ 0., 1.]])
  782. """
  783. a1 = _asarray_validated(a, check_finite=check_finite)
  784. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  785. raise ValueError('expected square matrix')
  786. overwrite_a = overwrite_a or _datacopied(a1, a)
  787. # XXX: I found no advantage or disadvantage of using finv.
  788. # finv, = get_flinalg_funcs(('inv',),(a1,))
  789. # if finv is not None:
  790. # a_inv,info = finv(a1,overwrite_a=overwrite_a)
  791. # if info==0:
  792. # return a_inv
  793. # if info>0: raise LinAlgError, "singular matrix"
  794. # if info<0: raise ValueError('illegal value in %d-th argument of '
  795. # 'internal inv.getrf|getri'%(-info))
  796. getrf, getri, getri_lwork = get_lapack_funcs(('getrf', 'getri',
  797. 'getri_lwork'),
  798. (a1,))
  799. lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
  800. if info == 0:
  801. lwork = _compute_lwork(getri_lwork, a1.shape[0])
  802. # XXX: the following line fixes curious SEGFAULT when
  803. # benchmarking 500x500 matrix inverse. This seems to
  804. # be a bug in LAPACK ?getri routine because if lwork is
  805. # minimal (when using lwork[0] instead of lwork[1]) then
  806. # all tests pass. Further investigation is required if
  807. # more such SEGFAULTs occur.
  808. lwork = int(1.01 * lwork)
  809. inv_a, info = getri(lu, piv, lwork=lwork, overwrite_lu=1)
  810. if info > 0:
  811. raise LinAlgError("singular matrix")
  812. if info < 0:
  813. raise ValueError('illegal value in %d-th argument of internal '
  814. 'getrf|getri' % -info)
  815. return inv_a
  816. # Determinant
  817. def det(a, overwrite_a=False, check_finite=True):
  818. """
  819. Compute the determinant of a matrix
  820. The determinant of a square matrix is a value derived arithmetically
  821. from the coefficients of the matrix.
  822. The determinant for a 3x3 matrix, for example, is computed as follows::
  823. a b c
  824. d e f = A
  825. g h i
  826. det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
  827. Parameters
  828. ----------
  829. a : (M, M) array_like
  830. A square matrix.
  831. overwrite_a : bool, optional
  832. Allow overwriting data in a (may enhance performance).
  833. check_finite : bool, optional
  834. Whether to check that the input matrix contains only finite numbers.
  835. Disabling may give a performance gain, but may result in problems
  836. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  837. Returns
  838. -------
  839. det : float or complex
  840. Determinant of `a`.
  841. Notes
  842. -----
  843. The determinant is computed via LU factorization, LAPACK routine z/dgetrf.
  844. Examples
  845. --------
  846. >>> import numpy as np
  847. >>> from scipy import linalg
  848. >>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
  849. >>> linalg.det(a)
  850. 0.0
  851. >>> a = np.array([[0,2,3], [4,5,6], [7,8,9]])
  852. >>> linalg.det(a)
  853. 3.0
  854. """
  855. a1 = _asarray_validated(a, check_finite=check_finite)
  856. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  857. raise ValueError('expected square matrix')
  858. overwrite_a = overwrite_a or _datacopied(a1, a)
  859. fdet, = get_flinalg_funcs(('det',), (a1,))
  860. a_det, info = fdet(a1, overwrite_a=overwrite_a)
  861. if info < 0:
  862. raise ValueError('illegal value in %d-th argument of internal '
  863. 'det.getrf' % -info)
  864. return a_det
  865. # Linear Least Squares
  866. def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
  867. check_finite=True, lapack_driver=None):
  868. """
  869. Compute least-squares solution to equation Ax = b.
  870. Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
  871. Parameters
  872. ----------
  873. a : (M, N) array_like
  874. Left-hand side array
  875. b : (M,) or (M, K) array_like
  876. Right hand side array
  877. cond : float, optional
  878. Cutoff for 'small' singular values; used to determine effective
  879. rank of a. Singular values smaller than
  880. ``cond * largest_singular_value`` are considered zero.
  881. overwrite_a : bool, optional
  882. Discard data in `a` (may enhance performance). Default is False.
  883. overwrite_b : bool, optional
  884. Discard data in `b` (may enhance performance). Default is False.
  885. check_finite : bool, optional
  886. Whether to check that the input matrices contain only finite numbers.
  887. Disabling may give a performance gain, but may result in problems
  888. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  889. lapack_driver : str, optional
  890. Which LAPACK driver is used to solve the least-squares problem.
  891. Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
  892. (``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
  893. faster on many problems. ``'gelss'`` was used historically. It is
  894. generally slow but uses less memory.
  895. .. versionadded:: 0.17.0
  896. Returns
  897. -------
  898. x : (N,) or (N, K) ndarray
  899. Least-squares solution.
  900. residues : (K,) ndarray or float
  901. Square of the 2-norm for each column in ``b - a x``, if ``M > N`` and
  902. ``ndim(A) == n`` (returns a scalar if ``b`` is 1-D). Otherwise a
  903. (0,)-shaped array is returned.
  904. rank : int
  905. Effective rank of `a`.
  906. s : (min(M, N),) ndarray or None
  907. Singular values of `a`. The condition number of ``a`` is
  908. ``s[0] / s[-1]``.
  909. Raises
  910. ------
  911. LinAlgError
  912. If computation does not converge.
  913. ValueError
  914. When parameters are not compatible.
  915. See Also
  916. --------
  917. scipy.optimize.nnls : linear least squares with non-negativity constraint
  918. Notes
  919. -----
  920. When ``'gelsy'`` is used as a driver, `residues` is set to a (0,)-shaped
  921. array and `s` is always ``None``.
  922. Examples
  923. --------
  924. >>> import numpy as np
  925. >>> from scipy.linalg import lstsq
  926. >>> import matplotlib.pyplot as plt
  927. Suppose we have the following data:
  928. >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
  929. >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
  930. We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
  931. to this data. We first form the "design matrix" M, with a constant
  932. column of 1s and a column containing ``x**2``:
  933. >>> M = x[:, np.newaxis]**[0, 2]
  934. >>> M
  935. array([[ 1. , 1. ],
  936. [ 1. , 6.25],
  937. [ 1. , 12.25],
  938. [ 1. , 16. ],
  939. [ 1. , 25. ],
  940. [ 1. , 49. ],
  941. [ 1. , 72.25]])
  942. We want to find the least-squares solution to ``M.dot(p) = y``,
  943. where ``p`` is a vector with length 2 that holds the parameters
  944. ``a`` and ``b``.
  945. >>> p, res, rnk, s = lstsq(M, y)
  946. >>> p
  947. array([ 0.20925829, 0.12013861])
  948. Plot the data and the fitted curve.
  949. >>> plt.plot(x, y, 'o', label='data')
  950. >>> xx = np.linspace(0, 9, 101)
  951. >>> yy = p[0] + p[1]*xx**2
  952. >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
  953. >>> plt.xlabel('x')
  954. >>> plt.ylabel('y')
  955. >>> plt.legend(framealpha=1, shadow=True)
  956. >>> plt.grid(alpha=0.25)
  957. >>> plt.show()
  958. """
  959. a1 = _asarray_validated(a, check_finite=check_finite)
  960. b1 = _asarray_validated(b, check_finite=check_finite)
  961. if len(a1.shape) != 2:
  962. raise ValueError('Input array a should be 2D')
  963. m, n = a1.shape
  964. if len(b1.shape) == 2:
  965. nrhs = b1.shape[1]
  966. else:
  967. nrhs = 1
  968. if m != b1.shape[0]:
  969. raise ValueError('Shape mismatch: a and b should have the same number'
  970. ' of rows ({} != {}).'.format(m, b1.shape[0]))
  971. if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
  972. x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
  973. if n == 0:
  974. residues = np.linalg.norm(b1, axis=0)**2
  975. else:
  976. residues = np.empty((0,))
  977. return x, residues, 0, np.empty((0,))
  978. driver = lapack_driver
  979. if driver is None:
  980. driver = lstsq.default_lapack_driver
  981. if driver not in ('gelsd', 'gelsy', 'gelss'):
  982. raise ValueError('LAPACK driver "%s" is not found' % driver)
  983. lapack_func, lapack_lwork = get_lapack_funcs((driver,
  984. '%s_lwork' % driver),
  985. (a1, b1))
  986. real_data = True if (lapack_func.dtype.kind == 'f') else False
  987. if m < n:
  988. # need to extend b matrix as it will be filled with
  989. # a larger solution matrix
  990. if len(b1.shape) == 2:
  991. b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
  992. b2[:m, :] = b1
  993. else:
  994. b2 = np.zeros(n, dtype=lapack_func.dtype)
  995. b2[:m] = b1
  996. b1 = b2
  997. overwrite_a = overwrite_a or _datacopied(a1, a)
  998. overwrite_b = overwrite_b or _datacopied(b1, b)
  999. if cond is None:
  1000. cond = np.finfo(lapack_func.dtype).eps
  1001. if driver in ('gelss', 'gelsd'):
  1002. if driver == 'gelss':
  1003. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1004. v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
  1005. overwrite_a=overwrite_a,
  1006. overwrite_b=overwrite_b)
  1007. elif driver == 'gelsd':
  1008. if real_data:
  1009. lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1010. x, s, rank, info = lapack_func(a1, b1, lwork,
  1011. iwork, cond, False, False)
  1012. else: # complex data
  1013. lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
  1014. nrhs, cond)
  1015. x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
  1016. cond, False, False)
  1017. if info > 0:
  1018. raise LinAlgError("SVD did not converge in Linear Least Squares")
  1019. if info < 0:
  1020. raise ValueError('illegal value in %d-th argument of internal %s'
  1021. % (-info, lapack_driver))
  1022. resids = np.asarray([], dtype=x.dtype)
  1023. if m > n:
  1024. x1 = x[:n]
  1025. if rank == n:
  1026. resids = np.sum(np.abs(x[n:])**2, axis=0)
  1027. x = x1
  1028. return x, resids, rank, s
  1029. elif driver == 'gelsy':
  1030. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1031. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  1032. v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
  1033. lwork, False, False)
  1034. if info < 0:
  1035. raise ValueError("illegal value in %d-th argument of internal "
  1036. "gelsy" % -info)
  1037. if m > n:
  1038. x1 = x[:n]
  1039. x = x1
  1040. return x, np.array([], x.dtype), rank, None
  1041. lstsq.default_lapack_driver = 'gelsd'
  1042. def pinv(a, atol=None, rtol=None, return_rank=False, check_finite=True,
  1043. cond=None, rcond=None):
  1044. """
  1045. Compute the (Moore-Penrose) pseudo-inverse of a matrix.
  1046. Calculate a generalized inverse of a matrix using its
  1047. singular-value decomposition ``U @ S @ V`` in the economy mode and picking
  1048. up only the columns/rows that are associated with significant singular
  1049. values.
  1050. If ``s`` is the maximum singular value of ``a``, then the
  1051. significance cut-off value is determined by ``atol + rtol * s``. Any
  1052. singular value below this value is assumed insignificant.
  1053. Parameters
  1054. ----------
  1055. a : (M, N) array_like
  1056. Matrix to be pseudo-inverted.
  1057. atol : float, optional
  1058. Absolute threshold term, default value is 0.
  1059. .. versionadded:: 1.7.0
  1060. rtol : float, optional
  1061. Relative threshold term, default value is ``max(M, N) * eps`` where
  1062. ``eps`` is the machine precision value of the datatype of ``a``.
  1063. .. versionadded:: 1.7.0
  1064. return_rank : bool, optional
  1065. If True, return the effective rank of the matrix.
  1066. check_finite : bool, optional
  1067. Whether to check that the input matrix contains only finite numbers.
  1068. Disabling may give a performance gain, but may result in problems
  1069. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1070. cond, rcond : float, optional
  1071. In older versions, these values were meant to be used as ``atol`` with
  1072. ``rtol=0``. If both were given ``rcond`` overwrote ``cond`` and hence
  1073. the code was not correct. Thus using these are strongly discouraged and
  1074. the tolerances above are recommended instead. In fact, if provided,
  1075. atol, rtol takes precedence over these keywords.
  1076. .. versionchanged:: 1.7.0
  1077. Deprecated in favor of ``rtol`` and ``atol`` parameters above and
  1078. will be removed in future versions of SciPy.
  1079. .. versionchanged:: 1.3.0
  1080. Previously the default cutoff value was just ``eps*f`` where ``f``
  1081. was ``1e3`` for single precision and ``1e6`` for double precision.
  1082. Returns
  1083. -------
  1084. B : (N, M) ndarray
  1085. The pseudo-inverse of matrix `a`.
  1086. rank : int
  1087. The effective rank of the matrix. Returned if `return_rank` is True.
  1088. Raises
  1089. ------
  1090. LinAlgError
  1091. If SVD computation does not converge.
  1092. Examples
  1093. --------
  1094. >>> import numpy as np
  1095. >>> from scipy import linalg
  1096. >>> rng = np.random.default_rng()
  1097. >>> a = rng.standard_normal((9, 6))
  1098. >>> B = linalg.pinv(a)
  1099. >>> np.allclose(a, a @ B @ a)
  1100. True
  1101. >>> np.allclose(B, B @ a @ B)
  1102. True
  1103. """
  1104. a = _asarray_validated(a, check_finite=check_finite)
  1105. u, s, vh = _decomp_svd.svd(a, full_matrices=False, check_finite=False)
  1106. t = u.dtype.char.lower()
  1107. maxS = np.max(s)
  1108. if rcond or cond:
  1109. warn('Use of the "cond" and "rcond" keywords are deprecated and '
  1110. 'will be removed in future versions of SciPy. Use "atol" and '
  1111. '"rtol" keywords instead', DeprecationWarning, stacklevel=2)
  1112. # backwards compatible only atol and rtol are both missing
  1113. if (rcond or cond) and (atol is None) and (rtol is None):
  1114. atol = rcond or cond
  1115. rtol = 0.
  1116. atol = 0. if atol is None else atol
  1117. rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
  1118. if (atol < 0.) or (rtol < 0.):
  1119. raise ValueError("atol and rtol values must be positive.")
  1120. val = atol + maxS * rtol
  1121. rank = np.sum(s > val)
  1122. u = u[:, :rank]
  1123. u /= s[:rank]
  1124. B = (u @ vh[:rank]).conj().T
  1125. if return_rank:
  1126. return B, rank
  1127. else:
  1128. return B
  1129. def pinvh(a, atol=None, rtol=None, lower=True, return_rank=False,
  1130. check_finite=True):
  1131. """
  1132. Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
  1133. Calculate a generalized inverse of a complex Hermitian/real symmetric
  1134. matrix using its eigenvalue decomposition and including all eigenvalues
  1135. with 'large' absolute value.
  1136. Parameters
  1137. ----------
  1138. a : (N, N) array_like
  1139. Real symmetric or complex hermetian matrix to be pseudo-inverted
  1140. atol : float, optional
  1141. Absolute threshold term, default value is 0.
  1142. .. versionadded:: 1.7.0
  1143. rtol : float, optional
  1144. Relative threshold term, default value is ``N * eps`` where
  1145. ``eps`` is the machine precision value of the datatype of ``a``.
  1146. .. versionadded:: 1.7.0
  1147. lower : bool, optional
  1148. Whether the pertinent array data is taken from the lower or upper
  1149. triangle of `a`. (Default: lower)
  1150. return_rank : bool, optional
  1151. If True, return the effective rank of the matrix.
  1152. check_finite : bool, optional
  1153. Whether to check that the input matrix contains only finite numbers.
  1154. Disabling may give a performance gain, but may result in problems
  1155. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1156. Returns
  1157. -------
  1158. B : (N, N) ndarray
  1159. The pseudo-inverse of matrix `a`.
  1160. rank : int
  1161. The effective rank of the matrix. Returned if `return_rank` is True.
  1162. Raises
  1163. ------
  1164. LinAlgError
  1165. If eigenvalue algorithm does not converge.
  1166. Examples
  1167. --------
  1168. >>> import numpy as np
  1169. >>> from scipy.linalg import pinvh
  1170. >>> rng = np.random.default_rng()
  1171. >>> a = rng.standard_normal((9, 6))
  1172. >>> a = np.dot(a, a.T)
  1173. >>> B = pinvh(a)
  1174. >>> np.allclose(a, a @ B @ a)
  1175. True
  1176. >>> np.allclose(B, B @ a @ B)
  1177. True
  1178. """
  1179. a = _asarray_validated(a, check_finite=check_finite)
  1180. s, u = _decomp.eigh(a, lower=lower, check_finite=False)
  1181. t = u.dtype.char.lower()
  1182. maxS = np.max(np.abs(s))
  1183. atol = 0. if atol is None else atol
  1184. rtol = max(a.shape) * np.finfo(t).eps if (rtol is None) else rtol
  1185. if (atol < 0.) or (rtol < 0.):
  1186. raise ValueError("atol and rtol values must be positive.")
  1187. val = atol + maxS * rtol
  1188. above_cutoff = (abs(s) > val)
  1189. psigma_diag = 1.0 / s[above_cutoff]
  1190. u = u[:, above_cutoff]
  1191. B = (u * psigma_diag) @ u.conj().T
  1192. if return_rank:
  1193. return B, len(psigma_diag)
  1194. else:
  1195. return B
  1196. def matrix_balance(A, permute=True, scale=True, separate=False,
  1197. overwrite_a=False):
  1198. """
  1199. Compute a diagonal similarity transformation for row/column balancing.
  1200. The balancing tries to equalize the row and column 1-norms by applying
  1201. a similarity transformation such that the magnitude variation of the
  1202. matrix entries is reflected to the scaling matrices.
  1203. Moreover, if enabled, the matrix is first permuted to isolate the upper
  1204. triangular parts of the matrix and, again if scaling is also enabled,
  1205. only the remaining subblocks are subjected to scaling.
  1206. The balanced matrix satisfies the following equality
  1207. .. math::
  1208. B = T^{-1} A T
  1209. The scaling coefficients are approximated to the nearest power of 2
  1210. to avoid round-off errors.
  1211. Parameters
  1212. ----------
  1213. A : (n, n) array_like
  1214. Square data matrix for the balancing.
  1215. permute : bool, optional
  1216. The selector to define whether permutation of A is also performed
  1217. prior to scaling.
  1218. scale : bool, optional
  1219. The selector to turn on and off the scaling. If False, the matrix
  1220. will not be scaled.
  1221. separate : bool, optional
  1222. This switches from returning a full matrix of the transformation
  1223. to a tuple of two separate 1-D permutation and scaling arrays.
  1224. overwrite_a : bool, optional
  1225. This is passed to xGEBAL directly. Essentially, overwrites the result
  1226. to the data. It might increase the space efficiency. See LAPACK manual
  1227. for details. This is False by default.
  1228. Returns
  1229. -------
  1230. B : (n, n) ndarray
  1231. Balanced matrix
  1232. T : (n, n) ndarray
  1233. A possibly permuted diagonal matrix whose nonzero entries are
  1234. integer powers of 2 to avoid numerical truncation errors.
  1235. scale, perm : (n,) ndarray
  1236. If ``separate`` keyword is set to True then instead of the array
  1237. ``T`` above, the scaling and the permutation vectors are given
  1238. separately as a tuple without allocating the full array ``T``.
  1239. Notes
  1240. -----
  1241. This algorithm is particularly useful for eigenvalue and matrix
  1242. decompositions and in many cases it is already called by various
  1243. LAPACK routines.
  1244. The algorithm is based on the well-known technique of [1]_ and has
  1245. been modified to account for special cases. See [2]_ for details
  1246. which have been implemented since LAPACK v3.5.0. Before this version
  1247. there are corner cases where balancing can actually worsen the
  1248. conditioning. See [3]_ for such examples.
  1249. The code is a wrapper around LAPACK's xGEBAL routine family for matrix
  1250. balancing.
  1251. .. versionadded:: 0.19.0
  1252. References
  1253. ----------
  1254. .. [1] B.N. Parlett and C. Reinsch, "Balancing a Matrix for
  1255. Calculation of Eigenvalues and Eigenvectors", Numerische Mathematik,
  1256. Vol.13(4), 1969, :doi:`10.1007/BF02165404`
  1257. .. [2] R. James, J. Langou, B.R. Lowery, "On matrix balancing and
  1258. eigenvector computation", 2014, :arxiv:`1401.5766`
  1259. .. [3] D.S. Watkins. A case where balancing is harmful.
  1260. Electron. Trans. Numer. Anal, Vol.23, 2006.
  1261. Examples
  1262. --------
  1263. >>> import numpy as np
  1264. >>> from scipy import linalg
  1265. >>> x = np.array([[1,2,0], [9,1,0.01], [1,2,10*np.pi]])
  1266. >>> y, permscale = linalg.matrix_balance(x)
  1267. >>> np.abs(x).sum(axis=0) / np.abs(x).sum(axis=1)
  1268. array([ 3.66666667, 0.4995005 , 0.91312162])
  1269. >>> np.abs(y).sum(axis=0) / np.abs(y).sum(axis=1)
  1270. array([ 1.2 , 1.27041742, 0.92658316]) # may vary
  1271. >>> permscale # only powers of 2 (0.5 == 2^(-1))
  1272. array([[ 0.5, 0. , 0. ], # may vary
  1273. [ 0. , 1. , 0. ],
  1274. [ 0. , 0. , 1. ]])
  1275. """
  1276. A = np.atleast_2d(_asarray_validated(A, check_finite=True))
  1277. if not np.equal(*A.shape):
  1278. raise ValueError('The data matrix for balancing should be square.')
  1279. gebal = get_lapack_funcs(('gebal'), (A,))
  1280. B, lo, hi, ps, info = gebal(A, scale=scale, permute=permute,
  1281. overwrite_a=overwrite_a)
  1282. if info < 0:
  1283. raise ValueError('xGEBAL exited with the internal error '
  1284. '"illegal value in argument number {}.". See '
  1285. 'LAPACK documentation for the xGEBAL error codes.'
  1286. ''.format(-info))
  1287. # Separate the permutations from the scalings and then convert to int
  1288. scaling = np.ones_like(ps, dtype=float)
  1289. scaling[lo:hi+1] = ps[lo:hi+1]
  1290. # gebal uses 1-indexing
  1291. ps = ps.astype(int, copy=False) - 1
  1292. n = A.shape[0]
  1293. perm = np.arange(n)
  1294. # LAPACK permutes with the ordering n --> hi, then 0--> lo
  1295. if hi < n:
  1296. for ind, x in enumerate(ps[hi+1:][::-1], 1):
  1297. if n-ind == x:
  1298. continue
  1299. perm[[x, n-ind]] = perm[[n-ind, x]]
  1300. if lo > 0:
  1301. for ind, x in enumerate(ps[:lo]):
  1302. if ind == x:
  1303. continue
  1304. perm[[x, ind]] = perm[[ind, x]]
  1305. if separate:
  1306. return B, (scaling, perm)
  1307. # get the inverse permutation
  1308. iperm = np.empty_like(perm)
  1309. iperm[perm] = np.arange(n)
  1310. return B, np.diag(scaling)[iperm, :]
  1311. def _validate_args_for_toeplitz_ops(c_or_cr, b, check_finite, keep_b_shape,
  1312. enforce_square=True):
  1313. """Validate arguments and format inputs for toeplitz functions
  1314. Parameters
  1315. ----------
  1316. c_or_cr : array_like or tuple of (array_like, array_like)
  1317. The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
  1318. actual shape of ``c``, it will be converted to a 1-D array. If not
  1319. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  1320. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  1321. of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
  1322. of ``r``, it will be converted to a 1-D array.
  1323. b : (M,) or (M, K) array_like
  1324. Right-hand side in ``T x = b``.
  1325. check_finite : bool
  1326. Whether to check that the input matrices contain only finite numbers.
  1327. Disabling may give a performance gain, but may result in problems
  1328. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  1329. keep_b_shape : bool
  1330. Whether to convert a (M,) dimensional b into a (M, 1) dimensional
  1331. matrix.
  1332. enforce_square : bool, optional
  1333. If True (default), this verifies that the Toeplitz matrix is square.
  1334. Returns
  1335. -------
  1336. r : array
  1337. 1d array corresponding to the first row of the Toeplitz matrix.
  1338. c: array
  1339. 1d array corresponding to the first column of the Toeplitz matrix.
  1340. b: array
  1341. (M,), (M, 1) or (M, K) dimensional array, post validation,
  1342. corresponding to ``b``.
  1343. dtype: numpy datatype
  1344. ``dtype`` stores the datatype of ``r``, ``c`` and ``b``. If any of
  1345. ``r``, ``c`` or ``b`` are complex, ``dtype`` is ``np.complex128``,
  1346. otherwise, it is ``np.float``.
  1347. b_shape: tuple
  1348. Shape of ``b`` after passing it through ``_asarray_validated``.
  1349. """
  1350. if isinstance(c_or_cr, tuple):
  1351. c, r = c_or_cr
  1352. c = _asarray_validated(c, check_finite=check_finite).ravel()
  1353. r = _asarray_validated(r, check_finite=check_finite).ravel()
  1354. else:
  1355. c = _asarray_validated(c_or_cr, check_finite=check_finite).ravel()
  1356. r = c.conjugate()
  1357. if b is None:
  1358. raise ValueError('`b` must be an array, not None.')
  1359. b = _asarray_validated(b, check_finite=check_finite)
  1360. b_shape = b.shape
  1361. is_not_square = r.shape[0] != c.shape[0]
  1362. if (enforce_square and is_not_square) or b.shape[0] != r.shape[0]:
  1363. raise ValueError('Incompatible dimensions.')
  1364. is_cmplx = np.iscomplexobj(r) or np.iscomplexobj(c) or np.iscomplexobj(b)
  1365. dtype = np.complex128 if is_cmplx else np.double
  1366. r, c, b = (np.asarray(i, dtype=dtype) for i in (r, c, b))
  1367. if b.ndim == 1 and not keep_b_shape:
  1368. b = b.reshape(-1, 1)
  1369. elif b.ndim != 1:
  1370. b = b.reshape(b.shape[0], -1)
  1371. return r, c, b, dtype, b_shape
  1372. def matmul_toeplitz(c_or_cr, x, check_finite=False, workers=None):
  1373. """Efficient Toeplitz Matrix-Matrix Multiplication using FFT
  1374. This function returns the matrix multiplication between a Toeplitz
  1375. matrix and a dense matrix.
  1376. The Toeplitz matrix has constant diagonals, with c as its first column
  1377. and r as its first row. If r is not given, ``r == conjugate(c)`` is
  1378. assumed.
  1379. Parameters
  1380. ----------
  1381. c_or_cr : array_like or tuple of (array_like, array_like)
  1382. The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
  1383. actual shape of ``c``, it will be converted to a 1-D array. If not
  1384. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  1385. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  1386. of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
  1387. of ``r``, it will be converted to a 1-D array.
  1388. x : (M,) or (M, K) array_like
  1389. Matrix with which to multiply.
  1390. check_finite : bool, optional
  1391. Whether to check that the input matrices contain only finite numbers.
  1392. Disabling may give a performance gain, but may result in problems
  1393. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  1394. workers : int, optional
  1395. To pass to scipy.fft.fft and ifft. Maximum number of workers to use
  1396. for parallel computation. If negative, the value wraps around from
  1397. ``os.cpu_count()``. See scipy.fft.fft for more details.
  1398. Returns
  1399. -------
  1400. T @ x : (M,) or (M, K) ndarray
  1401. The result of the matrix multiplication ``T @ x``. Shape of return
  1402. matches shape of `x`.
  1403. See Also
  1404. --------
  1405. toeplitz : Toeplitz matrix
  1406. solve_toeplitz : Solve a Toeplitz system using Levinson Recursion
  1407. Notes
  1408. -----
  1409. The Toeplitz matrix is embedded in a circulant matrix and the FFT is used
  1410. to efficiently calculate the matrix-matrix product.
  1411. Because the computation is based on the FFT, integer inputs will
  1412. result in floating point outputs. This is unlike NumPy's `matmul`,
  1413. which preserves the data type of the input.
  1414. This is partly based on the implementation that can be found in [1]_,
  1415. licensed under the MIT license. More information about the method can be
  1416. found in reference [2]_. References [3]_ and [4]_ have more reference
  1417. implementations in Python.
  1418. .. versionadded:: 1.6.0
  1419. References
  1420. ----------
  1421. .. [1] Jacob R Gardner, Geoff Pleiss, David Bindel, Kilian
  1422. Q Weinberger, Andrew Gordon Wilson, "GPyTorch: Blackbox Matrix-Matrix
  1423. Gaussian Process Inference with GPU Acceleration" with contributions
  1424. from Max Balandat and Ruihan Wu. Available online:
  1425. https://github.com/cornellius-gp/gpytorch
  1426. .. [2] J. Demmel, P. Koev, and X. Li, "A Brief Survey of Direct Linear
  1427. Solvers". In Z. Bai, J. Demmel, J. Dongarra, A. Ruhe, and H. van der
  1428. Vorst, editors. Templates for the Solution of Algebraic Eigenvalue
  1429. Problems: A Practical Guide. SIAM, Philadelphia, 2000. Available at:
  1430. http://www.netlib.org/utk/people/JackDongarra/etemplates/node384.html
  1431. .. [3] R. Scheibler, E. Bezzam, I. Dokmanic, Pyroomacoustics: A Python
  1432. package for audio room simulations and array processing algorithms,
  1433. Proc. IEEE ICASSP, Calgary, CA, 2018.
  1434. https://github.com/LCAV/pyroomacoustics/blob/pypi-release/
  1435. pyroomacoustics/adaptive/util.py
  1436. .. [4] Marano S, Edwards B, Ferrari G and Fah D (2017), "Fitting
  1437. Earthquake Spectra: Colored Noise and Incomplete Data", Bulletin of
  1438. the Seismological Society of America., January, 2017. Vol. 107(1),
  1439. pp. 276-291.
  1440. Examples
  1441. --------
  1442. Multiply the Toeplitz matrix T with matrix x::
  1443. [ 1 -1 -2 -3] [1 10]
  1444. T = [ 3 1 -1 -2] x = [2 11]
  1445. [ 6 3 1 -1] [2 11]
  1446. [10 6 3 1] [5 19]
  1447. To specify the Toeplitz matrix, only the first column and the first
  1448. row are needed.
  1449. >>> import numpy as np
  1450. >>> c = np.array([1, 3, 6, 10]) # First column of T
  1451. >>> r = np.array([1, -1, -2, -3]) # First row of T
  1452. >>> x = np.array([[1, 10], [2, 11], [2, 11], [5, 19]])
  1453. >>> from scipy.linalg import toeplitz, matmul_toeplitz
  1454. >>> matmul_toeplitz((c, r), x)
  1455. array([[-20., -80.],
  1456. [ -7., -8.],
  1457. [ 9., 85.],
  1458. [ 33., 218.]])
  1459. Check the result by creating the full Toeplitz matrix and
  1460. multiplying it by ``x``.
  1461. >>> toeplitz(c, r) @ x
  1462. array([[-20, -80],
  1463. [ -7, -8],
  1464. [ 9, 85],
  1465. [ 33, 218]])
  1466. The full matrix is never formed explicitly, so this routine
  1467. is suitable for very large Toeplitz matrices.
  1468. >>> n = 1000000
  1469. >>> matmul_toeplitz([1] + [0]*(n-1), np.ones(n))
  1470. array([1., 1., 1., ..., 1., 1., 1.])
  1471. """
  1472. from ..fft import fft, ifft, rfft, irfft
  1473. r, c, x, dtype, x_shape = _validate_args_for_toeplitz_ops(
  1474. c_or_cr, x, check_finite, keep_b_shape=False, enforce_square=False)
  1475. n, m = x.shape
  1476. T_nrows = len(c)
  1477. T_ncols = len(r)
  1478. p = T_nrows + T_ncols - 1 # equivalent to len(embedded_col)
  1479. embedded_col = np.concatenate((c, r[-1:0:-1]))
  1480. if np.iscomplexobj(embedded_col) or np.iscomplexobj(x):
  1481. fft_mat = fft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
  1482. fft_x = fft(x, n=p, axis=0, workers=workers)
  1483. mat_times_x = ifft(fft_mat*fft_x, axis=0,
  1484. workers=workers)[:T_nrows, :]
  1485. else:
  1486. # Real inputs; using rfft is faster
  1487. fft_mat = rfft(embedded_col, axis=0, workers=workers).reshape(-1, 1)
  1488. fft_x = rfft(x, n=p, axis=0, workers=workers)
  1489. mat_times_x = irfft(fft_mat*fft_x, axis=0,
  1490. workers=workers, n=p)[:T_nrows, :]
  1491. return_shape = (T_nrows,) if len(x_shape) == 1 else (T_nrows, m)
  1492. return mat_times_x.reshape(*return_shape)