123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306 |
- from .basic import Basic
- from .sorting import ordered
- from .sympify import sympify
- from sympy.utilities.iterables import iterable
- def iterargs(expr):
- """Yield the args of a Basic object in a breadth-first traversal.
- Depth-traversal stops if `arg.args` is either empty or is not
- an iterable.
- Examples
- ========
- >>> from sympy import Integral, Function
- >>> from sympy.abc import x
- >>> f = Function('f')
- >>> from sympy.core.traversal import iterargs
- >>> list(iterargs(Integral(f(x), (f(x), 1))))
- [Integral(f(x), (f(x), 1)), f(x), (f(x), 1), x, f(x), 1, x]
- See Also
- ========
- iterfreeargs, preorder_traversal
- """
- args = [expr]
- for i in args:
- yield i
- try:
- args.extend(i.args)
- except TypeError:
- pass # for cases like f being an arg
- def iterfreeargs(expr, _first=True):
- """Yield the args of a Basic object in a breadth-first traversal.
- Depth-traversal stops if `arg.args` is either empty or is not
- an iterable. The bound objects of an expression will be returned
- as canonical variables.
- Examples
- ========
- >>> from sympy import Integral, Function
- >>> from sympy.abc import x
- >>> f = Function('f')
- >>> from sympy.core.traversal import iterfreeargs
- >>> list(iterfreeargs(Integral(f(x), (f(x), 1))))
- [Integral(f(x), (f(x), 1)), 1]
- See Also
- ========
- iterargs, preorder_traversal
- """
- args = [expr]
- for i in args:
- yield i
- if _first and hasattr(i, 'bound_symbols'):
- void = i.canonical_variables.values()
- for i in iterfreeargs(i.as_dummy(), _first=False):
- if not i.has(*void):
- yield i
- try:
- args.extend(i.args)
- except TypeError:
- pass # for cases like f being an arg
- class preorder_traversal:
- """
- Do a pre-order traversal of a tree.
- This iterator recursively yields nodes that it has visited in a pre-order
- fashion. That is, it yields the current node then descends through the
- tree breadth-first to yield all of a node's children's pre-order
- traversal.
- For an expression, the order of the traversal depends on the order of
- .args, which in many cases can be arbitrary.
- Parameters
- ==========
- node : SymPy expression
- The expression to traverse.
- keys : (default None) sort key(s)
- The key(s) used to sort args of Basic objects. When None, args of Basic
- objects are processed in arbitrary order. If key is defined, it will
- be passed along to ordered() as the only key(s) to use to sort the
- arguments; if ``key`` is simply True then the default keys of ordered
- will be used.
- Yields
- ======
- subtree : SymPy expression
- All of the subtrees in the tree.
- Examples
- ========
- >>> from sympy import preorder_traversal, symbols
- >>> x, y, z = symbols('x y z')
- The nodes are returned in the order that they are encountered unless key
- is given; simply passing key=True will guarantee that the traversal is
- unique.
- >>> list(preorder_traversal((x + y)*z, keys=None)) # doctest: +SKIP
- [z*(x + y), z, x + y, y, x]
- >>> list(preorder_traversal((x + y)*z, keys=True))
- [z*(x + y), z, x + y, x, y]
- """
- def __init__(self, node, keys=None):
- self._skip_flag = False
- self._pt = self._preorder_traversal(node, keys)
- def _preorder_traversal(self, node, keys):
- yield node
- if self._skip_flag:
- self._skip_flag = False
- return
- if isinstance(node, Basic):
- if not keys and hasattr(node, '_argset'):
- # LatticeOp keeps args as a set. We should use this if we
- # don't care about the order, to prevent unnecessary sorting.
- args = node._argset
- else:
- args = node.args
- if keys:
- if keys != True:
- args = ordered(args, keys, default=False)
- else:
- args = ordered(args)
- for arg in args:
- yield from self._preorder_traversal(arg, keys)
- elif iterable(node):
- for item in node:
- yield from self._preorder_traversal(item, keys)
- def skip(self):
- """
- Skip yielding current node's (last yielded node's) subtrees.
- Examples
- ========
- >>> from sympy import preorder_traversal, symbols
- >>> x, y, z = symbols('x y z')
- >>> pt = preorder_traversal((x + y*z)*z)
- >>> for i in pt:
- ... print(i)
- ... if i == x + y*z:
- ... pt.skip()
- z*(x + y*z)
- z
- x + y*z
- """
- self._skip_flag = True
- def __next__(self):
- return next(self._pt)
- def __iter__(self):
- return self
- def use(expr, func, level=0, args=(), kwargs={}):
- """
- Use ``func`` to transform ``expr`` at the given level.
- Examples
- ========
- >>> from sympy import use, expand
- >>> from sympy.abc import x, y
- >>> f = (x + y)**2*x + 1
- >>> use(f, expand, level=2)
- x*(x**2 + 2*x*y + y**2) + 1
- >>> expand(f)
- x**3 + 2*x**2*y + x*y**2 + 1
- """
- def _use(expr, level):
- if not level:
- return func(expr, *args, **kwargs)
- else:
- if expr.is_Atom:
- return expr
- else:
- level -= 1
- _args = [_use(arg, level) for arg in expr.args]
- return expr.__class__(*_args)
- return _use(sympify(expr), level)
- def walk(e, *target):
- """Iterate through the args that are the given types (target) and
- return a list of the args that were traversed; arguments
- that are not of the specified types are not traversed.
- Examples
- ========
- >>> from sympy.core.traversal import walk
- >>> from sympy import Min, Max
- >>> from sympy.abc import x, y, z
- >>> list(walk(Min(x, Max(y, Min(1, z))), Min))
- [Min(x, Max(y, Min(1, z)))]
- >>> list(walk(Min(x, Max(y, Min(1, z))), Min, Max))
- [Min(x, Max(y, Min(1, z))), Max(y, Min(1, z)), Min(1, z)]
- See Also
- ========
- bottom_up
- """
- if isinstance(e, target):
- yield e
- for i in e.args:
- yield from walk(i, *target)
- def bottom_up(rv, F, atoms=False, nonbasic=False):
- """Apply ``F`` to all expressions in an expression tree from the
- bottom up. If ``atoms`` is True, apply ``F`` even if there are no args;
- if ``nonbasic`` is True, try to apply ``F`` to non-Basic objects.
- """
- args = getattr(rv, 'args', None)
- if args is not None:
- if args:
- args = tuple([bottom_up(a, F, atoms, nonbasic) for a in args])
- if args != rv.args:
- rv = rv.func(*args)
- rv = F(rv)
- elif atoms:
- rv = F(rv)
- else:
- if nonbasic:
- try:
- rv = F(rv)
- except TypeError:
- pass
- return rv
- def postorder_traversal(node, keys=None):
- """
- Do a postorder traversal of a tree.
- This generator recursively yields nodes that it has visited in a postorder
- fashion. That is, it descends through the tree depth-first to yield all of
- a node's children's postorder traversal before yielding the node itself.
- Parameters
- ==========
- node : SymPy expression
- The expression to traverse.
- keys : (default None) sort key(s)
- The key(s) used to sort args of Basic objects. When None, args of Basic
- objects are processed in arbitrary order. If key is defined, it will
- be passed along to ordered() as the only key(s) to use to sort the
- arguments; if ``key`` is simply True then the default keys of
- ``ordered`` will be used (node count and default_sort_key).
- Yields
- ======
- subtree : SymPy expression
- All of the subtrees in the tree.
- Examples
- ========
- >>> from sympy import postorder_traversal
- >>> from sympy.abc import w, x, y, z
- The nodes are returned in the order that they are encountered unless key
- is given; simply passing key=True will guarantee that the traversal is
- unique.
- >>> list(postorder_traversal(w + (x + y)*z)) # doctest: +SKIP
- [z, y, x, x + y, z*(x + y), w, w + z*(x + y)]
- >>> list(postorder_traversal(w + (x + y)*z, keys=True))
- [w, z, x, y, x + y, z*(x + y), w + z*(x + y)]
- """
- if isinstance(node, Basic):
- args = node.args
- if keys:
- if keys != True:
- args = ordered(args, keys, default=False)
- else:
- args = ordered(args)
- for arg in args:
- yield from postorder_traversal(arg, keys)
- elif iterable(node):
- for item in node:
- yield from postorder_traversal(item, keys)
- yield node
|