toperators.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from sympy import permutedims
  2. from sympy.core.numbers import Number
  3. from sympy.core.singleton import S
  4. from sympy.core.symbol import Symbol
  5. from sympy.core.sympify import sympify
  6. from sympy.tensor.tensor import Tensor, TensExpr, TensAdd, TensMul
  7. class PartialDerivative(TensExpr):
  8. """
  9. Partial derivative for tensor expressions.
  10. Examples
  11. ========
  12. >>> from sympy.tensor.tensor import TensorIndexType, TensorHead
  13. >>> from sympy.tensor.toperators import PartialDerivative
  14. >>> from sympy import symbols
  15. >>> L = TensorIndexType("L")
  16. >>> A = TensorHead("A", [L])
  17. >>> B = TensorHead("B", [L])
  18. >>> i, j, k = symbols("i j k")
  19. >>> expr = PartialDerivative(A(i), A(j))
  20. >>> expr
  21. PartialDerivative(A(i), A(j))
  22. The ``PartialDerivative`` object behaves like a tensorial expression:
  23. >>> expr.get_indices()
  24. [i, -j]
  25. Notice that the deriving variables have opposite valence than the
  26. printed one: ``A(j)`` is printed as covariant, but the index of the
  27. derivative is actually contravariant, i.e. ``-j``.
  28. Indices can be contracted:
  29. >>> expr = PartialDerivative(A(i), A(i))
  30. >>> expr
  31. PartialDerivative(A(L_0), A(L_0))
  32. >>> expr.get_indices()
  33. [L_0, -L_0]
  34. The method ``.get_indices()`` always returns all indices (even the
  35. contracted ones). If only uncontracted indices are needed, call
  36. ``.get_free_indices()``:
  37. >>> expr.get_free_indices()
  38. []
  39. Nested partial derivatives are flattened:
  40. >>> expr = PartialDerivative(PartialDerivative(A(i), A(j)), A(k))
  41. >>> expr
  42. PartialDerivative(A(i), A(j), A(k))
  43. >>> expr.get_indices()
  44. [i, -j, -k]
  45. Replace a derivative with array values:
  46. >>> from sympy.abc import x, y
  47. >>> from sympy import sin, log
  48. >>> compA = [sin(x), log(x)*y**3]
  49. >>> compB = [x, y]
  50. >>> expr = PartialDerivative(A(i), B(j))
  51. >>> expr.replace_with_arrays({A(i): compA, B(i): compB})
  52. [[cos(x), 0], [y**3/x, 3*y**2*log(x)]]
  53. The returned array is indexed by `(i, -j)`.
  54. Be careful that other SymPy modules put the indices of the deriving
  55. variables before the indices of the derivand in the derivative result.
  56. For example:
  57. >>> expr.get_free_indices()
  58. [i, -j]
  59. >>> from sympy import Matrix, Array
  60. >>> Matrix(compA).diff(Matrix(compB)).reshape(2, 2)
  61. [[cos(x), y**3/x], [0, 3*y**2*log(x)]]
  62. >>> Array(compA).diff(Array(compB))
  63. [[cos(x), y**3/x], [0, 3*y**2*log(x)]]
  64. These are the transpose of the result of ``PartialDerivative``,
  65. as the matrix and the array modules put the index `-j` before `i` in the
  66. derivative result. An array read with index order `(-j, i)` is indeed the
  67. transpose of the same array read with index order `(i, -j)`. By specifying
  68. the index order to ``.replace_with_arrays`` one can get a compatible
  69. expression:
  70. >>> expr.replace_with_arrays({A(i): compA, B(i): compB}, [-j, i])
  71. [[cos(x), y**3/x], [0, 3*y**2*log(x)]]
  72. """
  73. def __new__(cls, expr, *variables):
  74. # Flatten:
  75. if isinstance(expr, PartialDerivative):
  76. variables = expr.variables + variables
  77. expr = expr.expr
  78. args, indices, free, dum = cls._contract_indices_for_derivative(
  79. S(expr), variables)
  80. obj = TensExpr.__new__(cls, *args)
  81. obj._indices = indices
  82. obj._free = free
  83. obj._dum = dum
  84. return obj
  85. @property
  86. def coeff(self):
  87. return S.One
  88. @property
  89. def nocoeff(self):
  90. return self
  91. @classmethod
  92. def _contract_indices_for_derivative(cls, expr, variables):
  93. variables_opposite_valence = []
  94. for i in variables:
  95. if isinstance(i, Tensor):
  96. i_free_indices = i.get_free_indices()
  97. variables_opposite_valence.append(
  98. i.xreplace({k: -k for k in i_free_indices}))
  99. elif isinstance(i, Symbol):
  100. variables_opposite_valence.append(i)
  101. args, indices, free, dum = TensMul._tensMul_contract_indices(
  102. [expr] + variables_opposite_valence, replace_indices=True)
  103. for i in range(1, len(args)):
  104. args_i = args[i]
  105. if isinstance(args_i, Tensor):
  106. i_indices = args[i].get_free_indices()
  107. args[i] = args[i].xreplace({k: -k for k in i_indices})
  108. return args, indices, free, dum
  109. def doit(self, **hints):
  110. args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables)
  111. obj = self.func(*args)
  112. obj._indices = indices
  113. obj._free = free
  114. obj._dum = dum
  115. return obj
  116. def _expand_partial_derivative(self):
  117. args, indices, free, dum = self._contract_indices_for_derivative(self.expr, self.variables)
  118. obj = self.func(*args)
  119. obj._indices = indices
  120. obj._free = free
  121. obj._dum = dum
  122. result = obj
  123. if not args[0].free_symbols:
  124. return S.Zero
  125. elif isinstance(obj.expr, TensAdd):
  126. # take care of sums of multi PDs
  127. result = obj.expr.func(*[
  128. self.func(a, *obj.variables)._expand_partial_derivative()
  129. for a in result.expr.args])
  130. elif isinstance(obj.expr, TensMul):
  131. # take care of products of multi PDs
  132. if len(obj.variables) == 1:
  133. # derivative with respect to single variable
  134. terms = []
  135. mulargs = list(obj.expr.args)
  136. for ind in range(len(mulargs)):
  137. if not isinstance(sympify(mulargs[ind]), Number):
  138. # a number coefficient is not considered for
  139. # expansion of PartialDerivative
  140. d = self.func(mulargs[ind], *obj.variables)._expand_partial_derivative()
  141. terms.append(TensMul(*(mulargs[:ind]
  142. + [d]
  143. + mulargs[(ind + 1):])))
  144. result = TensAdd.fromiter(terms)
  145. else:
  146. # derivative with respect to multiple variables
  147. # decompose:
  148. # partial(expr, (u, v))
  149. # = partial(partial(expr, u).doit(), v).doit()
  150. result = obj.expr # init with expr
  151. for v in obj.variables:
  152. result = self.func(result, v)._expand_partial_derivative()
  153. # then throw PD on it
  154. return result
  155. def _perform_derivative(self):
  156. result = self.expr
  157. for v in self.variables:
  158. if isinstance(result, TensExpr):
  159. result = result._eval_partial_derivative(v)
  160. else:
  161. if v._diff_wrt:
  162. result = result._eval_derivative(v)
  163. else:
  164. result = S.Zero
  165. return result
  166. def get_indices(self):
  167. return self._indices
  168. def get_free_indices(self):
  169. free = sorted(self._free, key=lambda x: x[1])
  170. return [i[0] for i in free]
  171. def _replace_indices(self, repl):
  172. expr = self.expr.xreplace(repl)
  173. mirrored = {-k: -v for k, v in repl.items()}
  174. variables = [i.xreplace(mirrored) for i in self.variables]
  175. return self.func(expr, *variables)
  176. @property
  177. def expr(self):
  178. return self.args[0]
  179. @property
  180. def variables(self):
  181. return self.args[1:]
  182. def _extract_data(self, replacement_dict):
  183. from .array import derive_by_array, tensorcontraction
  184. indices, array = self.expr._extract_data(replacement_dict)
  185. for variable in self.variables:
  186. var_indices, var_array = variable._extract_data(replacement_dict)
  187. var_indices = [-i for i in var_indices]
  188. coeff_array, var_array = zip(*[i.as_coeff_Mul() for i in var_array])
  189. dim_before = len(array.shape)
  190. array = derive_by_array(array, var_array)
  191. dim_after = len(array.shape)
  192. dim_increase = dim_after - dim_before
  193. array = permutedims(array, [i + dim_increase for i in range(dim_before)] + list(range(dim_increase)))
  194. array = array.as_mutable()
  195. varindex = var_indices[0]
  196. # Remove coefficients of base vector:
  197. coeff_index = [0] + [slice(None) for i in range(len(indices))]
  198. for i, coeff in enumerate(coeff_array):
  199. coeff_index[0] = i
  200. array[tuple(coeff_index)] /= coeff
  201. if -varindex in indices:
  202. pos = indices.index(-varindex)
  203. array = tensorcontraction(array, (0, pos+1))
  204. indices.pop(pos)
  205. else:
  206. indices.append(varindex)
  207. return indices, array