dense.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. """
  2. Module for the ddm_* routines for operating on a matrix in list of lists
  3. matrix representation.
  4. These routines are used internally by the DDM class which also provides a
  5. friendlier interface for them. The idea here is to implement core matrix
  6. routines in a way that can be applied to any simple list representation
  7. without the need to use any particular matrix class. For example we can
  8. compute the RREF of a matrix like:
  9. >>> from sympy.polys.matrices.dense import ddm_irref
  10. >>> M = [[1, 2, 3], [4, 5, 6]]
  11. >>> pivots = ddm_irref(M)
  12. >>> M
  13. [[1.0, 0.0, -1.0], [0, 1.0, 2.0]]
  14. These are lower-level routines that work mostly in place.The routines at this
  15. level should not need to know what the domain of the elements is but should
  16. ideally document what operations they will use and what functions they need to
  17. be provided with.
  18. The next-level up is the DDM class which uses these routines but wraps them up
  19. with an interface that handles copying etc and keeps track of the Domain of
  20. the elements of the matrix:
  21. >>> from sympy.polys.domains import QQ
  22. >>> from sympy.polys.matrices.ddm import DDM
  23. >>> M = DDM([[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]], (2, 3), QQ)
  24. >>> M
  25. [[1, 2, 3], [4, 5, 6]]
  26. >>> Mrref, pivots = M.rref()
  27. >>> Mrref
  28. [[1, 0, -1], [0, 1, 2]]
  29. """
  30. from __future__ import annotations
  31. from operator import mul
  32. from .exceptions import (
  33. DMShapeError,
  34. DMNonInvertibleMatrixError,
  35. DMNonSquareMatrixError,
  36. )
  37. from typing import Sequence, TypeVar
  38. from sympy.polys.matrices._typing import RingElement
  39. T = TypeVar('T')
  40. R = TypeVar('R', bound=RingElement)
  41. def ddm_transpose(matrix: Sequence[Sequence[T]]) -> list[list[T]]:
  42. """matrix transpose"""
  43. return list(map(list, zip(*matrix)))
  44. def ddm_iadd(a: list[list[R]], b: Sequence[Sequence[R]]) -> None:
  45. """a += b"""
  46. for ai, bi in zip(a, b):
  47. for j, bij in enumerate(bi):
  48. ai[j] += bij
  49. def ddm_isub(a: list[list[R]], b: Sequence[Sequence[R]]) -> None:
  50. """a -= b"""
  51. for ai, bi in zip(a, b):
  52. for j, bij in enumerate(bi):
  53. ai[j] -= bij
  54. def ddm_ineg(a: list[list[R]]) -> None:
  55. """a <-- -a"""
  56. for ai in a:
  57. for j, aij in enumerate(ai):
  58. ai[j] = -aij
  59. def ddm_imul(a: list[list[R]], b: R) -> None:
  60. for ai in a:
  61. for j, aij in enumerate(ai):
  62. ai[j] = aij * b
  63. def ddm_irmul(a: list[list[R]], b: R) -> None:
  64. for ai in a:
  65. for j, aij in enumerate(ai):
  66. ai[j] = b * aij
  67. def ddm_imatmul(
  68. a: list[list[R]], b: Sequence[Sequence[R]], c: Sequence[Sequence[R]]
  69. ) -> None:
  70. """a += b @ c"""
  71. cT = list(zip(*c))
  72. for bi, ai in zip(b, a):
  73. for j, cTj in enumerate(cT):
  74. ai[j] = sum(map(mul, bi, cTj), ai[j])
  75. def ddm_irref(a, _partial_pivot=False):
  76. """a <-- rref(a)"""
  77. # a is (m x n)
  78. m = len(a)
  79. if not m:
  80. return []
  81. n = len(a[0])
  82. i = 0
  83. pivots = []
  84. for j in range(n):
  85. # Proper pivoting should be used for all domains for performance
  86. # reasons but it is only strictly needed for RR and CC (and possibly
  87. # other domains like RR(x)). This path is used by DDM.rref() if the
  88. # domain is RR or CC. It uses partial (row) pivoting based on the
  89. # absolute value of the pivot candidates.
  90. if _partial_pivot:
  91. ip = max(range(i, m), key=lambda ip: abs(a[ip][j]))
  92. a[i], a[ip] = a[ip], a[i]
  93. # pivot
  94. aij = a[i][j]
  95. # zero-pivot
  96. if not aij:
  97. for ip in range(i+1, m):
  98. aij = a[ip][j]
  99. # row-swap
  100. if aij:
  101. a[i], a[ip] = a[ip], a[i]
  102. break
  103. else:
  104. # next column
  105. continue
  106. # normalise row
  107. ai = a[i]
  108. aijinv = aij**-1
  109. for l in range(j, n):
  110. ai[l] *= aijinv # ai[j] = one
  111. # eliminate above and below to the right
  112. for k, ak in enumerate(a):
  113. if k == i or not ak[j]:
  114. continue
  115. akj = ak[j]
  116. ak[j] -= akj # ak[j] = zero
  117. for l in range(j+1, n):
  118. ak[l] -= akj * ai[l]
  119. # next row
  120. pivots.append(j)
  121. i += 1
  122. # no more rows?
  123. if i >= m:
  124. break
  125. return pivots
  126. def ddm_idet(a, K):
  127. """a <-- echelon(a); return det"""
  128. # Bareiss algorithm
  129. # https://www.math.usm.edu/perry/Research/Thesis_DRL.pdf
  130. # a is (m x n)
  131. m = len(a)
  132. if not m:
  133. return K.one
  134. n = len(a[0])
  135. exquo = K.exquo
  136. # uf keeps track of the sign change from row swaps
  137. uf = K.one
  138. for k in range(n-1):
  139. if not a[k][k]:
  140. for i in range(k+1, n):
  141. if a[i][k]:
  142. a[k], a[i] = a[i], a[k]
  143. uf = -uf
  144. break
  145. else:
  146. return K.zero
  147. akkm1 = a[k-1][k-1] if k else K.one
  148. for i in range(k+1, n):
  149. for j in range(k+1, n):
  150. a[i][j] = exquo(a[i][j]*a[k][k] - a[i][k]*a[k][j], akkm1)
  151. return uf * a[-1][-1]
  152. def ddm_iinv(ainv, a, K):
  153. if not K.is_Field:
  154. raise ValueError('Not a field')
  155. # a is (m x n)
  156. m = len(a)
  157. if not m:
  158. return
  159. n = len(a[0])
  160. if m != n:
  161. raise DMNonSquareMatrixError
  162. eye = [[K.one if i==j else K.zero for j in range(n)] for i in range(n)]
  163. Aaug = [row + eyerow for row, eyerow in zip(a, eye)]
  164. pivots = ddm_irref(Aaug)
  165. if pivots != list(range(n)):
  166. raise DMNonInvertibleMatrixError('Matrix det == 0; not invertible.')
  167. ainv[:] = [row[n:] for row in Aaug]
  168. def ddm_ilu_split(L, U, K):
  169. """L, U <-- LU(U)"""
  170. m = len(U)
  171. if not m:
  172. return []
  173. n = len(U[0])
  174. swaps = ddm_ilu(U)
  175. zeros = [K.zero] * min(m, n)
  176. for i in range(1, m):
  177. j = min(i, n)
  178. L[i][:j] = U[i][:j]
  179. U[i][:j] = zeros[:j]
  180. return swaps
  181. def ddm_ilu(a):
  182. """a <-- LU(a)"""
  183. m = len(a)
  184. if not m:
  185. return []
  186. n = len(a[0])
  187. swaps = []
  188. for i in range(min(m, n)):
  189. if not a[i][i]:
  190. for ip in range(i+1, m):
  191. if a[ip][i]:
  192. swaps.append((i, ip))
  193. a[i], a[ip] = a[ip], a[i]
  194. break
  195. else:
  196. # M = Matrix([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]])
  197. continue
  198. for j in range(i+1, m):
  199. l_ji = a[j][i] / a[i][i]
  200. a[j][i] = l_ji
  201. for k in range(i+1, n):
  202. a[j][k] -= l_ji * a[i][k]
  203. return swaps
  204. def ddm_ilu_solve(x, L, U, swaps, b):
  205. """x <-- solve(L*U*x = swaps(b))"""
  206. m = len(U)
  207. if not m:
  208. return
  209. n = len(U[0])
  210. m2 = len(b)
  211. if not m2:
  212. raise DMShapeError("Shape mismtch")
  213. o = len(b[0])
  214. if m != m2:
  215. raise DMShapeError("Shape mismtch")
  216. if m < n:
  217. raise NotImplementedError("Underdetermined")
  218. if swaps:
  219. b = [row[:] for row in b]
  220. for i1, i2 in swaps:
  221. b[i1], b[i2] = b[i2], b[i1]
  222. # solve Ly = b
  223. y = [[None] * o for _ in range(m)]
  224. for k in range(o):
  225. for i in range(m):
  226. rhs = b[i][k]
  227. for j in range(i):
  228. rhs -= L[i][j] * y[j][k]
  229. y[i][k] = rhs
  230. if m > n:
  231. for i in range(n, m):
  232. for j in range(o):
  233. if y[i][j]:
  234. raise DMNonInvertibleMatrixError
  235. # Solve Ux = y
  236. for k in range(o):
  237. for i in reversed(range(n)):
  238. if not U[i][i]:
  239. raise DMNonInvertibleMatrixError
  240. rhs = y[i][k]
  241. for j in range(i+1, n):
  242. rhs -= U[i][j] * x[j][k]
  243. x[i][k] = rhs / U[i][i]
  244. def ddm_berk(M, K):
  245. m = len(M)
  246. if not m:
  247. return [[K.one]]
  248. n = len(M[0])
  249. if m != n:
  250. raise DMShapeError("Not square")
  251. if n == 1:
  252. return [[K.one], [-M[0][0]]]
  253. a = M[0][0]
  254. R = [M[0][1:]]
  255. C = [[row[0]] for row in M[1:]]
  256. A = [row[1:] for row in M[1:]]
  257. q = ddm_berk(A, K)
  258. T = [[K.zero] * n for _ in range(n+1)]
  259. for i in range(n):
  260. T[i][i] = K.one
  261. T[i+1][i] = -a
  262. for i in range(2, n+1):
  263. if i == 2:
  264. AnC = C
  265. else:
  266. C = AnC
  267. AnC = [[K.zero] for row in C]
  268. ddm_imatmul(AnC, A, C)
  269. RAnC = [[K.zero]]
  270. ddm_imatmul(RAnC, R, AnC)
  271. for j in range(0, n+1-i):
  272. T[i+j][j] = -RAnC[0][0]
  273. qout = [[K.zero] for _ in range(n+1)]
  274. ddm_imatmul(qout, T, q)
  275. return qout