test_operator.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. from sympy.core.function import (Derivative, Function, diff)
  2. from sympy.core.mul import Mul
  3. from sympy.core.numbers import (Integer, pi)
  4. from sympy.core.symbol import (Symbol, symbols)
  5. from sympy.functions.elementary.trigonometric import sin
  6. from sympy.physics.quantum.qexpr import QExpr
  7. from sympy.physics.quantum.dagger import Dagger
  8. from sympy.physics.quantum.hilbert import HilbertSpace
  9. from sympy.physics.quantum.operator import (Operator, UnitaryOperator,
  10. HermitianOperator, OuterProduct,
  11. DifferentialOperator,
  12. IdentityOperator)
  13. from sympy.physics.quantum.state import Ket, Bra, Wavefunction
  14. from sympy.physics.quantum.qapply import qapply
  15. from sympy.physics.quantum.represent import represent
  16. from sympy.physics.quantum.spin import JzKet, JzBra
  17. from sympy.physics.quantum.trace import Tr
  18. from sympy.matrices import eye
  19. class CustomKet(Ket):
  20. @classmethod
  21. def default_args(self):
  22. return ("t",)
  23. class CustomOp(HermitianOperator):
  24. @classmethod
  25. def default_args(self):
  26. return ("T",)
  27. t_ket = CustomKet()
  28. t_op = CustomOp()
  29. def test_operator():
  30. A = Operator('A')
  31. B = Operator('B')
  32. C = Operator('C')
  33. assert isinstance(A, Operator)
  34. assert isinstance(A, QExpr)
  35. assert A.label == (Symbol('A'),)
  36. assert A.is_commutative is False
  37. assert A.hilbert_space == HilbertSpace()
  38. assert A*B != B*A
  39. assert (A*(B + C)).expand() == A*B + A*C
  40. assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2
  41. assert t_op.label[0] == Symbol(t_op.default_args()[0])
  42. assert Operator() == Operator("O")
  43. assert A*IdentityOperator() == A
  44. def test_operator_inv():
  45. A = Operator('A')
  46. assert A*A.inv() == 1
  47. assert A.inv()*A == 1
  48. def test_hermitian():
  49. H = HermitianOperator('H')
  50. assert isinstance(H, HermitianOperator)
  51. assert isinstance(H, Operator)
  52. assert Dagger(H) == H
  53. assert H.inv() != H
  54. assert H.is_commutative is False
  55. assert Dagger(H).is_commutative is False
  56. def test_unitary():
  57. U = UnitaryOperator('U')
  58. assert isinstance(U, UnitaryOperator)
  59. assert isinstance(U, Operator)
  60. assert U.inv() == Dagger(U)
  61. assert U*Dagger(U) == 1
  62. assert Dagger(U)*U == 1
  63. assert U.is_commutative is False
  64. assert Dagger(U).is_commutative is False
  65. def test_identity():
  66. I = IdentityOperator()
  67. O = Operator('O')
  68. x = Symbol("x")
  69. assert isinstance(I, IdentityOperator)
  70. assert isinstance(I, Operator)
  71. assert I * O == O
  72. assert O * I == O
  73. assert I * Dagger(O) == Dagger(O)
  74. assert Dagger(O) * I == Dagger(O)
  75. assert isinstance(I * I, IdentityOperator)
  76. assert isinstance(3 * I, Mul)
  77. assert isinstance(I * x, Mul)
  78. assert I.inv() == I
  79. assert Dagger(I) == I
  80. assert qapply(I * O) == O
  81. assert qapply(O * I) == O
  82. for n in [2, 3, 5]:
  83. assert represent(IdentityOperator(n)) == eye(n)
  84. def test_outer_product():
  85. k = Ket('k')
  86. b = Bra('b')
  87. op = OuterProduct(k, b)
  88. assert isinstance(op, OuterProduct)
  89. assert isinstance(op, Operator)
  90. assert op.ket == k
  91. assert op.bra == b
  92. assert op.label == (k, b)
  93. assert op.is_commutative is False
  94. op = k*b
  95. assert isinstance(op, OuterProduct)
  96. assert isinstance(op, Operator)
  97. assert op.ket == k
  98. assert op.bra == b
  99. assert op.label == (k, b)
  100. assert op.is_commutative is False
  101. op = 2*k*b
  102. assert op == Mul(Integer(2), k, b)
  103. op = 2*(k*b)
  104. assert op == Mul(Integer(2), OuterProduct(k, b))
  105. assert Dagger(k*b) == OuterProduct(Dagger(b), Dagger(k))
  106. assert Dagger(k*b).is_commutative is False
  107. #test the _eval_trace
  108. assert Tr(OuterProduct(JzKet(1, 1), JzBra(1, 1))).doit() == 1
  109. # test scaled kets and bras
  110. assert OuterProduct(2 * k, b) == 2 * OuterProduct(k, b)
  111. assert OuterProduct(k, 2 * b) == 2 * OuterProduct(k, b)
  112. # test sums of kets and bras
  113. k1, k2 = Ket('k1'), Ket('k2')
  114. b1, b2 = Bra('b1'), Bra('b2')
  115. assert (OuterProduct(k1 + k2, b1) ==
  116. OuterProduct(k1, b1) + OuterProduct(k2, b1))
  117. assert (OuterProduct(k1, b1 + b2) ==
  118. OuterProduct(k1, b1) + OuterProduct(k1, b2))
  119. assert (OuterProduct(1 * k1 + 2 * k2, 3 * b1 + 4 * b2) ==
  120. 3 * OuterProduct(k1, b1) +
  121. 4 * OuterProduct(k1, b2) +
  122. 6 * OuterProduct(k2, b1) +
  123. 8 * OuterProduct(k2, b2))
  124. def test_operator_dagger():
  125. A = Operator('A')
  126. B = Operator('B')
  127. assert Dagger(A*B) == Dagger(B)*Dagger(A)
  128. assert Dagger(A + B) == Dagger(A) + Dagger(B)
  129. assert Dagger(A**2) == Dagger(A)**2
  130. def test_differential_operator():
  131. x = Symbol('x')
  132. f = Function('f')
  133. d = DifferentialOperator(Derivative(f(x), x), f(x))
  134. g = Wavefunction(x**2, x)
  135. assert qapply(d*g) == Wavefunction(2*x, x)
  136. assert d.expr == Derivative(f(x), x)
  137. assert d.function == f(x)
  138. assert d.variables == (x,)
  139. assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 2), f(x))
  140. d = DifferentialOperator(Derivative(f(x), x, 2), f(x))
  141. g = Wavefunction(x**3, x)
  142. assert qapply(d*g) == Wavefunction(6*x, x)
  143. assert d.expr == Derivative(f(x), x, 2)
  144. assert d.function == f(x)
  145. assert d.variables == (x,)
  146. assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 3), f(x))
  147. d = DifferentialOperator(1/x*Derivative(f(x), x), f(x))
  148. assert d.expr == 1/x*Derivative(f(x), x)
  149. assert d.function == f(x)
  150. assert d.variables == (x,)
  151. assert diff(d, x) == \
  152. DifferentialOperator(Derivative(1/x*Derivative(f(x), x), x), f(x))
  153. assert qapply(d*g) == Wavefunction(3*x, x)
  154. # 2D cartesian Laplacian
  155. y = Symbol('y')
  156. d = DifferentialOperator(Derivative(f(x, y), x, 2) +
  157. Derivative(f(x, y), y, 2), f(x, y))
  158. w = Wavefunction(x**3*y**2 + y**3*x**2, x, y)
  159. assert d.expr == Derivative(f(x, y), x, 2) + Derivative(f(x, y), y, 2)
  160. assert d.function == f(x, y)
  161. assert d.variables == (x, y)
  162. assert diff(d, x) == \
  163. DifferentialOperator(Derivative(d.expr, x), f(x, y))
  164. assert diff(d, y) == \
  165. DifferentialOperator(Derivative(d.expr, y), f(x, y))
  166. assert qapply(d*w) == Wavefunction(2*x**3 + 6*x*y**2 + 6*x**2*y + 2*y**3,
  167. x, y)
  168. # 2D polar Laplacian (th = theta)
  169. r, th = symbols('r th')
  170. d = DifferentialOperator(1/r*Derivative(r*Derivative(f(r, th), r), r) +
  171. 1/(r**2)*Derivative(f(r, th), th, 2), f(r, th))
  172. w = Wavefunction(r**2*sin(th), r, (th, 0, pi))
  173. assert d.expr == \
  174. 1/r*Derivative(r*Derivative(f(r, th), r), r) + \
  175. 1/(r**2)*Derivative(f(r, th), th, 2)
  176. assert d.function == f(r, th)
  177. assert d.variables == (r, th)
  178. assert diff(d, r) == \
  179. DifferentialOperator(Derivative(d.expr, r), f(r, th))
  180. assert diff(d, th) == \
  181. DifferentialOperator(Derivative(d.expr, th), f(r, th))
  182. assert qapply(d*w) == Wavefunction(3*sin(th), r, (th, 0, pi))
  183. def test_eval_power():
  184. from sympy.core import Pow
  185. from sympy.core.expr import unchanged
  186. O = Operator('O')
  187. U = UnitaryOperator('U')
  188. H = HermitianOperator('H')
  189. assert O**-1 == O.inv() # same as doc test
  190. assert U**-1 == U.inv()
  191. assert H**-1 == H.inv()
  192. x = symbols("x", commutative = True)
  193. assert unchanged(Pow, H, x) # verify Pow(H,x)=="X^n"
  194. assert H**x == Pow(H, x)
  195. assert Pow(H,x) == Pow(H, x, evaluate=False) # Just check
  196. from sympy.physics.quantum.gate import XGate
  197. X = XGate(0) # is hermitian and unitary
  198. assert unchanged(Pow, X, x) # verify Pow(X,x)=="X^x"
  199. assert X**x == Pow(X, x)
  200. assert Pow(X, x, evaluate=False) == Pow(X, x) # Just check
  201. n = symbols("n", integer=True, even=True)
  202. assert X**n == 1
  203. n = symbols("n", integer=True, odd=True)
  204. assert X**n == X
  205. n = symbols("n", integer=True)
  206. assert unchanged(Pow, X, n) # verify Pow(X,n)=="X^n"
  207. assert X**n == Pow(X, n)
  208. assert Pow(X, n, evaluate=False)==Pow(X, n) # Just check
  209. assert X**4 == 1
  210. assert X**7 == X