reductions.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from types import FunctionType
  2. from .utilities import _get_intermediate_simp, _iszero, _dotprodsimp, _simplify
  3. from .determinant import _find_reasonable_pivot
  4. def _row_reduce_list(mat, rows, cols, one, iszerofunc, simpfunc,
  5. normalize_last=True, normalize=True, zero_above=True):
  6. """Row reduce a flat list representation of a matrix and return a tuple
  7. (rref_matrix, pivot_cols, swaps) where ``rref_matrix`` is a flat list,
  8. ``pivot_cols`` are the pivot columns and ``swaps`` are any row swaps that
  9. were used in the process of row reduction.
  10. Parameters
  11. ==========
  12. mat : list
  13. list of matrix elements, must be ``rows`` * ``cols`` in length
  14. rows, cols : integer
  15. number of rows and columns in flat list representation
  16. one : SymPy object
  17. represents the value one, from ``Matrix.one``
  18. iszerofunc : determines if an entry can be used as a pivot
  19. simpfunc : used to simplify elements and test if they are
  20. zero if ``iszerofunc`` returns `None`
  21. normalize_last : indicates where all row reduction should
  22. happen in a fraction-free manner and then the rows are
  23. normalized (so that the pivots are 1), or whether
  24. rows should be normalized along the way (like the naive
  25. row reduction algorithm)
  26. normalize : whether pivot rows should be normalized so that
  27. the pivot value is 1
  28. zero_above : whether entries above the pivot should be zeroed.
  29. If ``zero_above=False``, an echelon matrix will be returned.
  30. """
  31. def get_col(i):
  32. return mat[i::cols]
  33. def row_swap(i, j):
  34. mat[i*cols:(i + 1)*cols], mat[j*cols:(j + 1)*cols] = \
  35. mat[j*cols:(j + 1)*cols], mat[i*cols:(i + 1)*cols]
  36. def cross_cancel(a, i, b, j):
  37. """Does the row op row[i] = a*row[i] - b*row[j]"""
  38. q = (j - i)*cols
  39. for p in range(i*cols, (i + 1)*cols):
  40. mat[p] = isimp(a*mat[p] - b*mat[p + q])
  41. isimp = _get_intermediate_simp(_dotprodsimp)
  42. piv_row, piv_col = 0, 0
  43. pivot_cols = []
  44. swaps = []
  45. # use a fraction free method to zero above and below each pivot
  46. while piv_col < cols and piv_row < rows:
  47. pivot_offset, pivot_val, \
  48. assumed_nonzero, newly_determined = _find_reasonable_pivot(
  49. get_col(piv_col)[piv_row:], iszerofunc, simpfunc)
  50. # _find_reasonable_pivot may have simplified some things
  51. # in the process. Let's not let them go to waste
  52. for (offset, val) in newly_determined:
  53. offset += piv_row
  54. mat[offset*cols + piv_col] = val
  55. if pivot_offset is None:
  56. piv_col += 1
  57. continue
  58. pivot_cols.append(piv_col)
  59. if pivot_offset != 0:
  60. row_swap(piv_row, pivot_offset + piv_row)
  61. swaps.append((piv_row, pivot_offset + piv_row))
  62. # if we aren't normalizing last, we normalize
  63. # before we zero the other rows
  64. if normalize_last is False:
  65. i, j = piv_row, piv_col
  66. mat[i*cols + j] = one
  67. for p in range(i*cols + j + 1, (i + 1)*cols):
  68. mat[p] = isimp(mat[p] / pivot_val)
  69. # after normalizing, the pivot value is 1
  70. pivot_val = one
  71. # zero above and below the pivot
  72. for row in range(rows):
  73. # don't zero our current row
  74. if row == piv_row:
  75. continue
  76. # don't zero above the pivot unless we're told.
  77. if zero_above is False and row < piv_row:
  78. continue
  79. # if we're already a zero, don't do anything
  80. val = mat[row*cols + piv_col]
  81. if iszerofunc(val):
  82. continue
  83. cross_cancel(pivot_val, row, val, piv_row)
  84. piv_row += 1
  85. # normalize each row
  86. if normalize_last is True and normalize is True:
  87. for piv_i, piv_j in enumerate(pivot_cols):
  88. pivot_val = mat[piv_i*cols + piv_j]
  89. mat[piv_i*cols + piv_j] = one
  90. for p in range(piv_i*cols + piv_j + 1, (piv_i + 1)*cols):
  91. mat[p] = isimp(mat[p] / pivot_val)
  92. return mat, tuple(pivot_cols), tuple(swaps)
  93. # This functions is a candidate for caching if it gets implemented for matrices.
  94. def _row_reduce(M, iszerofunc, simpfunc, normalize_last=True,
  95. normalize=True, zero_above=True):
  96. mat, pivot_cols, swaps = _row_reduce_list(list(M), M.rows, M.cols, M.one,
  97. iszerofunc, simpfunc, normalize_last=normalize_last,
  98. normalize=normalize, zero_above=zero_above)
  99. return M._new(M.rows, M.cols, mat), pivot_cols, swaps
  100. def _is_echelon(M, iszerofunc=_iszero):
  101. """Returns `True` if the matrix is in echelon form. That is, all rows of
  102. zeros are at the bottom, and below each leading non-zero in a row are
  103. exclusively zeros."""
  104. if M.rows <= 0 or M.cols <= 0:
  105. return True
  106. zeros_below = all(iszerofunc(t) for t in M[1:, 0])
  107. if iszerofunc(M[0, 0]):
  108. return zeros_below and _is_echelon(M[:, 1:], iszerofunc)
  109. return zeros_below and _is_echelon(M[1:, 1:], iszerofunc)
  110. def _echelon_form(M, iszerofunc=_iszero, simplify=False, with_pivots=False):
  111. """Returns a matrix row-equivalent to ``M`` that is in echelon form. Note
  112. that echelon form of a matrix is *not* unique, however, properties like the
  113. row space and the null space are preserved.
  114. Examples
  115. ========
  116. >>> from sympy import Matrix
  117. >>> M = Matrix([[1, 2], [3, 4]])
  118. >>> M.echelon_form()
  119. Matrix([
  120. [1, 2],
  121. [0, -2]])
  122. """
  123. simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify
  124. mat, pivots, _ = _row_reduce(M, iszerofunc, simpfunc,
  125. normalize_last=True, normalize=False, zero_above=False)
  126. if with_pivots:
  127. return mat, pivots
  128. return mat
  129. # This functions is a candidate for caching if it gets implemented for matrices.
  130. def _rank(M, iszerofunc=_iszero, simplify=False):
  131. """Returns the rank of a matrix.
  132. Examples
  133. ========
  134. >>> from sympy import Matrix
  135. >>> from sympy.abc import x
  136. >>> m = Matrix([[1, 2], [x, 1 - 1/x]])
  137. >>> m.rank()
  138. 2
  139. >>> n = Matrix(3, 3, range(1, 10))
  140. >>> n.rank()
  141. 2
  142. """
  143. def _permute_complexity_right(M, iszerofunc):
  144. """Permute columns with complicated elements as
  145. far right as they can go. Since the ``sympy`` row reduction
  146. algorithms start on the left, having complexity right-shifted
  147. speeds things up.
  148. Returns a tuple (mat, perm) where perm is a permutation
  149. of the columns to perform to shift the complex columns right, and mat
  150. is the permuted matrix."""
  151. def complexity(i):
  152. # the complexity of a column will be judged by how many
  153. # element's zero-ness cannot be determined
  154. return sum(1 if iszerofunc(e) is None else 0 for e in M[:, i])
  155. complex = [(complexity(i), i) for i in range(M.cols)]
  156. perm = [j for (i, j) in sorted(complex)]
  157. return (M.permute(perm, orientation='cols'), perm)
  158. simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify
  159. # for small matrices, we compute the rank explicitly
  160. # if is_zero on elements doesn't answer the question
  161. # for small matrices, we fall back to the full routine.
  162. if M.rows <= 0 or M.cols <= 0:
  163. return 0
  164. if M.rows <= 1 or M.cols <= 1:
  165. zeros = [iszerofunc(x) for x in M]
  166. if False in zeros:
  167. return 1
  168. if M.rows == 2 and M.cols == 2:
  169. zeros = [iszerofunc(x) for x in M]
  170. if False not in zeros and None not in zeros:
  171. return 0
  172. d = M.det()
  173. if iszerofunc(d) and False in zeros:
  174. return 1
  175. if iszerofunc(d) is False:
  176. return 2
  177. mat, _ = _permute_complexity_right(M, iszerofunc=iszerofunc)
  178. _, pivots, _ = _row_reduce(mat, iszerofunc, simpfunc, normalize_last=True,
  179. normalize=False, zero_above=False)
  180. return len(pivots)
  181. def _rref(M, iszerofunc=_iszero, simplify=False, pivots=True,
  182. normalize_last=True):
  183. """Return reduced row-echelon form of matrix and indices of pivot vars.
  184. Parameters
  185. ==========
  186. iszerofunc : Function
  187. A function used for detecting whether an element can
  188. act as a pivot. ``lambda x: x.is_zero`` is used by default.
  189. simplify : Function
  190. A function used to simplify elements when looking for a pivot.
  191. By default SymPy's ``simplify`` is used.
  192. pivots : True or False
  193. If ``True``, a tuple containing the row-reduced matrix and a tuple
  194. of pivot columns is returned. If ``False`` just the row-reduced
  195. matrix is returned.
  196. normalize_last : True or False
  197. If ``True``, no pivots are normalized to `1` until after all
  198. entries above and below each pivot are zeroed. This means the row
  199. reduction algorithm is fraction free until the very last step.
  200. If ``False``, the naive row reduction procedure is used where
  201. each pivot is normalized to be `1` before row operations are
  202. used to zero above and below the pivot.
  203. Examples
  204. ========
  205. >>> from sympy import Matrix
  206. >>> from sympy.abc import x
  207. >>> m = Matrix([[1, 2], [x, 1 - 1/x]])
  208. >>> m.rref()
  209. (Matrix([
  210. [1, 0],
  211. [0, 1]]), (0, 1))
  212. >>> rref_matrix, rref_pivots = m.rref()
  213. >>> rref_matrix
  214. Matrix([
  215. [1, 0],
  216. [0, 1]])
  217. >>> rref_pivots
  218. (0, 1)
  219. ``iszerofunc`` can correct rounding errors in matrices with float
  220. values. In the following example, calling ``rref()`` leads to
  221. floating point errors, incorrectly row reducing the matrix.
  222. ``iszerofunc= lambda x: abs(x)<1e-9`` sets sufficiently small numbers
  223. to zero, avoiding this error.
  224. >>> m = Matrix([[0.9, -0.1, -0.2, 0], [-0.8, 0.9, -0.4, 0], [-0.1, -0.8, 0.6, 0]])
  225. >>> m.rref()
  226. (Matrix([
  227. [1, 0, 0, 0],
  228. [0, 1, 0, 0],
  229. [0, 0, 1, 0]]), (0, 1, 2))
  230. >>> m.rref(iszerofunc=lambda x:abs(x)<1e-9)
  231. (Matrix([
  232. [1, 0, -0.301369863013699, 0],
  233. [0, 1, -0.712328767123288, 0],
  234. [0, 0, 0, 0]]), (0, 1))
  235. Notes
  236. =====
  237. The default value of ``normalize_last=True`` can provide significant
  238. speedup to row reduction, especially on matrices with symbols. However,
  239. if you depend on the form row reduction algorithm leaves entries
  240. of the matrix, set ``noramlize_last=False``
  241. """
  242. simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify
  243. mat, pivot_cols, _ = _row_reduce(M, iszerofunc, simpfunc,
  244. normalize_last, normalize=True, zero_above=True)
  245. if pivots:
  246. mat = (mat, pivot_cols)
  247. return mat