test_trace.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from sympy.core.containers import Tuple
  2. from sympy.core.symbol import symbols
  3. from sympy.matrices.dense import Matrix
  4. from sympy.physics.quantum.trace import Tr
  5. from sympy.testing.pytest import raises, warns_deprecated_sympy
  6. def test_trace_new():
  7. a, b, c, d, Y = symbols('a b c d Y')
  8. A, B, C, D = symbols('A B C D', commutative=False)
  9. assert Tr(a + b) == a + b
  10. assert Tr(A + B) == Tr(A) + Tr(B)
  11. #check trace args not implicitly permuted
  12. assert Tr(C*D*A*B).args[0].args == (C, D, A, B)
  13. # check for mul and adds
  14. assert Tr((a*b) + ( c*d)) == (a*b) + (c*d)
  15. # Tr(scalar*A) = scalar*Tr(A)
  16. assert Tr(a*A) == a*Tr(A)
  17. assert Tr(a*A*B*b) == a*b*Tr(A*B)
  18. # since A is symbol and not commutative
  19. assert isinstance(Tr(A), Tr)
  20. #POW
  21. assert Tr(pow(a, b)) == a**b
  22. assert isinstance(Tr(pow(A, a)), Tr)
  23. #Matrix
  24. M = Matrix([[1, 1], [2, 2]])
  25. assert Tr(M) == 3
  26. ##test indices in different forms
  27. #no index
  28. t = Tr(A)
  29. assert t.args[1] == Tuple()
  30. #single index
  31. t = Tr(A, 0)
  32. assert t.args[1] == Tuple(0)
  33. #index in a list
  34. t = Tr(A, [0])
  35. assert t.args[1] == Tuple(0)
  36. t = Tr(A, [0, 1, 2])
  37. assert t.args[1] == Tuple(0, 1, 2)
  38. #index is tuple
  39. t = Tr(A, (0))
  40. assert t.args[1] == Tuple(0)
  41. t = Tr(A, (1, 2))
  42. assert t.args[1] == Tuple(1, 2)
  43. #trace indices test
  44. t = Tr((A + B), [2])
  45. assert t.args[0].args[1] == Tuple(2) and t.args[1].args[1] == Tuple(2)
  46. t = Tr(a*A, [2, 3])
  47. assert t.args[1].args[1] == Tuple(2, 3)
  48. #class with trace method defined
  49. #to simulate numpy objects
  50. class Foo:
  51. def trace(self):
  52. return 1
  53. assert Tr(Foo()) == 1
  54. #argument test
  55. # check for value error, when either/both arguments are not provided
  56. raises(ValueError, lambda: Tr())
  57. raises(ValueError, lambda: Tr(A, 1, 2))
  58. def test_trace_doit():
  59. a, b, c, d = symbols('a b c d')
  60. A, B, C, D = symbols('A B C D', commutative=False)
  61. #TODO: needed while testing reduced density operations, etc.
  62. def test_permute():
  63. A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False)
  64. t = Tr(A*B*C*D*E*F*G)
  65. assert t.permute(0).args[0].args == (A, B, C, D, E, F, G)
  66. assert t.permute(2).args[0].args == (F, G, A, B, C, D, E)
  67. assert t.permute(4).args[0].args == (D, E, F, G, A, B, C)
  68. assert t.permute(6).args[0].args == (B, C, D, E, F, G, A)
  69. assert t.permute(8).args[0].args == t.permute(1).args[0].args
  70. assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A)
  71. assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C)
  72. assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E)
  73. assert t.permute(-8).args[0].args == t.permute(-1).args[0].args
  74. t = Tr((A + B)*(B*B)*C*D)
  75. assert t.permute(2).args[0].args == (C, D, (A + B), (B**2))
  76. t1 = Tr(A*B)
  77. t2 = t1.permute(1)
  78. assert id(t1) != id(t2) and t1 == t2
  79. def test_deprecated_core_trace():
  80. with warns_deprecated_sympy():
  81. from sympy.core.trace import Tr # noqa:F401