123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- """Module with functions operating on IndexedBase, Indexed and Idx objects
- - Check shape conformance
- - Determine indices in resulting expression
- etc.
- Methods in this module could be implemented by calling methods on Expr
- objects instead. When things stabilize this could be a useful
- refactoring.
- """
- from functools import reduce
- from sympy.core.function import Function
- from sympy.functions import exp, Piecewise
- from sympy.tensor.indexed import Idx, Indexed
- from sympy.utilities import sift
- from collections import OrderedDict
- class IndexConformanceException(Exception):
- pass
- def _unique_and_repeated(inds):
- """
- Returns the unique and repeated indices. Also note, from the examples given below
- that the order of indices is maintained as given in the input.
- Examples
- ========
- >>> from sympy.tensor.index_methods import _unique_and_repeated
- >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0])
- ([2, 1, 4], [3, 0])
- """
- uniq = OrderedDict()
- for i in inds:
- if i in uniq:
- uniq[i] = 0
- else:
- uniq[i] = 1
- return sift(uniq, lambda x: uniq[x], binary=True)
- def _remove_repeated(inds):
- """
- Removes repeated objects from sequences
- Returns a set of the unique objects and a tuple of all that have been
- removed.
- Examples
- ========
- >>> from sympy.tensor.index_methods import _remove_repeated
- >>> l1 = [1, 2, 3, 2]
- >>> _remove_repeated(l1)
- ({1, 3}, (2,))
- """
- u, r = _unique_and_repeated(inds)
- return set(u), tuple(r)
- def _get_indices_Mul(expr, return_dummies=False):
- """Determine the outer indices of a Mul object.
- Examples
- ========
- >>> from sympy.tensor.index_methods import _get_indices_Mul
- >>> from sympy.tensor.indexed import IndexedBase, Idx
- >>> i, j, k = map(Idx, ['i', 'j', 'k'])
- >>> x = IndexedBase('x')
- >>> y = IndexedBase('y')
- >>> _get_indices_Mul(x[i, k]*y[j, k])
- ({i, j}, {})
- >>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True)
- ({i, j}, {}, (k,))
- """
- inds = list(map(get_indices, expr.args))
- inds, syms = list(zip(*inds))
- inds = list(map(list, inds))
- inds = list(reduce(lambda x, y: x + y, inds))
- inds, dummies = _remove_repeated(inds)
- symmetry = {}
- for s in syms:
- for pair in s:
- if pair in symmetry:
- symmetry[pair] *= s[pair]
- else:
- symmetry[pair] = s[pair]
- if return_dummies:
- return inds, symmetry, dummies
- else:
- return inds, symmetry
- def _get_indices_Pow(expr):
- """Determine outer indices of a power or an exponential.
- A power is considered a universal function, so that the indices of a Pow is
- just the collection of indices present in the expression. This may be
- viewed as a bit inconsistent in the special case:
- x[i]**2 = x[i]*x[i] (1)
- The above expression could have been interpreted as the contraction of x[i]
- with itself, but we choose instead to interpret it as a function
- lambda y: y**2
- applied to each element of x (a universal function in numpy terms). In
- order to allow an interpretation of (1) as a contraction, we need
- contravariant and covariant Idx subclasses. (FIXME: this is not yet
- implemented)
- Expressions in the base or exponent are subject to contraction as usual,
- but an index that is present in the exponent, will not be considered
- contractable with its own base. Note however, that indices in the same
- exponent can be contracted with each other.
- Examples
- ========
- >>> from sympy.tensor.index_methods import _get_indices_Pow
- >>> from sympy import Pow, exp, IndexedBase, Idx
- >>> A = IndexedBase('A')
- >>> x = IndexedBase('x')
- >>> i, j, k = map(Idx, ['i', 'j', 'k'])
- >>> _get_indices_Pow(exp(A[i, j]*x[j]))
- ({i}, {})
- >>> _get_indices_Pow(Pow(x[i], x[i]))
- ({i}, {})
- >>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i]))
- ({i}, {})
- """
- base, exp = expr.as_base_exp()
- binds, bsyms = get_indices(base)
- einds, esyms = get_indices(exp)
- inds = binds | einds
- # FIXME: symmetries from power needs to check special cases, else nothing
- symmetries = {}
- return inds, symmetries
- def _get_indices_Add(expr):
- """Determine outer indices of an Add object.
- In a sum, each term must have the same set of outer indices. A valid
- expression could be
- x(i)*y(j) - x(j)*y(i)
- But we do not allow expressions like:
- x(i)*y(j) - z(j)*z(j)
- FIXME: Add support for Numpy broadcasting
- Examples
- ========
- >>> from sympy.tensor.index_methods import _get_indices_Add
- >>> from sympy.tensor.indexed import IndexedBase, Idx
- >>> i, j, k = map(Idx, ['i', 'j', 'k'])
- >>> x = IndexedBase('x')
- >>> y = IndexedBase('y')
- >>> _get_indices_Add(x[i] + x[k]*y[i, k])
- ({i}, {})
- """
- inds = list(map(get_indices, expr.args))
- inds, syms = list(zip(*inds))
- # allow broadcast of scalars
- non_scalars = [x for x in inds if x != set()]
- if not non_scalars:
- return set(), {}
- if not all(x == non_scalars[0] for x in non_scalars[1:]):
- raise IndexConformanceException("Indices are not consistent: %s" % expr)
- if not reduce(lambda x, y: x != y or y, syms):
- symmetries = syms[0]
- else:
- # FIXME: search for symmetries
- symmetries = {}
- return non_scalars[0], symmetries
- def get_indices(expr):
- """Determine the outer indices of expression ``expr``
- By *outer* we mean indices that are not summation indices. Returns a set
- and a dict. The set contains outer indices and the dict contains
- information about index symmetries.
- Examples
- ========
- >>> from sympy.tensor.index_methods import get_indices
- >>> from sympy import symbols
- >>> from sympy.tensor import IndexedBase
- >>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
- >>> i, j, a, z = symbols('i j a z', integer=True)
- The indices of the total expression is determined, Repeated indices imply a
- summation, for instance the trace of a matrix A:
- >>> get_indices(A[i, i])
- (set(), {})
- In the case of many terms, the terms are required to have identical
- outer indices. Else an IndexConformanceException is raised.
- >>> get_indices(x[i] + A[i, j]*y[j])
- ({i}, {})
- :Exceptions:
- An IndexConformanceException means that the terms ar not compatible, e.g.
- >>> get_indices(x[i] + y[j]) #doctest: +SKIP
- (...)
- IndexConformanceException: Indices are not consistent: x(i) + y(j)
- .. warning::
- The concept of *outer* indices applies recursively, starting on the deepest
- level. This implies that dummies inside parenthesis are assumed to be
- summed first, so that the following expression is handled gracefully:
- >>> get_indices((x[i] + A[i, j]*y[j])*x[j])
- ({i, j}, {})
- This is correct and may appear convenient, but you need to be careful
- with this as SymPy will happily .expand() the product, if requested. The
- resulting expression would mix the outer ``j`` with the dummies inside
- the parenthesis, which makes it a different expression. To be on the
- safe side, it is best to avoid such ambiguities by using unique indices
- for all contractions that should be held separate.
- """
- # We call ourself recursively to determine indices of sub expressions.
- # break recursion
- if isinstance(expr, Indexed):
- c = expr.indices
- inds, dummies = _remove_repeated(c)
- return inds, {}
- elif expr is None:
- return set(), {}
- elif isinstance(expr, Idx):
- return {expr}, {}
- elif expr.is_Atom:
- return set(), {}
- # recurse via specialized functions
- else:
- if expr.is_Mul:
- return _get_indices_Mul(expr)
- elif expr.is_Add:
- return _get_indices_Add(expr)
- elif expr.is_Pow or isinstance(expr, exp):
- return _get_indices_Pow(expr)
- elif isinstance(expr, Piecewise):
- # FIXME: No support for Piecewise yet
- return set(), {}
- elif isinstance(expr, Function):
- # Support ufunc like behaviour by returning indices from arguments.
- # Functions do not interpret repeated indices across arguments
- # as summation
- ind0 = set()
- for arg in expr.args:
- ind, sym = get_indices(arg)
- ind0 |= ind
- return ind0, sym
- # this test is expensive, so it should be at the end
- elif not expr.has(Indexed):
- return set(), {}
- raise NotImplementedError(
- "FIXME: No specialized handling of type %s" % type(expr))
- def get_contraction_structure(expr):
- """Determine dummy indices of ``expr`` and describe its structure
- By *dummy* we mean indices that are summation indices.
- The structure of the expression is determined and described as follows:
- 1) A conforming summation of Indexed objects is described with a dict where
- the keys are summation indices and the corresponding values are sets
- containing all terms for which the summation applies. All Add objects
- in the SymPy expression tree are described like this.
- 2) For all nodes in the SymPy expression tree that are *not* of type Add, the
- following applies:
- If a node discovers contractions in one of its arguments, the node
- itself will be stored as a key in the dict. For that key, the
- corresponding value is a list of dicts, each of which is the result of a
- recursive call to get_contraction_structure(). The list contains only
- dicts for the non-trivial deeper contractions, omitting dicts with None
- as the one and only key.
- .. Note:: The presence of expressions among the dictionary keys indicates
- multiple levels of index contractions. A nested dict displays nested
- contractions and may itself contain dicts from a deeper level. In
- practical calculations the summation in the deepest nested level must be
- calculated first so that the outer expression can access the resulting
- indexed object.
- Examples
- ========
- >>> from sympy.tensor.index_methods import get_contraction_structure
- >>> from sympy import default_sort_key
- >>> from sympy.tensor import IndexedBase, Idx
- >>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
- >>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l'])
- >>> get_contraction_structure(x[i]*y[i] + A[j, j])
- {(i,): {x[i]*y[i]}, (j,): {A[j, j]}}
- >>> get_contraction_structure(x[i]*y[j])
- {None: {x[i]*y[j]}}
- A multiplication of contracted factors results in nested dicts representing
- the internal contractions.
- >>> d = get_contraction_structure(x[i, i]*y[j, j])
- >>> sorted(d.keys(), key=default_sort_key)
- [None, x[i, i]*y[j, j]]
- In this case, the product has no contractions:
- >>> d[None]
- {x[i, i]*y[j, j]}
- Factors are contracted "first":
- >>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key)
- [{(i,): {x[i, i]}}, {(j,): {y[j, j]}}]
- A parenthesized Add object is also returned as a nested dictionary. The
- term containing the parenthesis is a Mul with a contraction among the
- arguments, so it will be found as a key in the result. It stores the
- dictionary resulting from a recursive call on the Add expression.
- >>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j]))
- >>> sorted(d.keys(), key=default_sort_key)
- [(A[i, j]*x[j] + y[i])*x[i], (i,)]
- >>> d[(i,)]
- {(A[i, j]*x[j] + y[i])*x[i]}
- >>> d[x[i]*(A[i, j]*x[j] + y[i])]
- [{None: {y[i]}, (j,): {A[i, j]*x[j]}}]
- Powers with contractions in either base or exponent will also be found as
- keys in the dictionary, mapping to a list of results from recursive calls:
- >>> d = get_contraction_structure(A[j, j]**A[i, i])
- >>> d[None]
- {A[j, j]**A[i, i]}
- >>> nested_contractions = d[A[j, j]**A[i, i]]
- >>> nested_contractions[0]
- {(j,): {A[j, j]}}
- >>> nested_contractions[1]
- {(i,): {A[i, i]}}
- The description of the contraction structure may appear complicated when
- represented with a string in the above examples, but it is easy to iterate
- over:
- >>> from sympy import Expr
- >>> for key in d:
- ... if isinstance(key, Expr):
- ... continue
- ... for term in d[key]:
- ... if term in d:
- ... # treat deepest contraction first
- ... pass
- ... # treat outermost contactions here
- """
- # We call ourself recursively to inspect sub expressions.
- if isinstance(expr, Indexed):
- junk, key = _remove_repeated(expr.indices)
- return {key or None: {expr}}
- elif expr.is_Atom:
- return {None: {expr}}
- elif expr.is_Mul:
- junk, junk, key = _get_indices_Mul(expr, return_dummies=True)
- result = {key or None: {expr}}
- # recurse on every factor
- nested = []
- for fac in expr.args:
- facd = get_contraction_structure(fac)
- if not (None in facd and len(facd) == 1):
- nested.append(facd)
- if nested:
- result[expr] = nested
- return result
- elif expr.is_Pow or isinstance(expr, exp):
- # recurse in base and exp separately. If either has internal
- # contractions we must include ourselves as a key in the returned dict
- b, e = expr.as_base_exp()
- dbase = get_contraction_structure(b)
- dexp = get_contraction_structure(e)
- dicts = []
- for d in dbase, dexp:
- if not (None in d and len(d) == 1):
- dicts.append(d)
- result = {None: {expr}}
- if dicts:
- result[expr] = dicts
- return result
- elif expr.is_Add:
- # Note: we just collect all terms with identical summation indices, We
- # do nothing to identify equivalent terms here, as this would require
- # substitutions or pattern matching in expressions of unknown
- # complexity.
- result = {}
- for term in expr.args:
- # recurse on every term
- d = get_contraction_structure(term)
- for key in d:
- if key in result:
- result[key] |= d[key]
- else:
- result[key] = d[key]
- return result
- elif isinstance(expr, Piecewise):
- # FIXME: No support for Piecewise yet
- return {None: expr}
- elif isinstance(expr, Function):
- # Collect non-trivial contraction structures in each argument
- # We do not report repeated indices in separate arguments as a
- # contraction
- deeplist = []
- for arg in expr.args:
- deep = get_contraction_structure(arg)
- if not (None in deep and len(deep) == 1):
- deeplist.append(deep)
- d = {None: {expr}}
- if deeplist:
- d[expr] = deeplist
- return d
- # this test is expensive, so it should be at the end
- elif not expr.has(Indexed):
- return {None: {expr}}
- raise NotImplementedError(
- "FIXME: No specialized handling of type %s" % type(expr))
|