ddm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. """
  2. Module for the DDM class.
  3. The DDM class is an internal representation used by DomainMatrix. The letters
  4. DDM stand for Dense Domain Matrix. A DDM instance represents a matrix using
  5. elements from a polynomial Domain (e.g. ZZ, QQ, ...) in a dense-matrix
  6. representation.
  7. Basic usage:
  8. >>> from sympy import ZZ, QQ
  9. >>> from sympy.polys.matrices.ddm import DDM
  10. >>> A = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ)
  11. >>> A.shape
  12. (2, 2)
  13. >>> A
  14. [[0, 1], [-1, 0]]
  15. >>> type(A)
  16. <class 'sympy.polys.matrices.ddm.DDM'>
  17. >>> A @ A
  18. [[-1, 0], [0, -1]]
  19. The ddm_* functions are designed to operate on DDM as well as on an ordinary
  20. list of lists:
  21. >>> from sympy.polys.matrices.dense import ddm_idet
  22. >>> ddm_idet(A, QQ)
  23. 1
  24. >>> ddm_idet([[0, 1], [-1, 0]], QQ)
  25. 1
  26. >>> A
  27. [[-1, 0], [0, -1]]
  28. Note that ddm_idet modifies the input matrix in-place. It is recommended to
  29. use the DDM.det method as a friendlier interface to this instead which takes
  30. care of copying the matrix:
  31. >>> B = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ)
  32. >>> B.det()
  33. 1
  34. Normally DDM would not be used directly and is just part of the internal
  35. representation of DomainMatrix which adds further functionality including e.g.
  36. unifying domains.
  37. The dense format used by DDM is a list of lists of elements e.g. the 2x2
  38. identity matrix is like [[1, 0], [0, 1]]. The DDM class itself is a subclass
  39. of list and its list items are plain lists. Elements are accessed as e.g.
  40. ddm[i][j] where ddm[i] gives the ith row and ddm[i][j] gets the element in the
  41. jth column of that row. Subclassing list makes e.g. iteration and indexing
  42. very efficient. We do not override __getitem__ because it would lose that
  43. benefit.
  44. The core routines are implemented by the ddm_* functions defined in dense.py.
  45. Those functions are intended to be able to operate on a raw list-of-lists
  46. representation of matrices with most functions operating in-place. The DDM
  47. class takes care of copying etc and also stores a Domain object associated
  48. with its elements. This makes it possible to implement things like A + B with
  49. domain checking and also shape checking so that the list of lists
  50. representation is friendlier.
  51. """
  52. from itertools import chain
  53. from .exceptions import DMBadInputError, DMShapeError, DMDomainError
  54. from .dense import (
  55. ddm_transpose,
  56. ddm_iadd,
  57. ddm_isub,
  58. ddm_ineg,
  59. ddm_imul,
  60. ddm_irmul,
  61. ddm_imatmul,
  62. ddm_irref,
  63. ddm_idet,
  64. ddm_iinv,
  65. ddm_ilu_split,
  66. ddm_ilu_solve,
  67. ddm_berk,
  68. )
  69. from sympy.polys.domains import QQ
  70. from .lll import ddm_lll, ddm_lll_transform
  71. class DDM(list):
  72. """Dense matrix based on polys domain elements
  73. This is a list subclass and is a wrapper for a list of lists that supports
  74. basic matrix arithmetic +, -, *, **.
  75. """
  76. fmt = 'dense'
  77. def __init__(self, rowslist, shape, domain):
  78. super().__init__(rowslist)
  79. self.shape = self.rows, self.cols = m, n = shape
  80. self.domain = domain
  81. if not (len(self) == m and all(len(row) == n for row in self)):
  82. raise DMBadInputError("Inconsistent row-list/shape")
  83. def getitem(self, i, j):
  84. return self[i][j]
  85. def setitem(self, i, j, value):
  86. self[i][j] = value
  87. def extract_slice(self, slice1, slice2):
  88. ddm = [row[slice2] for row in self[slice1]]
  89. rows = len(ddm)
  90. cols = len(ddm[0]) if ddm else len(range(self.shape[1])[slice2])
  91. return DDM(ddm, (rows, cols), self.domain)
  92. def extract(self, rows, cols):
  93. ddm = []
  94. for i in rows:
  95. rowi = self[i]
  96. ddm.append([rowi[j] for j in cols])
  97. return DDM(ddm, (len(rows), len(cols)), self.domain)
  98. def to_list(self):
  99. return list(self)
  100. def to_list_flat(self):
  101. flat = []
  102. for row in self:
  103. flat.extend(row)
  104. return flat
  105. def flatiter(self):
  106. return chain.from_iterable(self)
  107. def flat(self):
  108. items = []
  109. for row in self:
  110. items.extend(row)
  111. return items
  112. def to_dok(self):
  113. return {(i, j): e for i, row in enumerate(self) for j, e in enumerate(row)}
  114. def to_ddm(self):
  115. return self
  116. def to_sdm(self):
  117. return SDM.from_list(self, self.shape, self.domain)
  118. def convert_to(self, K):
  119. Kold = self.domain
  120. if K == Kold:
  121. return self.copy()
  122. rows = ([K.convert_from(e, Kold) for e in row] for row in self)
  123. return DDM(rows, self.shape, K)
  124. def __str__(self):
  125. rowsstr = ['[%s]' % ', '.join(map(str, row)) for row in self]
  126. return '[%s]' % ', '.join(rowsstr)
  127. def __repr__(self):
  128. cls = type(self).__name__
  129. rows = list.__repr__(self)
  130. return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain)
  131. def __eq__(self, other):
  132. if not isinstance(other, DDM):
  133. return False
  134. return (super().__eq__(other) and self.domain == other.domain)
  135. def __ne__(self, other):
  136. return not self.__eq__(other)
  137. @classmethod
  138. def zeros(cls, shape, domain):
  139. z = domain.zero
  140. m, n = shape
  141. rowslist = ([z] * n for _ in range(m))
  142. return DDM(rowslist, shape, domain)
  143. @classmethod
  144. def ones(cls, shape, domain):
  145. one = domain.one
  146. m, n = shape
  147. rowlist = ([one] * n for _ in range(m))
  148. return DDM(rowlist, shape, domain)
  149. @classmethod
  150. def eye(cls, size, domain):
  151. one = domain.one
  152. ddm = cls.zeros((size, size), domain)
  153. for i in range(size):
  154. ddm[i][i] = one
  155. return ddm
  156. def copy(self):
  157. copyrows = (row[:] for row in self)
  158. return DDM(copyrows, self.shape, self.domain)
  159. def transpose(self):
  160. rows, cols = self.shape
  161. if rows:
  162. ddmT = ddm_transpose(self)
  163. else:
  164. ddmT = [[]] * cols
  165. return DDM(ddmT, (cols, rows), self.domain)
  166. def __add__(a, b):
  167. if not isinstance(b, DDM):
  168. return NotImplemented
  169. return a.add(b)
  170. def __sub__(a, b):
  171. if not isinstance(b, DDM):
  172. return NotImplemented
  173. return a.sub(b)
  174. def __neg__(a):
  175. return a.neg()
  176. def __mul__(a, b):
  177. if b in a.domain:
  178. return a.mul(b)
  179. else:
  180. return NotImplemented
  181. def __rmul__(a, b):
  182. if b in a.domain:
  183. return a.mul(b)
  184. else:
  185. return NotImplemented
  186. def __matmul__(a, b):
  187. if isinstance(b, DDM):
  188. return a.matmul(b)
  189. else:
  190. return NotImplemented
  191. @classmethod
  192. def _check(cls, a, op, b, ashape, bshape):
  193. if a.domain != b.domain:
  194. msg = "Domain mismatch: %s %s %s" % (a.domain, op, b.domain)
  195. raise DMDomainError(msg)
  196. if ashape != bshape:
  197. msg = "Shape mismatch: %s %s %s" % (a.shape, op, b.shape)
  198. raise DMShapeError(msg)
  199. def add(a, b):
  200. """a + b"""
  201. a._check(a, '+', b, a.shape, b.shape)
  202. c = a.copy()
  203. ddm_iadd(c, b)
  204. return c
  205. def sub(a, b):
  206. """a - b"""
  207. a._check(a, '-', b, a.shape, b.shape)
  208. c = a.copy()
  209. ddm_isub(c, b)
  210. return c
  211. def neg(a):
  212. """-a"""
  213. b = a.copy()
  214. ddm_ineg(b)
  215. return b
  216. def mul(a, b):
  217. c = a.copy()
  218. ddm_imul(c, b)
  219. return c
  220. def rmul(a, b):
  221. c = a.copy()
  222. ddm_irmul(c, b)
  223. return c
  224. def matmul(a, b):
  225. """a @ b (matrix product)"""
  226. m, o = a.shape
  227. o2, n = b.shape
  228. a._check(a, '*', b, o, o2)
  229. c = a.zeros((m, n), a.domain)
  230. ddm_imatmul(c, a, b)
  231. return c
  232. def mul_elementwise(a, b):
  233. assert a.shape == b.shape
  234. assert a.domain == b.domain
  235. c = [[aij * bij for aij, bij in zip(ai, bi)] for ai, bi in zip(a, b)]
  236. return DDM(c, a.shape, a.domain)
  237. def hstack(A, *B):
  238. """Horizontally stacks :py:class:`~.DDM` matrices.
  239. Examples
  240. ========
  241. >>> from sympy import ZZ
  242. >>> from sympy.polys.matrices.sdm import DDM
  243. >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ)
  244. >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ)
  245. >>> A.hstack(B)
  246. [[1, 2, 5, 6], [3, 4, 7, 8]]
  247. >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ)
  248. >>> A.hstack(B, C)
  249. [[1, 2, 5, 6, 9, 10], [3, 4, 7, 8, 11, 12]]
  250. """
  251. Anew = list(A.copy())
  252. rows, cols = A.shape
  253. domain = A.domain
  254. for Bk in B:
  255. Bkrows, Bkcols = Bk.shape
  256. assert Bkrows == rows
  257. assert Bk.domain == domain
  258. cols += Bkcols
  259. for i, Bki in enumerate(Bk):
  260. Anew[i].extend(Bki)
  261. return DDM(Anew, (rows, cols), A.domain)
  262. def vstack(A, *B):
  263. """Vertically stacks :py:class:`~.DDM` matrices.
  264. Examples
  265. ========
  266. >>> from sympy import ZZ
  267. >>> from sympy.polys.matrices.sdm import DDM
  268. >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ)
  269. >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ)
  270. >>> A.vstack(B)
  271. [[1, 2], [3, 4], [5, 6], [7, 8]]
  272. >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ)
  273. >>> A.vstack(B, C)
  274. [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]
  275. """
  276. Anew = list(A.copy())
  277. rows, cols = A.shape
  278. domain = A.domain
  279. for Bk in B:
  280. Bkrows, Bkcols = Bk.shape
  281. assert Bkcols == cols
  282. assert Bk.domain == domain
  283. rows += Bkrows
  284. Anew.extend(Bk.copy())
  285. return DDM(Anew, (rows, cols), A.domain)
  286. def applyfunc(self, func, domain):
  287. elements = (list(map(func, row)) for row in self)
  288. return DDM(elements, self.shape, domain)
  289. def scc(a):
  290. """Strongly connected components of a square matrix *a*.
  291. Examples
  292. ========
  293. >>> from sympy import ZZ
  294. >>> from sympy.polys.matrices.sdm import DDM
  295. >>> A = DDM([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(1)]], (2, 2), ZZ)
  296. >>> A.scc()
  297. [[0], [1]]
  298. See also
  299. ========
  300. sympy.polys.matrices.domainmatrix.DomainMatrix.scc
  301. """
  302. return a.to_sdm().scc()
  303. def rref(a):
  304. """Reduced-row echelon form of a and list of pivots"""
  305. b = a.copy()
  306. K = a.domain
  307. partial_pivot = K.is_RealField or K.is_ComplexField
  308. pivots = ddm_irref(b, _partial_pivot=partial_pivot)
  309. return b, pivots
  310. def nullspace(a):
  311. rref, pivots = a.rref()
  312. rows, cols = a.shape
  313. domain = a.domain
  314. basis = []
  315. nonpivots = []
  316. for i in range(cols):
  317. if i in pivots:
  318. continue
  319. nonpivots.append(i)
  320. vec = [domain.one if i == j else domain.zero for j in range(cols)]
  321. for ii, jj in enumerate(pivots):
  322. vec[jj] -= rref[ii][i]
  323. basis.append(vec)
  324. return DDM(basis, (len(basis), cols), domain), nonpivots
  325. def particular(a):
  326. return a.to_sdm().particular().to_ddm()
  327. def det(a):
  328. """Determinant of a"""
  329. m, n = a.shape
  330. if m != n:
  331. raise DMShapeError("Determinant of non-square matrix")
  332. b = a.copy()
  333. K = b.domain
  334. deta = ddm_idet(b, K)
  335. return deta
  336. def inv(a):
  337. """Inverse of a"""
  338. m, n = a.shape
  339. if m != n:
  340. raise DMShapeError("Determinant of non-square matrix")
  341. ainv = a.copy()
  342. K = a.domain
  343. ddm_iinv(ainv, a, K)
  344. return ainv
  345. def lu(a):
  346. """L, U decomposition of a"""
  347. m, n = a.shape
  348. K = a.domain
  349. U = a.copy()
  350. L = a.eye(m, K)
  351. swaps = ddm_ilu_split(L, U, K)
  352. return L, U, swaps
  353. def lu_solve(a, b):
  354. """x where a*x = b"""
  355. m, n = a.shape
  356. m2, o = b.shape
  357. a._check(a, 'lu_solve', b, m, m2)
  358. L, U, swaps = a.lu()
  359. x = a.zeros((n, o), a.domain)
  360. ddm_ilu_solve(x, L, U, swaps, b)
  361. return x
  362. def charpoly(a):
  363. """Coefficients of characteristic polynomial of a"""
  364. K = a.domain
  365. m, n = a.shape
  366. if m != n:
  367. raise DMShapeError("Charpoly of non-square matrix")
  368. vec = ddm_berk(a, K)
  369. coeffs = [vec[i][0] for i in range(n+1)]
  370. return coeffs
  371. def is_zero_matrix(self):
  372. """
  373. Says whether this matrix has all zero entries.
  374. """
  375. zero = self.domain.zero
  376. return all(Mij == zero for Mij in self.flatiter())
  377. def is_upper(self):
  378. """
  379. Says whether this matrix is upper-triangular. True can be returned
  380. even if the matrix is not square.
  381. """
  382. zero = self.domain.zero
  383. return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[:i])
  384. def is_lower(self):
  385. """
  386. Says whether this matrix is lower-triangular. True can be returned
  387. even if the matrix is not square.
  388. """
  389. zero = self.domain.zero
  390. return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[i+1:])
  391. def lll(A, delta=QQ(3, 4)):
  392. return ddm_lll(A, delta=delta)
  393. def lll_transform(A, delta=QQ(3, 4)):
  394. return ddm_lll_transform(A, delta=delta)
  395. from .sdm import SDM