test_noncommutative.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """Tests for noncommutative symbols and expressions."""
  2. from sympy.core.function import expand
  3. from sympy.core.numbers import I
  4. from sympy.core.symbol import symbols
  5. from sympy.functions.elementary.complexes import (adjoint, conjugate, transpose)
  6. from sympy.functions.elementary.trigonometric import (cos, sin)
  7. from sympy.polys.polytools import (cancel, factor)
  8. from sympy.simplify.combsimp import combsimp
  9. from sympy.simplify.gammasimp import gammasimp
  10. from sympy.simplify.radsimp import (collect, radsimp, rcollect)
  11. from sympy.simplify.ratsimp import ratsimp
  12. from sympy.simplify.simplify import (posify, simplify)
  13. from sympy.simplify.trigsimp import trigsimp
  14. from sympy.abc import x, y, z
  15. from sympy.testing.pytest import XFAIL
  16. A, B, C = symbols("A B C", commutative=False)
  17. X = symbols("X", commutative=False, hermitian=True)
  18. Y = symbols("Y", commutative=False, antihermitian=True)
  19. def test_adjoint():
  20. assert adjoint(A).is_commutative is False
  21. assert adjoint(A*A) == adjoint(A)**2
  22. assert adjoint(A*B) == adjoint(B)*adjoint(A)
  23. assert adjoint(A*B**2) == adjoint(B)**2*adjoint(A)
  24. assert adjoint(A*B - B*A) == adjoint(B)*adjoint(A) - adjoint(A)*adjoint(B)
  25. assert adjoint(A + I*B) == adjoint(A) - I*adjoint(B)
  26. assert adjoint(X) == X
  27. assert adjoint(-I*X) == I*X
  28. assert adjoint(Y) == -Y
  29. assert adjoint(-I*Y) == -I*Y
  30. assert adjoint(X) == conjugate(transpose(X))
  31. assert adjoint(Y) == conjugate(transpose(Y))
  32. assert adjoint(X) == transpose(conjugate(X))
  33. assert adjoint(Y) == transpose(conjugate(Y))
  34. def test_cancel():
  35. assert cancel(A*B - B*A) == A*B - B*A
  36. assert cancel(A*B*(x - 1)) == A*B*(x - 1)
  37. assert cancel(A*B*(x**2 - 1)/(x + 1)) == A*B*(x - 1)
  38. assert cancel(A*B*(x**2 - 1)/(x + 1) - B*A*(x - 1)) == A*B*(x - 1) + (1 - x)*B*A
  39. @XFAIL
  40. def test_collect():
  41. assert collect(A*B - B*A, A) == A*B - B*A
  42. assert collect(A*B - B*A, B) == A*B - B*A
  43. assert collect(A*B - B*A, x) == A*B - B*A
  44. def test_combsimp():
  45. assert combsimp(A*B - B*A) == A*B - B*A
  46. def test_gammasimp():
  47. assert gammasimp(A*B - B*A) == A*B - B*A
  48. def test_conjugate():
  49. assert conjugate(A).is_commutative is False
  50. assert (A*A).conjugate() == conjugate(A)**2
  51. assert (A*B).conjugate() == conjugate(A)*conjugate(B)
  52. assert (A*B**2).conjugate() == conjugate(A)*conjugate(B)**2
  53. assert (A*B - B*A).conjugate() == \
  54. conjugate(A)*conjugate(B) - conjugate(B)*conjugate(A)
  55. assert (A*B).conjugate() - (B*A).conjugate() == \
  56. conjugate(A)*conjugate(B) - conjugate(B)*conjugate(A)
  57. assert (A + I*B).conjugate() == conjugate(A) - I*conjugate(B)
  58. def test_expand():
  59. assert expand((A*B)**2) == A*B*A*B
  60. assert expand(A*B - B*A) == A*B - B*A
  61. assert expand((A*B/A)**2) == A*B*B/A
  62. assert expand(B*A*(A + B)*B) == B*A**2*B + B*A*B**2
  63. assert expand(B*A*(A + C)*B) == B*A**2*B + B*A*C*B
  64. def test_factor():
  65. assert factor(A*B - B*A) == A*B - B*A
  66. def test_posify():
  67. assert posify(A)[0].is_commutative is False
  68. for q in (A*B/A, (A*B/A)**2, (A*B)**2, A*B - B*A):
  69. p = posify(q)
  70. assert p[0].subs(p[1]) == q
  71. def test_radsimp():
  72. assert radsimp(A*B - B*A) == A*B - B*A
  73. @XFAIL
  74. def test_ratsimp():
  75. assert ratsimp(A*B - B*A) == A*B - B*A
  76. @XFAIL
  77. def test_rcollect():
  78. assert rcollect(A*B - B*A, A) == A*B - B*A
  79. assert rcollect(A*B - B*A, B) == A*B - B*A
  80. assert rcollect(A*B - B*A, x) == A*B - B*A
  81. def test_simplify():
  82. assert simplify(A*B - B*A) == A*B - B*A
  83. def test_subs():
  84. assert (x*y*A).subs(x*y, z) == A*z
  85. assert (x*A*B).subs(x*A, C) == C*B
  86. assert (x*A*x*x).subs(x**2*A, C) == x*C
  87. assert (x*A*x*B).subs(x**2*A, C) == C*B
  88. assert (A**2*B**2).subs(A*B**2, C) == A*C
  89. assert (A*A*A + A*B*A).subs(A*A*A, C) == C + A*B*A
  90. def test_transpose():
  91. assert transpose(A).is_commutative is False
  92. assert transpose(A*A) == transpose(A)**2
  93. assert transpose(A*B) == transpose(B)*transpose(A)
  94. assert transpose(A*B**2) == transpose(B)**2*transpose(A)
  95. assert transpose(A*B - B*A) == \
  96. transpose(B)*transpose(A) - transpose(A)*transpose(B)
  97. assert transpose(A + I*B) == transpose(A) + I*transpose(B)
  98. assert transpose(X) == conjugate(X)
  99. assert transpose(-I*X) == -I*conjugate(X)
  100. assert transpose(Y) == -conjugate(Y)
  101. assert transpose(-I*Y) == I*conjugate(Y)
  102. def test_trigsimp():
  103. assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A