arrayexpr_derivatives.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import operator
  2. from functools import reduce, singledispatch
  3. from sympy.core.expr import Expr
  4. from sympy.core.singleton import S
  5. from sympy.matrices.expressions.hadamard import HadamardProduct
  6. from sympy.matrices.expressions.inverse import Inverse
  7. from sympy.matrices.expressions.matexpr import (MatrixExpr, MatrixSymbol)
  8. from sympy.matrices.expressions.special import Identity, OneMatrix
  9. from sympy.matrices.expressions.transpose import Transpose
  10. from sympy.combinatorics.permutations import _af_invert
  11. from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
  12. from sympy.tensor.array.expressions.array_expressions import (
  13. _ArrayExpr, ZeroArray, ArraySymbol, ArrayTensorProduct, ArrayAdd,
  14. PermuteDims, ArrayDiagonal, ArrayElementwiseApplyFunc, get_rank,
  15. get_shape, ArrayContraction, _array_tensor_product, _array_contraction,
  16. _array_diagonal, _array_add, _permute_dims, Reshape)
  17. from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
  18. @singledispatch
  19. def array_derive(expr, x):
  20. """
  21. Derivatives (gradients) for array expressions.
  22. """
  23. raise NotImplementedError(f"not implemented for type {type(expr)}")
  24. @array_derive.register(Expr)
  25. def _(expr: Expr, x: _ArrayExpr):
  26. return ZeroArray(*x.shape)
  27. @array_derive.register(ArrayTensorProduct)
  28. def _(expr: ArrayTensorProduct, x: Expr):
  29. args = expr.args
  30. addend_list = []
  31. for i, arg in enumerate(expr.args):
  32. darg = array_derive(arg, x)
  33. if darg == 0:
  34. continue
  35. args_prev = args[:i]
  36. args_succ = args[i+1:]
  37. shape_prev = reduce(operator.add, map(get_shape, args_prev), ())
  38. shape_succ = reduce(operator.add, map(get_shape, args_succ), ())
  39. addend = _array_tensor_product(*args_prev, darg, *args_succ)
  40. tot1 = len(get_shape(x))
  41. tot2 = tot1 + len(shape_prev)
  42. tot3 = tot2 + len(get_shape(arg))
  43. tot4 = tot3 + len(shape_succ)
  44. perm = list(range(tot1, tot2)) + \
  45. list(range(tot1)) + list(range(tot2, tot3)) + \
  46. list(range(tot3, tot4))
  47. addend = _permute_dims(addend, _af_invert(perm))
  48. addend_list.append(addend)
  49. if len(addend_list) == 1:
  50. return addend_list[0]
  51. elif len(addend_list) == 0:
  52. return S.Zero
  53. else:
  54. return _array_add(*addend_list)
  55. @array_derive.register(ArraySymbol)
  56. def _(expr: ArraySymbol, x: _ArrayExpr):
  57. if expr == x:
  58. return _permute_dims(
  59. ArrayTensorProduct.fromiter(Identity(i) for i in expr.shape),
  60. [2*i for i in range(len(expr.shape))] + [2*i+1 for i in range(len(expr.shape))]
  61. )
  62. return ZeroArray(*(x.shape + expr.shape))
  63. @array_derive.register(MatrixSymbol)
  64. def _(expr: MatrixSymbol, x: _ArrayExpr):
  65. m, n = expr.shape
  66. if expr == x:
  67. return _permute_dims(
  68. _array_tensor_product(Identity(m), Identity(n)),
  69. [0, 2, 1, 3]
  70. )
  71. return ZeroArray(*(x.shape + expr.shape))
  72. @array_derive.register(Identity)
  73. def _(expr: Identity, x: _ArrayExpr):
  74. return ZeroArray(*(x.shape + expr.shape))
  75. @array_derive.register(OneMatrix)
  76. def _(expr: OneMatrix, x: _ArrayExpr):
  77. return ZeroArray(*(x.shape + expr.shape))
  78. @array_derive.register(Transpose)
  79. def _(expr: Transpose, x: Expr):
  80. # D(A.T, A) ==> (m,n,i,j) ==> D(A_ji, A_mn) = d_mj d_ni
  81. # D(B.T, A) ==> (m,n,i,j) ==> D(B_ji, A_mn)
  82. fd = array_derive(expr.arg, x)
  83. return _permute_dims(fd, [0, 1, 3, 2])
  84. @array_derive.register(Inverse)
  85. def _(expr: Inverse, x: Expr):
  86. mat = expr.I
  87. dexpr = array_derive(mat, x)
  88. tp = _array_tensor_product(-expr, dexpr, expr)
  89. mp = _array_contraction(tp, (1, 4), (5, 6))
  90. pp = _permute_dims(mp, [1, 2, 0, 3])
  91. return pp
  92. @array_derive.register(ElementwiseApplyFunction)
  93. def _(expr: ElementwiseApplyFunction, x: Expr):
  94. assert get_rank(expr) == 2
  95. assert get_rank(x) == 2
  96. fdiff = expr._get_function_fdiff()
  97. dexpr = array_derive(expr.expr, x)
  98. tp = _array_tensor_product(
  99. ElementwiseApplyFunction(fdiff, expr.expr),
  100. dexpr
  101. )
  102. td = _array_diagonal(
  103. tp, (0, 4), (1, 5)
  104. )
  105. return td
  106. @array_derive.register(ArrayElementwiseApplyFunc)
  107. def _(expr: ArrayElementwiseApplyFunc, x: Expr):
  108. fdiff = expr._get_function_fdiff()
  109. subexpr = expr.expr
  110. dsubexpr = array_derive(subexpr, x)
  111. tp = _array_tensor_product(
  112. dsubexpr,
  113. ArrayElementwiseApplyFunc(fdiff, subexpr)
  114. )
  115. b = get_rank(x)
  116. c = get_rank(expr)
  117. diag_indices = [(b + i, b + c + i) for i in range(c)]
  118. return _array_diagonal(tp, *diag_indices)
  119. @array_derive.register(MatrixExpr)
  120. def _(expr: MatrixExpr, x: Expr):
  121. cg = convert_matrix_to_array(expr)
  122. return array_derive(cg, x)
  123. @array_derive.register(HadamardProduct)
  124. def _(expr: HadamardProduct, x: Expr):
  125. raise NotImplementedError()
  126. @array_derive.register(ArrayContraction)
  127. def _(expr: ArrayContraction, x: Expr):
  128. fd = array_derive(expr.expr, x)
  129. rank_x = len(get_shape(x))
  130. contraction_indices = expr.contraction_indices
  131. new_contraction_indices = [tuple(j + rank_x for j in i) for i in contraction_indices]
  132. return _array_contraction(fd, *new_contraction_indices)
  133. @array_derive.register(ArrayDiagonal)
  134. def _(expr: ArrayDiagonal, x: Expr):
  135. dsubexpr = array_derive(expr.expr, x)
  136. rank_x = len(get_shape(x))
  137. diag_indices = [[j + rank_x for j in i] for i in expr.diagonal_indices]
  138. return _array_diagonal(dsubexpr, *diag_indices)
  139. @array_derive.register(ArrayAdd)
  140. def _(expr: ArrayAdd, x: Expr):
  141. return _array_add(*[array_derive(arg, x) for arg in expr.args])
  142. @array_derive.register(PermuteDims)
  143. def _(expr: PermuteDims, x: Expr):
  144. de = array_derive(expr.expr, x)
  145. perm = [0, 1] + [i + 2 for i in expr.permutation.array_form]
  146. return _permute_dims(de, perm)
  147. @array_derive.register(Reshape)
  148. def _(expr: Reshape, x: Expr):
  149. de = array_derive(expr.expr, x)
  150. return Reshape(de, get_shape(x) + expr.shape)
  151. def matrix_derive(expr, x):
  152. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  153. ce = convert_matrix_to_array(expr)
  154. dce = array_derive(ce, x)
  155. return convert_array_to_matrix(dce).doit()