sdm.py 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. """
  2. Module for the SDM class.
  3. """
  4. from operator import add, neg, pos, sub, mul
  5. from collections import defaultdict
  6. from sympy.utilities.iterables import _strongly_connected_components
  7. from .exceptions import DMBadInputError, DMDomainError, DMShapeError
  8. from .ddm import DDM
  9. from .lll import ddm_lll, ddm_lll_transform
  10. from sympy.polys.domains import QQ
  11. class SDM(dict):
  12. r"""Sparse matrix based on polys domain elements
  13. This is a dict subclass and is a wrapper for a dict of dicts that supports
  14. basic matrix arithmetic +, -, *, **.
  15. In order to create a new :py:class:`~.SDM`, a dict
  16. of dicts mapping non-zero elements to their
  17. corresponding row and column in the matrix is needed.
  18. We also need to specify the shape and :py:class:`~.Domain`
  19. of our :py:class:`~.SDM` object.
  20. We declare a 2x2 :py:class:`~.SDM` matrix belonging
  21. to QQ domain as shown below.
  22. The 2x2 Matrix in the example is
  23. .. math::
  24. A = \left[\begin{array}{ccc}
  25. 0 & \frac{1}{2} \\
  26. 0 & 0 \end{array} \right]
  27. >>> from sympy.polys.matrices.sdm import SDM
  28. >>> from sympy import QQ
  29. >>> elemsdict = {0:{1:QQ(1, 2)}}
  30. >>> A = SDM(elemsdict, (2, 2), QQ)
  31. >>> A
  32. {0: {1: 1/2}}
  33. We can manipulate :py:class:`~.SDM` the same way
  34. as a Matrix class
  35. >>> from sympy import ZZ
  36. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  37. >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
  38. >>> A + B
  39. {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}}
  40. Multiplication
  41. >>> A*B
  42. {0: {1: 8}, 1: {0: 3}}
  43. >>> A*ZZ(2)
  44. {0: {1: 4}, 1: {0: 2}}
  45. """
  46. fmt = 'sparse'
  47. def __init__(self, elemsdict, shape, domain):
  48. super().__init__(elemsdict)
  49. self.shape = self.rows, self.cols = m, n = shape
  50. self.domain = domain
  51. if not all(0 <= r < m for r in self):
  52. raise DMBadInputError("Row out of range")
  53. if not all(0 <= c < n for row in self.values() for c in row):
  54. raise DMBadInputError("Column out of range")
  55. def getitem(self, i, j):
  56. try:
  57. return self[i][j]
  58. except KeyError:
  59. m, n = self.shape
  60. if -m <= i < m and -n <= j < n:
  61. try:
  62. return self[i % m][j % n]
  63. except KeyError:
  64. return self.domain.zero
  65. else:
  66. raise IndexError("index out of range")
  67. def setitem(self, i, j, value):
  68. m, n = self.shape
  69. if not (-m <= i < m and -n <= j < n):
  70. raise IndexError("index out of range")
  71. i, j = i % m, j % n
  72. if value:
  73. try:
  74. self[i][j] = value
  75. except KeyError:
  76. self[i] = {j: value}
  77. else:
  78. rowi = self.get(i, None)
  79. if rowi is not None:
  80. try:
  81. del rowi[j]
  82. except KeyError:
  83. pass
  84. else:
  85. if not rowi:
  86. del self[i]
  87. def extract_slice(self, slice1, slice2):
  88. m, n = self.shape
  89. ri = range(m)[slice1]
  90. ci = range(n)[slice2]
  91. sdm = {}
  92. for i, row in self.items():
  93. if i in ri:
  94. row = {ci.index(j): e for j, e in row.items() if j in ci}
  95. if row:
  96. sdm[ri.index(i)] = row
  97. return self.new(sdm, (len(ri), len(ci)), self.domain)
  98. def extract(self, rows, cols):
  99. if not (self and rows and cols):
  100. return self.zeros((len(rows), len(cols)), self.domain)
  101. m, n = self.shape
  102. if not (-m <= min(rows) <= max(rows) < m):
  103. raise IndexError('Row index out of range')
  104. if not (-n <= min(cols) <= max(cols) < n):
  105. raise IndexError('Column index out of range')
  106. # rows and cols can contain duplicates e.g. M[[1, 2, 2], [0, 1]]
  107. # Build a map from row/col in self to list of rows/cols in output
  108. rowmap = defaultdict(list)
  109. colmap = defaultdict(list)
  110. for i2, i1 in enumerate(rows):
  111. rowmap[i1 % m].append(i2)
  112. for j2, j1 in enumerate(cols):
  113. colmap[j1 % n].append(j2)
  114. # Used to efficiently skip zero rows/cols
  115. rowset = set(rowmap)
  116. colset = set(colmap)
  117. sdm1 = self
  118. sdm2 = {}
  119. for i1 in rowset & set(sdm1):
  120. row1 = sdm1[i1]
  121. row2 = {}
  122. for j1 in colset & set(row1):
  123. row1_j1 = row1[j1]
  124. for j2 in colmap[j1]:
  125. row2[j2] = row1_j1
  126. if row2:
  127. for i2 in rowmap[i1]:
  128. sdm2[i2] = row2.copy()
  129. return self.new(sdm2, (len(rows), len(cols)), self.domain)
  130. def __str__(self):
  131. rowsstr = []
  132. for i, row in self.items():
  133. elemsstr = ', '.join('%s: %s' % (j, elem) for j, elem in row.items())
  134. rowsstr.append('%s: {%s}' % (i, elemsstr))
  135. return '{%s}' % ', '.join(rowsstr)
  136. def __repr__(self):
  137. cls = type(self).__name__
  138. rows = dict.__repr__(self)
  139. return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain)
  140. @classmethod
  141. def new(cls, sdm, shape, domain):
  142. """
  143. Parameters
  144. ==========
  145. sdm: A dict of dicts for non-zero elements in SDM
  146. shape: tuple representing dimension of SDM
  147. domain: Represents :py:class:`~.Domain` of SDM
  148. Returns
  149. =======
  150. An :py:class:`~.SDM` object
  151. Examples
  152. ========
  153. >>> from sympy.polys.matrices.sdm import SDM
  154. >>> from sympy import QQ
  155. >>> elemsdict = {0:{1: QQ(2)}}
  156. >>> A = SDM.new(elemsdict, (2, 2), QQ)
  157. >>> A
  158. {0: {1: 2}}
  159. """
  160. return cls(sdm, shape, domain)
  161. def copy(A):
  162. """
  163. Returns the copy of a :py:class:`~.SDM` object
  164. Examples
  165. ========
  166. >>> from sympy.polys.matrices.sdm import SDM
  167. >>> from sympy import QQ
  168. >>> elemsdict = {0:{1:QQ(2)}, 1:{}}
  169. >>> A = SDM(elemsdict, (2, 2), QQ)
  170. >>> B = A.copy()
  171. >>> B
  172. {0: {1: 2}, 1: {}}
  173. """
  174. Ac = {i: Ai.copy() for i, Ai in A.items()}
  175. return A.new(Ac, A.shape, A.domain)
  176. @classmethod
  177. def from_list(cls, ddm, shape, domain):
  178. """
  179. Parameters
  180. ==========
  181. ddm:
  182. list of lists containing domain elements
  183. shape:
  184. Dimensions of :py:class:`~.SDM` matrix
  185. domain:
  186. Represents :py:class:`~.Domain` of :py:class:`~.SDM` object
  187. Returns
  188. =======
  189. :py:class:`~.SDM` containing elements of ddm
  190. Examples
  191. ========
  192. >>> from sympy.polys.matrices.sdm import SDM
  193. >>> from sympy import QQ
  194. >>> ddm = [[QQ(1, 2), QQ(0)], [QQ(0), QQ(3, 4)]]
  195. >>> A = SDM.from_list(ddm, (2, 2), QQ)
  196. >>> A
  197. {0: {0: 1/2}, 1: {1: 3/4}}
  198. """
  199. m, n = shape
  200. if not (len(ddm) == m and all(len(row) == n for row in ddm)):
  201. raise DMBadInputError("Inconsistent row-list/shape")
  202. getrow = lambda i: {j:ddm[i][j] for j in range(n) if ddm[i][j]}
  203. irows = ((i, getrow(i)) for i in range(m))
  204. sdm = {i: row for i, row in irows if row}
  205. return cls(sdm, shape, domain)
  206. @classmethod
  207. def from_ddm(cls, ddm):
  208. """
  209. converts object of :py:class:`~.DDM` to
  210. :py:class:`~.SDM`
  211. Examples
  212. ========
  213. >>> from sympy.polys.matrices.ddm import DDM
  214. >>> from sympy.polys.matrices.sdm import SDM
  215. >>> from sympy import QQ
  216. >>> ddm = DDM( [[QQ(1, 2), 0], [0, QQ(3, 4)]], (2, 2), QQ)
  217. >>> A = SDM.from_ddm(ddm)
  218. >>> A
  219. {0: {0: 1/2}, 1: {1: 3/4}}
  220. """
  221. return cls.from_list(ddm, ddm.shape, ddm.domain)
  222. def to_list(M):
  223. """
  224. Converts a :py:class:`~.SDM` object to a list
  225. Examples
  226. ========
  227. >>> from sympy.polys.matrices.sdm import SDM
  228. >>> from sympy import QQ
  229. >>> elemsdict = {0:{1:QQ(2)}, 1:{}}
  230. >>> A = SDM(elemsdict, (2, 2), QQ)
  231. >>> A.to_list()
  232. [[0, 2], [0, 0]]
  233. """
  234. m, n = M.shape
  235. zero = M.domain.zero
  236. ddm = [[zero] * n for _ in range(m)]
  237. for i, row in M.items():
  238. for j, e in row.items():
  239. ddm[i][j] = e
  240. return ddm
  241. def to_list_flat(M):
  242. m, n = M.shape
  243. zero = M.domain.zero
  244. flat = [zero] * (m * n)
  245. for i, row in M.items():
  246. for j, e in row.items():
  247. flat[i*n + j] = e
  248. return flat
  249. def to_dok(M):
  250. return {(i, j): e for i, row in M.items() for j, e in row.items()}
  251. def to_ddm(M):
  252. """
  253. Convert a :py:class:`~.SDM` object to a :py:class:`~.DDM` object
  254. Examples
  255. ========
  256. >>> from sympy.polys.matrices.sdm import SDM
  257. >>> from sympy import QQ
  258. >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ)
  259. >>> A.to_ddm()
  260. [[0, 2], [0, 0]]
  261. """
  262. return DDM(M.to_list(), M.shape, M.domain)
  263. def to_sdm(M):
  264. return M
  265. @classmethod
  266. def zeros(cls, shape, domain):
  267. r"""
  268. Returns a :py:class:`~.SDM` of size shape,
  269. belonging to the specified domain
  270. In the example below we declare a matrix A where,
  271. .. math::
  272. A := \left[\begin{array}{ccc}
  273. 0 & 0 & 0 \\
  274. 0 & 0 & 0 \end{array} \right]
  275. >>> from sympy.polys.matrices.sdm import SDM
  276. >>> from sympy import QQ
  277. >>> A = SDM.zeros((2, 3), QQ)
  278. >>> A
  279. {}
  280. """
  281. return cls({}, shape, domain)
  282. @classmethod
  283. def ones(cls, shape, domain):
  284. one = domain.one
  285. m, n = shape
  286. row = dict(zip(range(n), [one]*n))
  287. sdm = {i: row.copy() for i in range(m)}
  288. return cls(sdm, shape, domain)
  289. @classmethod
  290. def eye(cls, shape, domain):
  291. """
  292. Returns a identity :py:class:`~.SDM` matrix of dimensions
  293. size x size, belonging to the specified domain
  294. Examples
  295. ========
  296. >>> from sympy.polys.matrices.sdm import SDM
  297. >>> from sympy import QQ
  298. >>> I = SDM.eye((2, 2), QQ)
  299. >>> I
  300. {0: {0: 1}, 1: {1: 1}}
  301. """
  302. rows, cols = shape
  303. one = domain.one
  304. sdm = {i: {i: one} for i in range(min(rows, cols))}
  305. return cls(sdm, shape, domain)
  306. @classmethod
  307. def diag(cls, diagonal, domain, shape):
  308. sdm = {i: {i: v} for i, v in enumerate(diagonal) if v}
  309. return cls(sdm, shape, domain)
  310. def transpose(M):
  311. """
  312. Returns the transpose of a :py:class:`~.SDM` matrix
  313. Examples
  314. ========
  315. >>> from sympy.polys.matrices.sdm import SDM
  316. >>> from sympy import QQ
  317. >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ)
  318. >>> A.transpose()
  319. {1: {0: 2}}
  320. """
  321. MT = sdm_transpose(M)
  322. return M.new(MT, M.shape[::-1], M.domain)
  323. def __add__(A, B):
  324. if not isinstance(B, SDM):
  325. return NotImplemented
  326. return A.add(B)
  327. def __sub__(A, B):
  328. if not isinstance(B, SDM):
  329. return NotImplemented
  330. return A.sub(B)
  331. def __neg__(A):
  332. return A.neg()
  333. def __mul__(A, B):
  334. """A * B"""
  335. if isinstance(B, SDM):
  336. return A.matmul(B)
  337. elif B in A.domain:
  338. return A.mul(B)
  339. else:
  340. return NotImplemented
  341. def __rmul__(a, b):
  342. if b in a.domain:
  343. return a.rmul(b)
  344. else:
  345. return NotImplemented
  346. def matmul(A, B):
  347. """
  348. Performs matrix multiplication of two SDM matrices
  349. Parameters
  350. ==========
  351. A, B: SDM to multiply
  352. Returns
  353. =======
  354. SDM
  355. SDM after multiplication
  356. Raises
  357. ======
  358. DomainError
  359. If domain of A does not match
  360. with that of B
  361. Examples
  362. ========
  363. >>> from sympy import ZZ
  364. >>> from sympy.polys.matrices.sdm import SDM
  365. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  366. >>> B = SDM({0:{0:ZZ(2), 1:ZZ(3)}, 1:{0:ZZ(4)}}, (2, 2), ZZ)
  367. >>> A.matmul(B)
  368. {0: {0: 8}, 1: {0: 2, 1: 3}}
  369. """
  370. if A.domain != B.domain:
  371. raise DMDomainError
  372. m, n = A.shape
  373. n2, o = B.shape
  374. if n != n2:
  375. raise DMShapeError
  376. C = sdm_matmul(A, B, A.domain, m, o)
  377. return A.new(C, (m, o), A.domain)
  378. def mul(A, b):
  379. """
  380. Multiplies each element of A with a scalar b
  381. Examples
  382. ========
  383. >>> from sympy import ZZ
  384. >>> from sympy.polys.matrices.sdm import SDM
  385. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  386. >>> A.mul(ZZ(3))
  387. {0: {1: 6}, 1: {0: 3}}
  388. """
  389. Csdm = unop_dict(A, lambda aij: aij*b)
  390. return A.new(Csdm, A.shape, A.domain)
  391. def rmul(A, b):
  392. Csdm = unop_dict(A, lambda aij: b*aij)
  393. return A.new(Csdm, A.shape, A.domain)
  394. def mul_elementwise(A, B):
  395. if A.domain != B.domain:
  396. raise DMDomainError
  397. if A.shape != B.shape:
  398. raise DMShapeError
  399. zero = A.domain.zero
  400. fzero = lambda e: zero
  401. Csdm = binop_dict(A, B, mul, fzero, fzero)
  402. return A.new(Csdm, A.shape, A.domain)
  403. def add(A, B):
  404. """
  405. Adds two :py:class:`~.SDM` matrices
  406. Examples
  407. ========
  408. >>> from sympy import ZZ
  409. >>> from sympy.polys.matrices.sdm import SDM
  410. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  411. >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
  412. >>> A.add(B)
  413. {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}}
  414. """
  415. Csdm = binop_dict(A, B, add, pos, pos)
  416. return A.new(Csdm, A.shape, A.domain)
  417. def sub(A, B):
  418. """
  419. Subtracts two :py:class:`~.SDM` matrices
  420. Examples
  421. ========
  422. >>> from sympy import ZZ
  423. >>> from sympy.polys.matrices.sdm import SDM
  424. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  425. >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ)
  426. >>> A.sub(B)
  427. {0: {0: -3, 1: 2}, 1: {0: 1, 1: -4}}
  428. """
  429. Csdm = binop_dict(A, B, sub, pos, neg)
  430. return A.new(Csdm, A.shape, A.domain)
  431. def neg(A):
  432. """
  433. Returns the negative of a :py:class:`~.SDM` matrix
  434. Examples
  435. ========
  436. >>> from sympy import ZZ
  437. >>> from sympy.polys.matrices.sdm import SDM
  438. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  439. >>> A.neg()
  440. {0: {1: -2}, 1: {0: -1}}
  441. """
  442. Csdm = unop_dict(A, neg)
  443. return A.new(Csdm, A.shape, A.domain)
  444. def convert_to(A, K):
  445. """
  446. Converts the :py:class:`~.Domain` of a :py:class:`~.SDM` matrix to K
  447. Examples
  448. ========
  449. >>> from sympy import ZZ, QQ
  450. >>> from sympy.polys.matrices.sdm import SDM
  451. >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ)
  452. >>> A.convert_to(QQ)
  453. {0: {1: 2}, 1: {0: 1}}
  454. """
  455. Kold = A.domain
  456. if K == Kold:
  457. return A.copy()
  458. Ak = unop_dict(A, lambda e: K.convert_from(e, Kold))
  459. return A.new(Ak, A.shape, K)
  460. def scc(A):
  461. """Strongly connected components of a square matrix *A*.
  462. Examples
  463. ========
  464. >>> from sympy import ZZ
  465. >>> from sympy.polys.matrices.sdm import SDM
  466. >>> A = SDM({0:{0: ZZ(2)}, 1:{1:ZZ(1)}}, (2, 2), ZZ)
  467. >>> A.scc()
  468. [[0], [1]]
  469. See also
  470. ========
  471. sympy.polys.matrices.domainmatrix.DomainMatrix.scc
  472. """
  473. rows, cols = A.shape
  474. assert rows == cols
  475. V = range(rows)
  476. Emap = {v: list(A.get(v, [])) for v in V}
  477. return _strongly_connected_components(V, Emap)
  478. def rref(A):
  479. """
  480. Returns reduced-row echelon form and list of pivots for the :py:class:`~.SDM`
  481. Examples
  482. ========
  483. >>> from sympy import QQ
  484. >>> from sympy.polys.matrices.sdm import SDM
  485. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ)
  486. >>> A.rref()
  487. ({0: {0: 1, 1: 2}}, [0])
  488. """
  489. B, pivots, _ = sdm_irref(A)
  490. return A.new(B, A.shape, A.domain), pivots
  491. def inv(A):
  492. """
  493. Returns inverse of a matrix A
  494. Examples
  495. ========
  496. >>> from sympy import QQ
  497. >>> from sympy.polys.matrices.sdm import SDM
  498. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  499. >>> A.inv()
  500. {0: {0: -2, 1: 1}, 1: {0: 3/2, 1: -1/2}}
  501. """
  502. return A.from_ddm(A.to_ddm().inv())
  503. def det(A):
  504. """
  505. Returns determinant of A
  506. Examples
  507. ========
  508. >>> from sympy import QQ
  509. >>> from sympy.polys.matrices.sdm import SDM
  510. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  511. >>> A.det()
  512. -2
  513. """
  514. return A.to_ddm().det()
  515. def lu(A):
  516. """
  517. Returns LU decomposition for a matrix A
  518. Examples
  519. ========
  520. >>> from sympy import QQ
  521. >>> from sympy.polys.matrices.sdm import SDM
  522. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  523. >>> A.lu()
  524. ({0: {0: 1}, 1: {0: 3, 1: 1}}, {0: {0: 1, 1: 2}, 1: {1: -2}}, [])
  525. """
  526. L, U, swaps = A.to_ddm().lu()
  527. return A.from_ddm(L), A.from_ddm(U), swaps
  528. def lu_solve(A, b):
  529. """
  530. Uses LU decomposition to solve Ax = b,
  531. Examples
  532. ========
  533. >>> from sympy import QQ
  534. >>> from sympy.polys.matrices.sdm import SDM
  535. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  536. >>> b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ)
  537. >>> A.lu_solve(b)
  538. {1: {0: 1/2}}
  539. """
  540. return A.from_ddm(A.to_ddm().lu_solve(b.to_ddm()))
  541. def nullspace(A):
  542. """
  543. Returns nullspace for a :py:class:`~.SDM` matrix A
  544. Examples
  545. ========
  546. >>> from sympy import QQ
  547. >>> from sympy.polys.matrices.sdm import SDM
  548. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ)
  549. >>> A.nullspace()
  550. ({0: {0: -2, 1: 1}}, [1])
  551. """
  552. ncols = A.shape[1]
  553. one = A.domain.one
  554. B, pivots, nzcols = sdm_irref(A)
  555. K, nonpivots = sdm_nullspace_from_rref(B, one, ncols, pivots, nzcols)
  556. K = dict(enumerate(K))
  557. shape = (len(K), ncols)
  558. return A.new(K, shape, A.domain), nonpivots
  559. def particular(A):
  560. ncols = A.shape[1]
  561. B, pivots, nzcols = sdm_irref(A)
  562. P = sdm_particular_from_rref(B, ncols, pivots)
  563. rep = {0:P} if P else {}
  564. return A.new(rep, (1, ncols-1), A.domain)
  565. def hstack(A, *B):
  566. """Horizontally stacks :py:class:`~.SDM` matrices.
  567. Examples
  568. ========
  569. >>> from sympy import ZZ
  570. >>> from sympy.polys.matrices.sdm import SDM
  571. >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ)
  572. >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ)
  573. >>> A.hstack(B)
  574. {0: {0: 1, 1: 2, 2: 5, 3: 6}, 1: {0: 3, 1: 4, 2: 7, 3: 8}}
  575. >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ)
  576. >>> A.hstack(B, C)
  577. {0: {0: 1, 1: 2, 2: 5, 3: 6, 4: 9, 5: 10}, 1: {0: 3, 1: 4, 2: 7, 3: 8, 4: 11, 5: 12}}
  578. """
  579. Anew = dict(A.copy())
  580. rows, cols = A.shape
  581. domain = A.domain
  582. for Bk in B:
  583. Bkrows, Bkcols = Bk.shape
  584. assert Bkrows == rows
  585. assert Bk.domain == domain
  586. for i, Bki in Bk.items():
  587. Ai = Anew.get(i, None)
  588. if Ai is None:
  589. Anew[i] = Ai = {}
  590. for j, Bkij in Bki.items():
  591. Ai[j + cols] = Bkij
  592. cols += Bkcols
  593. return A.new(Anew, (rows, cols), A.domain)
  594. def vstack(A, *B):
  595. """Vertically stacks :py:class:`~.SDM` matrices.
  596. Examples
  597. ========
  598. >>> from sympy import ZZ
  599. >>> from sympy.polys.matrices.sdm import SDM
  600. >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ)
  601. >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ)
  602. >>> A.vstack(B)
  603. {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}}
  604. >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ)
  605. >>> A.vstack(B, C)
  606. {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}, 4: {0: 9, 1: 10}, 5: {0: 11, 1: 12}}
  607. """
  608. Anew = dict(A.copy())
  609. rows, cols = A.shape
  610. domain = A.domain
  611. for Bk in B:
  612. Bkrows, Bkcols = Bk.shape
  613. assert Bkcols == cols
  614. assert Bk.domain == domain
  615. for i, Bki in Bk.items():
  616. Anew[i + rows] = Bki
  617. rows += Bkrows
  618. return A.new(Anew, (rows, cols), A.domain)
  619. def applyfunc(self, func, domain):
  620. sdm = {i: {j: func(e) for j, e in row.items()} for i, row in self.items()}
  621. return self.new(sdm, self.shape, domain)
  622. def charpoly(A):
  623. """
  624. Returns the coefficients of the characteristic polynomial
  625. of the :py:class:`~.SDM` matrix. These elements will be domain elements.
  626. The domain of the elements will be same as domain of the :py:class:`~.SDM`.
  627. Examples
  628. ========
  629. >>> from sympy import QQ, Symbol
  630. >>> from sympy.polys.matrices.sdm import SDM
  631. >>> from sympy.polys import Poly
  632. >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ)
  633. >>> A.charpoly()
  634. [1, -5, -2]
  635. We can create a polynomial using the
  636. coefficients using :py:class:`~.Poly`
  637. >>> x = Symbol('x')
  638. >>> p = Poly(A.charpoly(), x, domain=A.domain)
  639. >>> p
  640. Poly(x**2 - 5*x - 2, x, domain='QQ')
  641. """
  642. return A.to_ddm().charpoly()
  643. def is_zero_matrix(self):
  644. """
  645. Says whether this matrix has all zero entries.
  646. """
  647. return not self
  648. def is_upper(self):
  649. """
  650. Says whether this matrix is upper-triangular. True can be returned
  651. even if the matrix is not square.
  652. """
  653. return all(i <= j for i, row in self.items() for j in row)
  654. def is_lower(self):
  655. """
  656. Says whether this matrix is lower-triangular. True can be returned
  657. even if the matrix is not square.
  658. """
  659. return all(i >= j for i, row in self.items() for j in row)
  660. def lll(A, delta=QQ(3, 4)):
  661. return A.from_ddm(ddm_lll(A.to_ddm(), delta=delta))
  662. def lll_transform(A, delta=QQ(3, 4)):
  663. reduced, transform = ddm_lll_transform(A.to_ddm(), delta=delta)
  664. return A.from_ddm(reduced), A.from_ddm(transform)
  665. def binop_dict(A, B, fab, fa, fb):
  666. Anz, Bnz = set(A), set(B)
  667. C = {}
  668. for i in Anz & Bnz:
  669. Ai, Bi = A[i], B[i]
  670. Ci = {}
  671. Anzi, Bnzi = set(Ai), set(Bi)
  672. for j in Anzi & Bnzi:
  673. Cij = fab(Ai[j], Bi[j])
  674. if Cij:
  675. Ci[j] = Cij
  676. for j in Anzi - Bnzi:
  677. Cij = fa(Ai[j])
  678. if Cij:
  679. Ci[j] = Cij
  680. for j in Bnzi - Anzi:
  681. Cij = fb(Bi[j])
  682. if Cij:
  683. Ci[j] = Cij
  684. if Ci:
  685. C[i] = Ci
  686. for i in Anz - Bnz:
  687. Ai = A[i]
  688. Ci = {}
  689. for j, Aij in Ai.items():
  690. Cij = fa(Aij)
  691. if Cij:
  692. Ci[j] = Cij
  693. if Ci:
  694. C[i] = Ci
  695. for i in Bnz - Anz:
  696. Bi = B[i]
  697. Ci = {}
  698. for j, Bij in Bi.items():
  699. Cij = fb(Bij)
  700. if Cij:
  701. Ci[j] = Cij
  702. if Ci:
  703. C[i] = Ci
  704. return C
  705. def unop_dict(A, f):
  706. B = {}
  707. for i, Ai in A.items():
  708. Bi = {}
  709. for j, Aij in Ai.items():
  710. Bij = f(Aij)
  711. if Bij:
  712. Bi[j] = Bij
  713. if Bi:
  714. B[i] = Bi
  715. return B
  716. def sdm_transpose(M):
  717. MT = {}
  718. for i, Mi in M.items():
  719. for j, Mij in Mi.items():
  720. try:
  721. MT[j][i] = Mij
  722. except KeyError:
  723. MT[j] = {i: Mij}
  724. return MT
  725. def sdm_matmul(A, B, K, m, o):
  726. #
  727. # Should be fast if A and B are very sparse.
  728. # Consider e.g. A = B = eye(1000).
  729. #
  730. # The idea here is that we compute C = A*B in terms of the rows of C and
  731. # B since the dict of dicts representation naturally stores the matrix as
  732. # rows. The ith row of C (Ci) is equal to the sum of Aik * Bk where Bk is
  733. # the kth row of B. The algorithm below loops over each nonzero element
  734. # Aik of A and if the corresponding row Bj is nonzero then we do
  735. # Ci += Aik * Bk.
  736. # To make this more efficient we don't need to loop over all elements Aik.
  737. # Instead for each row Ai we compute the intersection of the nonzero
  738. # columns in Ai with the nonzero rows in B. That gives the k such that
  739. # Aik and Bk are both nonzero. In Python the intersection of two sets
  740. # of int can be computed very efficiently.
  741. #
  742. if K.is_EXRAW:
  743. return sdm_matmul_exraw(A, B, K, m, o)
  744. C = {}
  745. B_knz = set(B)
  746. for i, Ai in A.items():
  747. Ci = {}
  748. Ai_knz = set(Ai)
  749. for k in Ai_knz & B_knz:
  750. Aik = Ai[k]
  751. for j, Bkj in B[k].items():
  752. Cij = Ci.get(j, None)
  753. if Cij is not None:
  754. Cij = Cij + Aik * Bkj
  755. if Cij:
  756. Ci[j] = Cij
  757. else:
  758. Ci.pop(j)
  759. else:
  760. Cij = Aik * Bkj
  761. if Cij:
  762. Ci[j] = Cij
  763. if Ci:
  764. C[i] = Ci
  765. return C
  766. def sdm_matmul_exraw(A, B, K, m, o):
  767. #
  768. # Like sdm_matmul above except that:
  769. #
  770. # - Handles cases like 0*oo -> nan (sdm_matmul skips multipication by zero)
  771. # - Uses K.sum (Add(*items)) for efficient addition of Expr
  772. #
  773. zero = K.zero
  774. C = {}
  775. B_knz = set(B)
  776. for i, Ai in A.items():
  777. Ci_list = defaultdict(list)
  778. Ai_knz = set(Ai)
  779. # Nonzero row/column pair
  780. for k in Ai_knz & B_knz:
  781. Aik = Ai[k]
  782. if zero * Aik == zero:
  783. # This is the main inner loop:
  784. for j, Bkj in B[k].items():
  785. Ci_list[j].append(Aik * Bkj)
  786. else:
  787. for j in range(o):
  788. Ci_list[j].append(Aik * B[k].get(j, zero))
  789. # Zero row in B, check for infinities in A
  790. for k in Ai_knz - B_knz:
  791. zAik = zero * Ai[k]
  792. if zAik != zero:
  793. for j in range(o):
  794. Ci_list[j].append(zAik)
  795. # Add terms using K.sum (Add(*terms)) for efficiency
  796. Ci = {}
  797. for j, Cij_list in Ci_list.items():
  798. Cij = K.sum(Cij_list)
  799. if Cij:
  800. Ci[j] = Cij
  801. if Ci:
  802. C[i] = Ci
  803. # Find all infinities in B
  804. for k, Bk in B.items():
  805. for j, Bkj in Bk.items():
  806. if zero * Bkj != zero:
  807. for i in range(m):
  808. Aik = A.get(i, {}).get(k, zero)
  809. # If Aik is not zero then this was handled above
  810. if Aik == zero:
  811. Ci = C.get(i, {})
  812. Cij = Ci.get(j, zero) + Aik * Bkj
  813. if Cij != zero:
  814. Ci[j] = Cij
  815. else: # pragma: no cover
  816. # Not sure how we could get here but let's raise an
  817. # exception just in case.
  818. raise RuntimeError
  819. C[i] = Ci
  820. return C
  821. def sdm_irref(A):
  822. """RREF and pivots of a sparse matrix *A*.
  823. Compute the reduced row echelon form (RREF) of the matrix *A* and return a
  824. list of the pivot columns. This routine does not work in place and leaves
  825. the original matrix *A* unmodified.
  826. Examples
  827. ========
  828. This routine works with a dict of dicts sparse representation of a matrix:
  829. >>> from sympy import QQ
  830. >>> from sympy.polys.matrices.sdm import sdm_irref
  831. >>> A = {0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}
  832. >>> Arref, pivots, _ = sdm_irref(A)
  833. >>> Arref
  834. {0: {0: 1}, 1: {1: 1}}
  835. >>> pivots
  836. [0, 1]
  837. The analogous calculation with :py:class:`~.Matrix` would be
  838. >>> from sympy import Matrix
  839. >>> M = Matrix([[1, 2], [3, 4]])
  840. >>> Mrref, pivots = M.rref()
  841. >>> Mrref
  842. Matrix([
  843. [1, 0],
  844. [0, 1]])
  845. >>> pivots
  846. (0, 1)
  847. Notes
  848. =====
  849. The cost of this algorithm is determined purely by the nonzero elements of
  850. the matrix. No part of the cost of any step in this algorithm depends on
  851. the number of rows or columns in the matrix. No step depends even on the
  852. number of nonzero rows apart from the primary loop over those rows. The
  853. implementation is much faster than ddm_rref for sparse matrices. In fact
  854. at the time of writing it is also (slightly) faster than the dense
  855. implementation even if the input is a fully dense matrix so it seems to be
  856. faster in all cases.
  857. The elements of the matrix should support exact division with ``/``. For
  858. example elements of any domain that is a field (e.g. ``QQ``) should be
  859. fine. No attempt is made to handle inexact arithmetic.
  860. """
  861. #
  862. # Any zeros in the matrix are not stored at all so an element is zero if
  863. # its row dict has no index at that key. A row is entirely zero if its
  864. # row index is not in the outer dict. Since rref reorders the rows and
  865. # removes zero rows we can completely discard the row indices. The first
  866. # step then copies the row dicts into a list sorted by the index of the
  867. # first nonzero column in each row.
  868. #
  869. # The algorithm then processes each row Ai one at a time. Previously seen
  870. # rows are used to cancel their pivot columns from Ai. Then a pivot from
  871. # Ai is chosen and is cancelled from all previously seen rows. At this
  872. # point Ai joins the previously seen rows. Once all rows are seen all
  873. # elimination has occurred and the rows are sorted by pivot column index.
  874. #
  875. # The previously seen rows are stored in two separate groups. The reduced
  876. # group consists of all rows that have been reduced to a single nonzero
  877. # element (the pivot). There is no need to attempt any further reduction
  878. # with these. Rows that still have other nonzeros need to be considered
  879. # when Ai is cancelled from the previously seen rows.
  880. #
  881. # A dict nonzerocolumns is used to map from a column index to a set of
  882. # previously seen rows that still have a nonzero element in that column.
  883. # This means that we can cancel the pivot from Ai into the previously seen
  884. # rows without needing to loop over each row that might have a zero in
  885. # that column.
  886. #
  887. # Row dicts sorted by index of first nonzero column
  888. # (Maybe sorting is not needed/useful.)
  889. Arows = sorted((Ai.copy() for Ai in A.values()), key=min)
  890. # Each processed row has an associated pivot column.
  891. # pivot_row_map maps from the pivot column index to the row dict.
  892. # This means that we can represent a set of rows purely as a set of their
  893. # pivot indices.
  894. pivot_row_map = {}
  895. # Set of pivot indices for rows that are fully reduced to a single nonzero.
  896. reduced_pivots = set()
  897. # Set of pivot indices for rows not fully reduced
  898. nonreduced_pivots = set()
  899. # Map from column index to a set of pivot indices representing the rows
  900. # that have a nonzero at that column.
  901. nonzero_columns = defaultdict(set)
  902. while Arows:
  903. # Select pivot element and row
  904. Ai = Arows.pop()
  905. # Nonzero columns from fully reduced pivot rows can be removed
  906. Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced_pivots}
  907. # Others require full row cancellation
  908. for j in nonreduced_pivots & set(Ai):
  909. Aj = pivot_row_map[j]
  910. Aij = Ai[j]
  911. Ainz = set(Ai)
  912. Ajnz = set(Aj)
  913. for k in Ajnz - Ainz:
  914. Ai[k] = - Aij * Aj[k]
  915. Ai.pop(j)
  916. Ainz.remove(j)
  917. for k in Ajnz & Ainz:
  918. Aik = Ai[k] - Aij * Aj[k]
  919. if Aik:
  920. Ai[k] = Aik
  921. else:
  922. Ai.pop(k)
  923. # We have now cancelled previously seen pivots from Ai.
  924. # If it is zero then discard it.
  925. if not Ai:
  926. continue
  927. # Choose a pivot from Ai:
  928. j = min(Ai)
  929. Aij = Ai[j]
  930. pivot_row_map[j] = Ai
  931. Ainz = set(Ai)
  932. # Normalise the pivot row to make the pivot 1.
  933. #
  934. # This approach is slow for some domains. Cross cancellation might be
  935. # better for e.g. QQ(x) with division delayed to the final steps.
  936. Aijinv = Aij**-1
  937. for l in Ai:
  938. Ai[l] *= Aijinv
  939. # Use Aij to cancel column j from all previously seen rows
  940. for k in nonzero_columns.pop(j, ()):
  941. Ak = pivot_row_map[k]
  942. Akj = Ak[j]
  943. Aknz = set(Ak)
  944. for l in Ainz - Aknz:
  945. Ak[l] = - Akj * Ai[l]
  946. nonzero_columns[l].add(k)
  947. Ak.pop(j)
  948. Aknz.remove(j)
  949. for l in Ainz & Aknz:
  950. Akl = Ak[l] - Akj * Ai[l]
  951. if Akl:
  952. Ak[l] = Akl
  953. else:
  954. # Drop nonzero elements
  955. Ak.pop(l)
  956. if l != j:
  957. nonzero_columns[l].remove(k)
  958. if len(Ak) == 1:
  959. reduced_pivots.add(k)
  960. nonreduced_pivots.remove(k)
  961. if len(Ai) == 1:
  962. reduced_pivots.add(j)
  963. else:
  964. nonreduced_pivots.add(j)
  965. for l in Ai:
  966. if l != j:
  967. nonzero_columns[l].add(j)
  968. # All done!
  969. pivots = sorted(reduced_pivots | nonreduced_pivots)
  970. pivot2row = {p: n for n, p in enumerate(pivots)}
  971. nonzero_columns = {c: {pivot2row[p] for p in s} for c, s in nonzero_columns.items()}
  972. rows = [pivot_row_map[i] for i in pivots]
  973. rref = dict(enumerate(rows))
  974. return rref, pivots, nonzero_columns
  975. def sdm_nullspace_from_rref(A, one, ncols, pivots, nonzero_cols):
  976. """Get nullspace from A which is in RREF"""
  977. nonpivots = sorted(set(range(ncols)) - set(pivots))
  978. K = []
  979. for j in nonpivots:
  980. Kj = {j:one}
  981. for i in nonzero_cols.get(j, ()):
  982. Kj[pivots[i]] = -A[i][j]
  983. K.append(Kj)
  984. return K, nonpivots
  985. def sdm_particular_from_rref(A, ncols, pivots):
  986. """Get a particular solution from A which is in RREF"""
  987. P = {}
  988. for i, j in enumerate(pivots):
  989. Ain = A[i].get(ncols-1, None)
  990. if Ain is not None:
  991. P[j] = Ain / A[i][j]
  992. return P