test_index_methods.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from sympy.core import symbols, S, Pow, Function
  2. from sympy.functions import exp
  3. from sympy.testing.pytest import raises
  4. from sympy.tensor.indexed import Idx, IndexedBase
  5. from sympy.tensor.index_methods import IndexConformanceException
  6. from sympy.tensor.index_methods import (get_contraction_structure, get_indices)
  7. def test_trivial_indices():
  8. x, y = symbols('x y')
  9. assert get_indices(x) == (set(), {})
  10. assert get_indices(x*y) == (set(), {})
  11. assert get_indices(x + y) == (set(), {})
  12. assert get_indices(x**y) == (set(), {})
  13. def test_get_indices_Indexed():
  14. x = IndexedBase('x')
  15. i, j = Idx('i'), Idx('j')
  16. assert get_indices(x[i, j]) == ({i, j}, {})
  17. assert get_indices(x[j, i]) == ({j, i}, {})
  18. def test_get_indices_Idx():
  19. f = Function('f')
  20. i, j = Idx('i'), Idx('j')
  21. assert get_indices(f(i)*j) == ({i, j}, {})
  22. assert get_indices(f(j, i)) == ({j, i}, {})
  23. assert get_indices(f(i)*i) == (set(), {})
  24. def test_get_indices_mul():
  25. x = IndexedBase('x')
  26. y = IndexedBase('y')
  27. i, j = Idx('i'), Idx('j')
  28. assert get_indices(x[j]*y[i]) == ({i, j}, {})
  29. assert get_indices(x[i]*y[j]) == ({i, j}, {})
  30. def test_get_indices_exceptions():
  31. x = IndexedBase('x')
  32. y = IndexedBase('y')
  33. i, j = Idx('i'), Idx('j')
  34. raises(IndexConformanceException, lambda: get_indices(x[i] + y[j]))
  35. def test_scalar_broadcast():
  36. x = IndexedBase('x')
  37. y = IndexedBase('y')
  38. i, j = Idx('i'), Idx('j')
  39. assert get_indices(x[i] + y[i, i]) == ({i}, {})
  40. assert get_indices(x[i] + y[j, j]) == ({i}, {})
  41. def test_get_indices_add():
  42. x = IndexedBase('x')
  43. y = IndexedBase('y')
  44. A = IndexedBase('A')
  45. i, j, k = Idx('i'), Idx('j'), Idx('k')
  46. assert get_indices(x[i] + 2*y[i]) == ({i}, {})
  47. assert get_indices(y[i] + 2*A[i, j]*x[j]) == ({i}, {})
  48. assert get_indices(y[i] + 2*(x[i] + A[i, j]*x[j])) == ({i}, {})
  49. assert get_indices(y[i] + x[i]*(A[j, j] + 1)) == ({i}, {})
  50. assert get_indices(
  51. y[i] + x[i]*x[j]*(y[j] + A[j, k]*x[k])) == ({i}, {})
  52. def test_get_indices_Pow():
  53. x = IndexedBase('x')
  54. y = IndexedBase('y')
  55. A = IndexedBase('A')
  56. i, j, k = Idx('i'), Idx('j'), Idx('k')
  57. assert get_indices(Pow(x[i], y[j])) == ({i, j}, {})
  58. assert get_indices(Pow(x[i, k], y[j, k])) == ({i, j, k}, {})
  59. assert get_indices(Pow(A[i, k], y[k] + A[k, j]*x[j])) == ({i, k}, {})
  60. assert get_indices(Pow(2, x[i])) == get_indices(exp(x[i]))
  61. # test of a design decision, this may change:
  62. assert get_indices(Pow(x[i], 2)) == ({i}, {})
  63. def test_get_contraction_structure_basic():
  64. x = IndexedBase('x')
  65. y = IndexedBase('y')
  66. i, j = Idx('i'), Idx('j')
  67. assert get_contraction_structure(x[i]*y[j]) == {None: {x[i]*y[j]}}
  68. assert get_contraction_structure(x[i] + y[j]) == {None: {x[i], y[j]}}
  69. assert get_contraction_structure(x[i]*y[i]) == {(i,): {x[i]*y[i]}}
  70. assert get_contraction_structure(
  71. 1 + x[i]*y[i]) == {None: {S.One}, (i,): {x[i]*y[i]}}
  72. assert get_contraction_structure(x[i]**y[i]) == {None: {x[i]**y[i]}}
  73. def test_get_contraction_structure_complex():
  74. x = IndexedBase('x')
  75. y = IndexedBase('y')
  76. A = IndexedBase('A')
  77. i, j, k = Idx('i'), Idx('j'), Idx('k')
  78. expr1 = y[i] + A[i, j]*x[j]
  79. d1 = {None: {y[i]}, (j,): {A[i, j]*x[j]}}
  80. assert get_contraction_structure(expr1) == d1
  81. expr2 = expr1*A[k, i] + x[k]
  82. d2 = {None: {x[k]}, (i,): {expr1*A[k, i]}, expr1*A[k, i]: [d1]}
  83. assert get_contraction_structure(expr2) == d2
  84. def test_contraction_structure_simple_Pow():
  85. x = IndexedBase('x')
  86. y = IndexedBase('y')
  87. i, j, k = Idx('i'), Idx('j'), Idx('k')
  88. ii_jj = x[i, i]**y[j, j]
  89. assert get_contraction_structure(ii_jj) == {
  90. None: {ii_jj},
  91. ii_jj: [
  92. {(i,): {x[i, i]}},
  93. {(j,): {y[j, j]}}
  94. ]
  95. }
  96. ii_jk = x[i, i]**y[j, k]
  97. assert get_contraction_structure(ii_jk) == {
  98. None: {x[i, i]**y[j, k]},
  99. x[i, i]**y[j, k]: [
  100. {(i,): {x[i, i]}}
  101. ]
  102. }
  103. def test_contraction_structure_Mul_and_Pow():
  104. x = IndexedBase('x')
  105. y = IndexedBase('y')
  106. i, j, k = Idx('i'), Idx('j'), Idx('k')
  107. i_ji = x[i]**(y[j]*x[i])
  108. assert get_contraction_structure(i_ji) == {None: {i_ji}}
  109. ij_i = (x[i]*y[j])**(y[i])
  110. assert get_contraction_structure(ij_i) == {None: {ij_i}}
  111. j_ij_i = x[j]*(x[i]*y[j])**(y[i])
  112. assert get_contraction_structure(j_ij_i) == {(j,): {j_ij_i}}
  113. j_i_ji = x[j]*x[i]**(y[j]*x[i])
  114. assert get_contraction_structure(j_i_ji) == {(j,): {j_i_ji}}
  115. ij_exp_kki = x[i]*y[j]*exp(y[i]*y[k, k])
  116. result = get_contraction_structure(ij_exp_kki)
  117. expected = {
  118. (i,): {ij_exp_kki},
  119. ij_exp_kki: [{
  120. None: {exp(y[i]*y[k, k])},
  121. exp(y[i]*y[k, k]): [{
  122. None: {y[i]*y[k, k]},
  123. y[i]*y[k, k]: [{(k,): {y[k, k]}}]
  124. }]}
  125. ]
  126. }
  127. assert result == expected
  128. def test_contraction_structure_Add_in_Pow():
  129. x = IndexedBase('x')
  130. y = IndexedBase('y')
  131. i, j, k = Idx('i'), Idx('j'), Idx('k')
  132. s_ii_jj_s = (1 + x[i, i])**(1 + y[j, j])
  133. expected = {
  134. None: {s_ii_jj_s},
  135. s_ii_jj_s: [
  136. {None: {S.One}, (i,): {x[i, i]}},
  137. {None: {S.One}, (j,): {y[j, j]}}
  138. ]
  139. }
  140. result = get_contraction_structure(s_ii_jj_s)
  141. assert result == expected
  142. s_ii_jk_s = (1 + x[i, i]) ** (1 + y[j, k])
  143. expected_2 = {
  144. None: {(x[i, i] + 1)**(y[j, k] + 1)},
  145. s_ii_jk_s: [
  146. {None: {S.One}, (i,): {x[i, i]}}
  147. ]
  148. }
  149. result_2 = get_contraction_structure(s_ii_jk_s)
  150. assert result_2 == expected_2
  151. def test_contraction_structure_Pow_in_Pow():
  152. x = IndexedBase('x')
  153. y = IndexedBase('y')
  154. z = IndexedBase('z')
  155. i, j, k = Idx('i'), Idx('j'), Idx('k')
  156. ii_jj_kk = x[i, i]**y[j, j]**z[k, k]
  157. expected = {
  158. None: {ii_jj_kk},
  159. ii_jj_kk: [
  160. {(i,): {x[i, i]}},
  161. {
  162. None: {y[j, j]**z[k, k]},
  163. y[j, j]**z[k, k]: [
  164. {(j,): {y[j, j]}},
  165. {(k,): {z[k, k]}}
  166. ]
  167. }
  168. ]
  169. }
  170. assert get_contraction_structure(ii_jj_kk) == expected
  171. def test_ufunc_support():
  172. f = Function('f')
  173. g = Function('g')
  174. x = IndexedBase('x')
  175. y = IndexedBase('y')
  176. i, j = Idx('i'), Idx('j')
  177. a = symbols('a')
  178. assert get_indices(f(x[i])) == ({i}, {})
  179. assert get_indices(f(x[i], y[j])) == ({i, j}, {})
  180. assert get_indices(f(y[i])*g(x[i])) == (set(), {})
  181. assert get_indices(f(a, x[i])) == ({i}, {})
  182. assert get_indices(f(a, y[i], x[j])*g(x[i])) == ({j}, {})
  183. assert get_indices(g(f(x[i]))) == ({i}, {})
  184. assert get_contraction_structure(f(x[i])) == {None: {f(x[i])}}
  185. assert get_contraction_structure(
  186. f(y[i])*g(x[i])) == {(i,): {f(y[i])*g(x[i])}}
  187. assert get_contraction_structure(
  188. f(y[i])*g(f(x[i]))) == {(i,): {f(y[i])*g(f(x[i]))}}
  189. assert get_contraction_structure(
  190. f(x[j], y[i])*g(x[i])) == {(i,): {f(x[j], y[i])*g(x[i])}}