linsolve.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #
  2. # sympy.polys.matrices.linsolve module
  3. #
  4. # This module defines the _linsolve function which is the internal workhorse
  5. # used by linsolve. This computes the solution of a system of linear equations
  6. # using the SDM sparse matrix implementation in sympy.polys.matrices.sdm. This
  7. # is a replacement for solve_lin_sys in sympy.polys.solvers which is
  8. # inefficient for large sparse systems due to the use of a PolyRing with many
  9. # generators:
  10. #
  11. # https://github.com/sympy/sympy/issues/20857
  12. #
  13. # The implementation of _linsolve here handles:
  14. #
  15. # - Extracting the coefficients from the Expr/Eq input equations.
  16. # - Constructing a domain and converting the coefficients to
  17. # that domain.
  18. # - Using the SDM.rref, SDM.nullspace etc methods to generate the full
  19. # solution working with arithmetic only in the domain of the coefficients.
  20. #
  21. # The routines here are particularly designed to be efficient for large sparse
  22. # systems of linear equations although as well as dense systems. It is
  23. # possible that for some small dense systems solve_lin_sys which uses the
  24. # dense matrix implementation DDM will be more efficient. With smaller systems
  25. # though the bulk of the time is spent just preprocessing the inputs and the
  26. # relative time spent in rref is too small to be noticeable.
  27. #
  28. from collections import defaultdict
  29. from sympy.core.add import Add
  30. from sympy.core.mul import Mul
  31. from sympy.core.singleton import S
  32. from sympy.polys.constructor import construct_domain
  33. from sympy.polys.solvers import PolyNonlinearError
  34. from .sdm import (
  35. SDM,
  36. sdm_irref,
  37. sdm_particular_from_rref,
  38. sdm_nullspace_from_rref
  39. )
  40. from sympy.utilities.misc import filldedent
  41. def _linsolve(eqs, syms):
  42. """Solve a linear system of equations.
  43. Examples
  44. ========
  45. Solve a linear system with a unique solution:
  46. >>> from sympy import symbols, Eq
  47. >>> from sympy.polys.matrices.linsolve import _linsolve
  48. >>> x, y = symbols('x, y')
  49. >>> eqs = [Eq(x + y, 1), Eq(x - y, 2)]
  50. >>> _linsolve(eqs, [x, y])
  51. {x: 3/2, y: -1/2}
  52. In the case of underdetermined systems the solution will be expressed in
  53. terms of the unknown symbols that are unconstrained:
  54. >>> _linsolve([Eq(x + y, 0)], [x, y])
  55. {x: -y, y: y}
  56. """
  57. # Number of unknowns (columns in the non-augmented matrix)
  58. nsyms = len(syms)
  59. # Convert to sparse augmented matrix (len(eqs) x (nsyms+1))
  60. eqsdict, const = _linear_eq_to_dict(eqs, syms)
  61. Aaug = sympy_dict_to_dm(eqsdict, const, syms)
  62. K = Aaug.domain
  63. # sdm_irref has issues with float matrices. This uses the ddm_rref()
  64. # function. When sdm_rref() can handle float matrices reasonably this
  65. # should be removed...
  66. if K.is_RealField or K.is_ComplexField:
  67. Aaug = Aaug.to_ddm().rref()[0].to_sdm()
  68. # Compute reduced-row echelon form (RREF)
  69. Arref, pivots, nzcols = sdm_irref(Aaug)
  70. # No solution:
  71. if pivots and pivots[-1] == nsyms:
  72. return None
  73. # Particular solution for non-homogeneous system:
  74. P = sdm_particular_from_rref(Arref, nsyms+1, pivots)
  75. # Nullspace - general solution to homogeneous system
  76. # Note: using nsyms not nsyms+1 to ignore last column
  77. V, nonpivots = sdm_nullspace_from_rref(Arref, K.one, nsyms, pivots, nzcols)
  78. # Collect together terms from particular and nullspace:
  79. sol = defaultdict(list)
  80. for i, v in P.items():
  81. sol[syms[i]].append(K.to_sympy(v))
  82. for npi, Vi in zip(nonpivots, V):
  83. sym = syms[npi]
  84. for i, v in Vi.items():
  85. sol[syms[i]].append(sym * K.to_sympy(v))
  86. # Use a single call to Add for each term:
  87. sol = {s: Add(*terms) for s, terms in sol.items()}
  88. # Fill in the zeros:
  89. zero = S.Zero
  90. for s in set(syms) - set(sol):
  91. sol[s] = zero
  92. # All done!
  93. return sol
  94. def sympy_dict_to_dm(eqs_coeffs, eqs_rhs, syms):
  95. """Convert a system of dict equations to a sparse augmented matrix"""
  96. elems = set(eqs_rhs).union(*(e.values() for e in eqs_coeffs))
  97. K, elems_K = construct_domain(elems, field=True, extension=True)
  98. elem_map = dict(zip(elems, elems_K))
  99. neqs = len(eqs_coeffs)
  100. nsyms = len(syms)
  101. sym2index = dict(zip(syms, range(nsyms)))
  102. eqsdict = []
  103. for eq, rhs in zip(eqs_coeffs, eqs_rhs):
  104. eqdict = {sym2index[s]: elem_map[c] for s, c in eq.items()}
  105. if rhs:
  106. eqdict[nsyms] = -elem_map[rhs]
  107. if eqdict:
  108. eqsdict.append(eqdict)
  109. sdm_aug = SDM(enumerate(eqsdict), (neqs, nsyms + 1), K)
  110. return sdm_aug
  111. def _linear_eq_to_dict(eqs, syms):
  112. """Convert a system Expr/Eq equations into dict form, returning
  113. the coefficient dictionaries and a list of syms-independent terms
  114. from each expression in ``eqs```.
  115. Examples
  116. ========
  117. >>> from sympy.polys.matrices.linsolve import _linear_eq_to_dict
  118. >>> from sympy.abc import x
  119. >>> _linear_eq_to_dict([2*x + 3], {x})
  120. ([{x: 2}], [3])
  121. """
  122. coeffs = []
  123. ind = []
  124. symset = set(syms)
  125. for i, e in enumerate(eqs):
  126. if e.is_Equality:
  127. coeff, terms = _lin_eq2dict(e.lhs, symset)
  128. cR, tR = _lin_eq2dict(e.rhs, symset)
  129. # there were no nonlinear errors so now
  130. # cancellation is allowed
  131. coeff -= cR
  132. for k, v in tR.items():
  133. if k in terms:
  134. terms[k] -= v
  135. else:
  136. terms[k] = -v
  137. # don't store coefficients of 0, however
  138. terms = {k: v for k, v in terms.items() if v}
  139. c, d = coeff, terms
  140. else:
  141. c, d = _lin_eq2dict(e, symset)
  142. coeffs.append(d)
  143. ind.append(c)
  144. return coeffs, ind
  145. def _lin_eq2dict(a, symset):
  146. """return (c, d) where c is the sym-independent part of ``a`` and
  147. ``d`` is an efficiently calculated dictionary mapping symbols to
  148. their coefficients. A PolyNonlinearError is raised if non-linearity
  149. is detected.
  150. The values in the dictionary will be non-zero.
  151. Examples
  152. ========
  153. >>> from sympy.polys.matrices.linsolve import _lin_eq2dict
  154. >>> from sympy.abc import x, y
  155. >>> _lin_eq2dict(x + 2*y + 3, {x, y})
  156. (3, {x: 1, y: 2})
  157. """
  158. if a in symset:
  159. return S.Zero, {a: S.One}
  160. elif a.is_Add:
  161. terms_list = defaultdict(list)
  162. coeff_list = []
  163. for ai in a.args:
  164. ci, ti = _lin_eq2dict(ai, symset)
  165. coeff_list.append(ci)
  166. for mij, cij in ti.items():
  167. terms_list[mij].append(cij)
  168. coeff = Add(*coeff_list)
  169. terms = {sym: Add(*coeffs) for sym, coeffs in terms_list.items()}
  170. return coeff, terms
  171. elif a.is_Mul:
  172. terms = terms_coeff = None
  173. coeff_list = []
  174. for ai in a.args:
  175. ci, ti = _lin_eq2dict(ai, symset)
  176. if not ti:
  177. coeff_list.append(ci)
  178. elif terms is None:
  179. terms = ti
  180. terms_coeff = ci
  181. else:
  182. # since ti is not null and we already have
  183. # a term, this is a cross term
  184. raise PolyNonlinearError(filldedent('''
  185. nonlinear cross-term: %s''' % a))
  186. coeff = Mul._from_args(coeff_list)
  187. if terms is None:
  188. return coeff, {}
  189. else:
  190. terms = {sym: coeff * c for sym, c in terms.items()}
  191. return coeff * terms_coeff, terms
  192. elif not a.has_xfree(symset):
  193. return a, {}
  194. else:
  195. raise PolyNonlinearError('nonlinear term: %s' % a)