index_methods.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. """Module with functions operating on IndexedBase, Indexed and Idx objects
  2. - Check shape conformance
  3. - Determine indices in resulting expression
  4. etc.
  5. Methods in this module could be implemented by calling methods on Expr
  6. objects instead. When things stabilize this could be a useful
  7. refactoring.
  8. """
  9. from functools import reduce
  10. from sympy.core.function import Function
  11. from sympy.functions import exp, Piecewise
  12. from sympy.tensor.indexed import Idx, Indexed
  13. from sympy.utilities import sift
  14. from collections import OrderedDict
  15. class IndexConformanceException(Exception):
  16. pass
  17. def _unique_and_repeated(inds):
  18. """
  19. Returns the unique and repeated indices. Also note, from the examples given below
  20. that the order of indices is maintained as given in the input.
  21. Examples
  22. ========
  23. >>> from sympy.tensor.index_methods import _unique_and_repeated
  24. >>> _unique_and_repeated([2, 3, 1, 3, 0, 4, 0])
  25. ([2, 1, 4], [3, 0])
  26. """
  27. uniq = OrderedDict()
  28. for i in inds:
  29. if i in uniq:
  30. uniq[i] = 0
  31. else:
  32. uniq[i] = 1
  33. return sift(uniq, lambda x: uniq[x], binary=True)
  34. def _remove_repeated(inds):
  35. """
  36. Removes repeated objects from sequences
  37. Returns a set of the unique objects and a tuple of all that have been
  38. removed.
  39. Examples
  40. ========
  41. >>> from sympy.tensor.index_methods import _remove_repeated
  42. >>> l1 = [1, 2, 3, 2]
  43. >>> _remove_repeated(l1)
  44. ({1, 3}, (2,))
  45. """
  46. u, r = _unique_and_repeated(inds)
  47. return set(u), tuple(r)
  48. def _get_indices_Mul(expr, return_dummies=False):
  49. """Determine the outer indices of a Mul object.
  50. Examples
  51. ========
  52. >>> from sympy.tensor.index_methods import _get_indices_Mul
  53. >>> from sympy.tensor.indexed import IndexedBase, Idx
  54. >>> i, j, k = map(Idx, ['i', 'j', 'k'])
  55. >>> x = IndexedBase('x')
  56. >>> y = IndexedBase('y')
  57. >>> _get_indices_Mul(x[i, k]*y[j, k])
  58. ({i, j}, {})
  59. >>> _get_indices_Mul(x[i, k]*y[j, k], return_dummies=True)
  60. ({i, j}, {}, (k,))
  61. """
  62. inds = list(map(get_indices, expr.args))
  63. inds, syms = list(zip(*inds))
  64. inds = list(map(list, inds))
  65. inds = list(reduce(lambda x, y: x + y, inds))
  66. inds, dummies = _remove_repeated(inds)
  67. symmetry = {}
  68. for s in syms:
  69. for pair in s:
  70. if pair in symmetry:
  71. symmetry[pair] *= s[pair]
  72. else:
  73. symmetry[pair] = s[pair]
  74. if return_dummies:
  75. return inds, symmetry, dummies
  76. else:
  77. return inds, symmetry
  78. def _get_indices_Pow(expr):
  79. """Determine outer indices of a power or an exponential.
  80. A power is considered a universal function, so that the indices of a Pow is
  81. just the collection of indices present in the expression. This may be
  82. viewed as a bit inconsistent in the special case:
  83. x[i]**2 = x[i]*x[i] (1)
  84. The above expression could have been interpreted as the contraction of x[i]
  85. with itself, but we choose instead to interpret it as a function
  86. lambda y: y**2
  87. applied to each element of x (a universal function in numpy terms). In
  88. order to allow an interpretation of (1) as a contraction, we need
  89. contravariant and covariant Idx subclasses. (FIXME: this is not yet
  90. implemented)
  91. Expressions in the base or exponent are subject to contraction as usual,
  92. but an index that is present in the exponent, will not be considered
  93. contractable with its own base. Note however, that indices in the same
  94. exponent can be contracted with each other.
  95. Examples
  96. ========
  97. >>> from sympy.tensor.index_methods import _get_indices_Pow
  98. >>> from sympy import Pow, exp, IndexedBase, Idx
  99. >>> A = IndexedBase('A')
  100. >>> x = IndexedBase('x')
  101. >>> i, j, k = map(Idx, ['i', 'j', 'k'])
  102. >>> _get_indices_Pow(exp(A[i, j]*x[j]))
  103. ({i}, {})
  104. >>> _get_indices_Pow(Pow(x[i], x[i]))
  105. ({i}, {})
  106. >>> _get_indices_Pow(Pow(A[i, j]*x[j], x[i]))
  107. ({i}, {})
  108. """
  109. base, exp = expr.as_base_exp()
  110. binds, bsyms = get_indices(base)
  111. einds, esyms = get_indices(exp)
  112. inds = binds | einds
  113. # FIXME: symmetries from power needs to check special cases, else nothing
  114. symmetries = {}
  115. return inds, symmetries
  116. def _get_indices_Add(expr):
  117. """Determine outer indices of an Add object.
  118. In a sum, each term must have the same set of outer indices. A valid
  119. expression could be
  120. x(i)*y(j) - x(j)*y(i)
  121. But we do not allow expressions like:
  122. x(i)*y(j) - z(j)*z(j)
  123. FIXME: Add support for Numpy broadcasting
  124. Examples
  125. ========
  126. >>> from sympy.tensor.index_methods import _get_indices_Add
  127. >>> from sympy.tensor.indexed import IndexedBase, Idx
  128. >>> i, j, k = map(Idx, ['i', 'j', 'k'])
  129. >>> x = IndexedBase('x')
  130. >>> y = IndexedBase('y')
  131. >>> _get_indices_Add(x[i] + x[k]*y[i, k])
  132. ({i}, {})
  133. """
  134. inds = list(map(get_indices, expr.args))
  135. inds, syms = list(zip(*inds))
  136. # allow broadcast of scalars
  137. non_scalars = [x for x in inds if x != set()]
  138. if not non_scalars:
  139. return set(), {}
  140. if not all(x == non_scalars[0] for x in non_scalars[1:]):
  141. raise IndexConformanceException("Indices are not consistent: %s" % expr)
  142. if not reduce(lambda x, y: x != y or y, syms):
  143. symmetries = syms[0]
  144. else:
  145. # FIXME: search for symmetries
  146. symmetries = {}
  147. return non_scalars[0], symmetries
  148. def get_indices(expr):
  149. """Determine the outer indices of expression ``expr``
  150. By *outer* we mean indices that are not summation indices. Returns a set
  151. and a dict. The set contains outer indices and the dict contains
  152. information about index symmetries.
  153. Examples
  154. ========
  155. >>> from sympy.tensor.index_methods import get_indices
  156. >>> from sympy import symbols
  157. >>> from sympy.tensor import IndexedBase
  158. >>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
  159. >>> i, j, a, z = symbols('i j a z', integer=True)
  160. The indices of the total expression is determined, Repeated indices imply a
  161. summation, for instance the trace of a matrix A:
  162. >>> get_indices(A[i, i])
  163. (set(), {})
  164. In the case of many terms, the terms are required to have identical
  165. outer indices. Else an IndexConformanceException is raised.
  166. >>> get_indices(x[i] + A[i, j]*y[j])
  167. ({i}, {})
  168. :Exceptions:
  169. An IndexConformanceException means that the terms ar not compatible, e.g.
  170. >>> get_indices(x[i] + y[j]) #doctest: +SKIP
  171. (...)
  172. IndexConformanceException: Indices are not consistent: x(i) + y(j)
  173. .. warning::
  174. The concept of *outer* indices applies recursively, starting on the deepest
  175. level. This implies that dummies inside parenthesis are assumed to be
  176. summed first, so that the following expression is handled gracefully:
  177. >>> get_indices((x[i] + A[i, j]*y[j])*x[j])
  178. ({i, j}, {})
  179. This is correct and may appear convenient, but you need to be careful
  180. with this as SymPy will happily .expand() the product, if requested. The
  181. resulting expression would mix the outer ``j`` with the dummies inside
  182. the parenthesis, which makes it a different expression. To be on the
  183. safe side, it is best to avoid such ambiguities by using unique indices
  184. for all contractions that should be held separate.
  185. """
  186. # We call ourself recursively to determine indices of sub expressions.
  187. # break recursion
  188. if isinstance(expr, Indexed):
  189. c = expr.indices
  190. inds, dummies = _remove_repeated(c)
  191. return inds, {}
  192. elif expr is None:
  193. return set(), {}
  194. elif isinstance(expr, Idx):
  195. return {expr}, {}
  196. elif expr.is_Atom:
  197. return set(), {}
  198. # recurse via specialized functions
  199. else:
  200. if expr.is_Mul:
  201. return _get_indices_Mul(expr)
  202. elif expr.is_Add:
  203. return _get_indices_Add(expr)
  204. elif expr.is_Pow or isinstance(expr, exp):
  205. return _get_indices_Pow(expr)
  206. elif isinstance(expr, Piecewise):
  207. # FIXME: No support for Piecewise yet
  208. return set(), {}
  209. elif isinstance(expr, Function):
  210. # Support ufunc like behaviour by returning indices from arguments.
  211. # Functions do not interpret repeated indices across arguments
  212. # as summation
  213. ind0 = set()
  214. for arg in expr.args:
  215. ind, sym = get_indices(arg)
  216. ind0 |= ind
  217. return ind0, sym
  218. # this test is expensive, so it should be at the end
  219. elif not expr.has(Indexed):
  220. return set(), {}
  221. raise NotImplementedError(
  222. "FIXME: No specialized handling of type %s" % type(expr))
  223. def get_contraction_structure(expr):
  224. """Determine dummy indices of ``expr`` and describe its structure
  225. By *dummy* we mean indices that are summation indices.
  226. The structure of the expression is determined and described as follows:
  227. 1) A conforming summation of Indexed objects is described with a dict where
  228. the keys are summation indices and the corresponding values are sets
  229. containing all terms for which the summation applies. All Add objects
  230. in the SymPy expression tree are described like this.
  231. 2) For all nodes in the SymPy expression tree that are *not* of type Add, the
  232. following applies:
  233. If a node discovers contractions in one of its arguments, the node
  234. itself will be stored as a key in the dict. For that key, the
  235. corresponding value is a list of dicts, each of which is the result of a
  236. recursive call to get_contraction_structure(). The list contains only
  237. dicts for the non-trivial deeper contractions, omitting dicts with None
  238. as the one and only key.
  239. .. Note:: The presence of expressions among the dictionary keys indicates
  240. multiple levels of index contractions. A nested dict displays nested
  241. contractions and may itself contain dicts from a deeper level. In
  242. practical calculations the summation in the deepest nested level must be
  243. calculated first so that the outer expression can access the resulting
  244. indexed object.
  245. Examples
  246. ========
  247. >>> from sympy.tensor.index_methods import get_contraction_structure
  248. >>> from sympy import default_sort_key
  249. >>> from sympy.tensor import IndexedBase, Idx
  250. >>> x, y, A = map(IndexedBase, ['x', 'y', 'A'])
  251. >>> i, j, k, l = map(Idx, ['i', 'j', 'k', 'l'])
  252. >>> get_contraction_structure(x[i]*y[i] + A[j, j])
  253. {(i,): {x[i]*y[i]}, (j,): {A[j, j]}}
  254. >>> get_contraction_structure(x[i]*y[j])
  255. {None: {x[i]*y[j]}}
  256. A multiplication of contracted factors results in nested dicts representing
  257. the internal contractions.
  258. >>> d = get_contraction_structure(x[i, i]*y[j, j])
  259. >>> sorted(d.keys(), key=default_sort_key)
  260. [None, x[i, i]*y[j, j]]
  261. In this case, the product has no contractions:
  262. >>> d[None]
  263. {x[i, i]*y[j, j]}
  264. Factors are contracted "first":
  265. >>> sorted(d[x[i, i]*y[j, j]], key=default_sort_key)
  266. [{(i,): {x[i, i]}}, {(j,): {y[j, j]}}]
  267. A parenthesized Add object is also returned as a nested dictionary. The
  268. term containing the parenthesis is a Mul with a contraction among the
  269. arguments, so it will be found as a key in the result. It stores the
  270. dictionary resulting from a recursive call on the Add expression.
  271. >>> d = get_contraction_structure(x[i]*(y[i] + A[i, j]*x[j]))
  272. >>> sorted(d.keys(), key=default_sort_key)
  273. [(A[i, j]*x[j] + y[i])*x[i], (i,)]
  274. >>> d[(i,)]
  275. {(A[i, j]*x[j] + y[i])*x[i]}
  276. >>> d[x[i]*(A[i, j]*x[j] + y[i])]
  277. [{None: {y[i]}, (j,): {A[i, j]*x[j]}}]
  278. Powers with contractions in either base or exponent will also be found as
  279. keys in the dictionary, mapping to a list of results from recursive calls:
  280. >>> d = get_contraction_structure(A[j, j]**A[i, i])
  281. >>> d[None]
  282. {A[j, j]**A[i, i]}
  283. >>> nested_contractions = d[A[j, j]**A[i, i]]
  284. >>> nested_contractions[0]
  285. {(j,): {A[j, j]}}
  286. >>> nested_contractions[1]
  287. {(i,): {A[i, i]}}
  288. The description of the contraction structure may appear complicated when
  289. represented with a string in the above examples, but it is easy to iterate
  290. over:
  291. >>> from sympy import Expr
  292. >>> for key in d:
  293. ... if isinstance(key, Expr):
  294. ... continue
  295. ... for term in d[key]:
  296. ... if term in d:
  297. ... # treat deepest contraction first
  298. ... pass
  299. ... # treat outermost contactions here
  300. """
  301. # We call ourself recursively to inspect sub expressions.
  302. if isinstance(expr, Indexed):
  303. junk, key = _remove_repeated(expr.indices)
  304. return {key or None: {expr}}
  305. elif expr.is_Atom:
  306. return {None: {expr}}
  307. elif expr.is_Mul:
  308. junk, junk, key = _get_indices_Mul(expr, return_dummies=True)
  309. result = {key or None: {expr}}
  310. # recurse on every factor
  311. nested = []
  312. for fac in expr.args:
  313. facd = get_contraction_structure(fac)
  314. if not (None in facd and len(facd) == 1):
  315. nested.append(facd)
  316. if nested:
  317. result[expr] = nested
  318. return result
  319. elif expr.is_Pow or isinstance(expr, exp):
  320. # recurse in base and exp separately. If either has internal
  321. # contractions we must include ourselves as a key in the returned dict
  322. b, e = expr.as_base_exp()
  323. dbase = get_contraction_structure(b)
  324. dexp = get_contraction_structure(e)
  325. dicts = []
  326. for d in dbase, dexp:
  327. if not (None in d and len(d) == 1):
  328. dicts.append(d)
  329. result = {None: {expr}}
  330. if dicts:
  331. result[expr] = dicts
  332. return result
  333. elif expr.is_Add:
  334. # Note: we just collect all terms with identical summation indices, We
  335. # do nothing to identify equivalent terms here, as this would require
  336. # substitutions or pattern matching in expressions of unknown
  337. # complexity.
  338. result = {}
  339. for term in expr.args:
  340. # recurse on every term
  341. d = get_contraction_structure(term)
  342. for key in d:
  343. if key in result:
  344. result[key] |= d[key]
  345. else:
  346. result[key] = d[key]
  347. return result
  348. elif isinstance(expr, Piecewise):
  349. # FIXME: No support for Piecewise yet
  350. return {None: expr}
  351. elif isinstance(expr, Function):
  352. # Collect non-trivial contraction structures in each argument
  353. # We do not report repeated indices in separate arguments as a
  354. # contraction
  355. deeplist = []
  356. for arg in expr.args:
  357. deep = get_contraction_structure(arg)
  358. if not (None in deep and len(deep) == 1):
  359. deeplist.append(deep)
  360. d = {None: {expr}}
  361. if deeplist:
  362. d[expr] = deeplist
  363. return d
  364. # this test is expensive, so it should be at the end
  365. elif not expr.has(Indexed):
  366. return {None: {expr}}
  367. raise NotImplementedError(
  368. "FIXME: No specialized handling of type %s" % type(expr))