traversal.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from .basic import Basic
  2. from .sorting import ordered
  3. from .sympify import sympify
  4. from sympy.utilities.iterables import iterable
  5. def iterargs(expr):
  6. """Yield the args of a Basic object in a breadth-first traversal.
  7. Depth-traversal stops if `arg.args` is either empty or is not
  8. an iterable.
  9. Examples
  10. ========
  11. >>> from sympy import Integral, Function
  12. >>> from sympy.abc import x
  13. >>> f = Function('f')
  14. >>> from sympy.core.traversal import iterargs
  15. >>> list(iterargs(Integral(f(x), (f(x), 1))))
  16. [Integral(f(x), (f(x), 1)), f(x), (f(x), 1), x, f(x), 1, x]
  17. See Also
  18. ========
  19. iterfreeargs, preorder_traversal
  20. """
  21. args = [expr]
  22. for i in args:
  23. yield i
  24. try:
  25. args.extend(i.args)
  26. except TypeError:
  27. pass # for cases like f being an arg
  28. def iterfreeargs(expr, _first=True):
  29. """Yield the args of a Basic object in a breadth-first traversal.
  30. Depth-traversal stops if `arg.args` is either empty or is not
  31. an iterable. The bound objects of an expression will be returned
  32. as canonical variables.
  33. Examples
  34. ========
  35. >>> from sympy import Integral, Function
  36. >>> from sympy.abc import x
  37. >>> f = Function('f')
  38. >>> from sympy.core.traversal import iterfreeargs
  39. >>> list(iterfreeargs(Integral(f(x), (f(x), 1))))
  40. [Integral(f(x), (f(x), 1)), 1]
  41. See Also
  42. ========
  43. iterargs, preorder_traversal
  44. """
  45. args = [expr]
  46. for i in args:
  47. yield i
  48. if _first and hasattr(i, 'bound_symbols'):
  49. void = i.canonical_variables.values()
  50. for i in iterfreeargs(i.as_dummy(), _first=False):
  51. if not i.has(*void):
  52. yield i
  53. try:
  54. args.extend(i.args)
  55. except TypeError:
  56. pass # for cases like f being an arg
  57. class preorder_traversal:
  58. """
  59. Do a pre-order traversal of a tree.
  60. This iterator recursively yields nodes that it has visited in a pre-order
  61. fashion. That is, it yields the current node then descends through the
  62. tree breadth-first to yield all of a node's children's pre-order
  63. traversal.
  64. For an expression, the order of the traversal depends on the order of
  65. .args, which in many cases can be arbitrary.
  66. Parameters
  67. ==========
  68. node : SymPy expression
  69. The expression to traverse.
  70. keys : (default None) sort key(s)
  71. The key(s) used to sort args of Basic objects. When None, args of Basic
  72. objects are processed in arbitrary order. If key is defined, it will
  73. be passed along to ordered() as the only key(s) to use to sort the
  74. arguments; if ``key`` is simply True then the default keys of ordered
  75. will be used.
  76. Yields
  77. ======
  78. subtree : SymPy expression
  79. All of the subtrees in the tree.
  80. Examples
  81. ========
  82. >>> from sympy import preorder_traversal, symbols
  83. >>> x, y, z = symbols('x y z')
  84. The nodes are returned in the order that they are encountered unless key
  85. is given; simply passing key=True will guarantee that the traversal is
  86. unique.
  87. >>> list(preorder_traversal((x + y)*z, keys=None)) # doctest: +SKIP
  88. [z*(x + y), z, x + y, y, x]
  89. >>> list(preorder_traversal((x + y)*z, keys=True))
  90. [z*(x + y), z, x + y, x, y]
  91. """
  92. def __init__(self, node, keys=None):
  93. self._skip_flag = False
  94. self._pt = self._preorder_traversal(node, keys)
  95. def _preorder_traversal(self, node, keys):
  96. yield node
  97. if self._skip_flag:
  98. self._skip_flag = False
  99. return
  100. if isinstance(node, Basic):
  101. if not keys and hasattr(node, '_argset'):
  102. # LatticeOp keeps args as a set. We should use this if we
  103. # don't care about the order, to prevent unnecessary sorting.
  104. args = node._argset
  105. else:
  106. args = node.args
  107. if keys:
  108. if keys != True:
  109. args = ordered(args, keys, default=False)
  110. else:
  111. args = ordered(args)
  112. for arg in args:
  113. yield from self._preorder_traversal(arg, keys)
  114. elif iterable(node):
  115. for item in node:
  116. yield from self._preorder_traversal(item, keys)
  117. def skip(self):
  118. """
  119. Skip yielding current node's (last yielded node's) subtrees.
  120. Examples
  121. ========
  122. >>> from sympy import preorder_traversal, symbols
  123. >>> x, y, z = symbols('x y z')
  124. >>> pt = preorder_traversal((x + y*z)*z)
  125. >>> for i in pt:
  126. ... print(i)
  127. ... if i == x + y*z:
  128. ... pt.skip()
  129. z*(x + y*z)
  130. z
  131. x + y*z
  132. """
  133. self._skip_flag = True
  134. def __next__(self):
  135. return next(self._pt)
  136. def __iter__(self):
  137. return self
  138. def use(expr, func, level=0, args=(), kwargs={}):
  139. """
  140. Use ``func`` to transform ``expr`` at the given level.
  141. Examples
  142. ========
  143. >>> from sympy import use, expand
  144. >>> from sympy.abc import x, y
  145. >>> f = (x + y)**2*x + 1
  146. >>> use(f, expand, level=2)
  147. x*(x**2 + 2*x*y + y**2) + 1
  148. >>> expand(f)
  149. x**3 + 2*x**2*y + x*y**2 + 1
  150. """
  151. def _use(expr, level):
  152. if not level:
  153. return func(expr, *args, **kwargs)
  154. else:
  155. if expr.is_Atom:
  156. return expr
  157. else:
  158. level -= 1
  159. _args = [_use(arg, level) for arg in expr.args]
  160. return expr.__class__(*_args)
  161. return _use(sympify(expr), level)
  162. def walk(e, *target):
  163. """Iterate through the args that are the given types (target) and
  164. return a list of the args that were traversed; arguments
  165. that are not of the specified types are not traversed.
  166. Examples
  167. ========
  168. >>> from sympy.core.traversal import walk
  169. >>> from sympy import Min, Max
  170. >>> from sympy.abc import x, y, z
  171. >>> list(walk(Min(x, Max(y, Min(1, z))), Min))
  172. [Min(x, Max(y, Min(1, z)))]
  173. >>> list(walk(Min(x, Max(y, Min(1, z))), Min, Max))
  174. [Min(x, Max(y, Min(1, z))), Max(y, Min(1, z)), Min(1, z)]
  175. See Also
  176. ========
  177. bottom_up
  178. """
  179. if isinstance(e, target):
  180. yield e
  181. for i in e.args:
  182. yield from walk(i, *target)
  183. def bottom_up(rv, F, atoms=False, nonbasic=False):
  184. """Apply ``F`` to all expressions in an expression tree from the
  185. bottom up. If ``atoms`` is True, apply ``F`` even if there are no args;
  186. if ``nonbasic`` is True, try to apply ``F`` to non-Basic objects.
  187. """
  188. args = getattr(rv, 'args', None)
  189. if args is not None:
  190. if args:
  191. args = tuple([bottom_up(a, F, atoms, nonbasic) for a in args])
  192. if args != rv.args:
  193. rv = rv.func(*args)
  194. rv = F(rv)
  195. elif atoms:
  196. rv = F(rv)
  197. else:
  198. if nonbasic:
  199. try:
  200. rv = F(rv)
  201. except TypeError:
  202. pass
  203. return rv
  204. def postorder_traversal(node, keys=None):
  205. """
  206. Do a postorder traversal of a tree.
  207. This generator recursively yields nodes that it has visited in a postorder
  208. fashion. That is, it descends through the tree depth-first to yield all of
  209. a node's children's postorder traversal before yielding the node itself.
  210. Parameters
  211. ==========
  212. node : SymPy expression
  213. The expression to traverse.
  214. keys : (default None) sort key(s)
  215. The key(s) used to sort args of Basic objects. When None, args of Basic
  216. objects are processed in arbitrary order. If key is defined, it will
  217. be passed along to ordered() as the only key(s) to use to sort the
  218. arguments; if ``key`` is simply True then the default keys of
  219. ``ordered`` will be used (node count and default_sort_key).
  220. Yields
  221. ======
  222. subtree : SymPy expression
  223. All of the subtrees in the tree.
  224. Examples
  225. ========
  226. >>> from sympy import postorder_traversal
  227. >>> from sympy.abc import w, x, y, z
  228. The nodes are returned in the order that they are encountered unless key
  229. is given; simply passing key=True will guarantee that the traversal is
  230. unique.
  231. >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
  232. [z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
  233. >>> list(postorder_traversal(w + (x + y)*z, keys=True))
  234. [w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
  235. """
  236. if isinstance(node, Basic):
  237. args = node.args
  238. if keys:
  239. if keys != True:
  240. args = ordered(args, keys, default=False)
  241. else:
  242. args = ordered(args)
  243. for arg in args:
  244. yield from postorder_traversal(arg, keys)
  245. elif iterable(node):
  246. for item in node:
  247. yield from postorder_traversal(item, keys)
  248. yield node