test_applyfunc.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from sympy.core.symbol import symbols, Dummy
  2. from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
  3. from sympy.core.function import Lambda
  4. from sympy.functions.elementary.exponential import exp
  5. from sympy.functions.elementary.trigonometric import sin
  6. from sympy.matrices.dense import Matrix
  7. from sympy.matrices.expressions.matexpr import MatrixSymbol
  8. from sympy.matrices.expressions.matmul import MatMul
  9. from sympy.simplify.simplify import simplify
  10. X = MatrixSymbol("X", 3, 3)
  11. Y = MatrixSymbol("Y", 3, 3)
  12. k = symbols("k")
  13. Xk = MatrixSymbol("X", k, k)
  14. Xd = X.as_explicit()
  15. x, y, z, t = symbols("x y z t")
  16. def test_applyfunc_matrix():
  17. x = Dummy('x')
  18. double = Lambda(x, x**2)
  19. expr = ElementwiseApplyFunction(double, Xd)
  20. assert isinstance(expr, ElementwiseApplyFunction)
  21. assert expr.doit() == Xd.applyfunc(lambda x: x**2)
  22. assert expr.shape == (3, 3)
  23. assert expr.func(*expr.args) == expr
  24. assert simplify(expr) == expr
  25. assert expr[0, 0] == double(Xd[0, 0])
  26. expr = ElementwiseApplyFunction(double, X)
  27. assert isinstance(expr, ElementwiseApplyFunction)
  28. assert isinstance(expr.doit(), ElementwiseApplyFunction)
  29. assert expr == X.applyfunc(double)
  30. assert expr.func(*expr.args) == expr
  31. expr = ElementwiseApplyFunction(exp, X*Y)
  32. assert expr.expr == X*Y
  33. assert expr.function.dummy_eq(Lambda(x, exp(x)))
  34. assert expr.dummy_eq((X*Y).applyfunc(exp))
  35. assert expr.func(*expr.args) == expr
  36. assert isinstance(X*expr, MatMul)
  37. assert (X*expr).shape == (3, 3)
  38. Z = MatrixSymbol("Z", 2, 3)
  39. assert (Z*expr).shape == (2, 3)
  40. expr = ElementwiseApplyFunction(exp, Z.T)*ElementwiseApplyFunction(exp, Z)
  41. assert expr.shape == (3, 3)
  42. expr = ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z.T)
  43. assert expr.shape == (2, 2)
  44. M = Matrix([[x, y], [z, t]])
  45. expr = ElementwiseApplyFunction(sin, M)
  46. assert isinstance(expr, ElementwiseApplyFunction)
  47. assert expr.function.dummy_eq(Lambda(x, sin(x)))
  48. assert expr.expr == M
  49. assert expr.doit() == M.applyfunc(sin)
  50. assert expr.doit() == Matrix([[sin(x), sin(y)], [sin(z), sin(t)]])
  51. assert expr.func(*expr.args) == expr
  52. expr = ElementwiseApplyFunction(double, Xk)
  53. assert expr.doit() == expr
  54. assert expr.subs(k, 2).shape == (2, 2)
  55. assert (expr*expr).shape == (k, k)
  56. M = MatrixSymbol("M", k, t)
  57. expr2 = M.T*expr*M
  58. assert isinstance(expr2, MatMul)
  59. assert expr2.args[1] == expr
  60. assert expr2.shape == (t, t)
  61. expr3 = expr*M
  62. assert expr3.shape == (k, t)
  63. expr1 = ElementwiseApplyFunction(lambda x: x+1, Xk)
  64. expr2 = ElementwiseApplyFunction(lambda x: x, Xk)
  65. assert expr1 != expr2
  66. def test_applyfunc_entry():
  67. af = X.applyfunc(sin)
  68. assert af[0, 0] == sin(X[0, 0])
  69. af = Xd.applyfunc(sin)
  70. assert af[0, 0] == sin(X[0, 0])
  71. def test_applyfunc_as_explicit():
  72. af = X.applyfunc(sin)
  73. assert af.as_explicit() == Matrix([
  74. [sin(X[0, 0]), sin(X[0, 1]), sin(X[0, 2])],
  75. [sin(X[1, 0]), sin(X[1, 1]), sin(X[1, 2])],
  76. [sin(X[2, 0]), sin(X[2, 1]), sin(X[2, 2])],
  77. ])
  78. def test_applyfunc_transpose():
  79. af = Xk.applyfunc(sin)
  80. assert af.T.dummy_eq(Xk.T.applyfunc(sin))
  81. def test_applyfunc_shape_11_matrices():
  82. M = MatrixSymbol("M", 1, 1)
  83. double = Lambda(x, x*2)
  84. expr = M.applyfunc(sin)
  85. assert isinstance(expr, ElementwiseApplyFunction)
  86. expr = M.applyfunc(double)
  87. assert isinstance(expr, MatMul)
  88. assert expr == 2*M