test_state.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from sympy.core.add import Add
  2. from sympy.core.function import diff
  3. from sympy.core.mul import Mul
  4. from sympy.core.numbers import (I, Integer, Rational, oo, pi)
  5. from sympy.core.power import Pow
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import (Symbol, symbols)
  8. from sympy.core.sympify import sympify
  9. from sympy.functions.elementary.complexes import conjugate
  10. from sympy.functions.elementary.miscellaneous import sqrt
  11. from sympy.functions.elementary.trigonometric import sin
  12. from sympy.testing.pytest import raises
  13. from sympy.physics.quantum.dagger import Dagger
  14. from sympy.physics.quantum.qexpr import QExpr
  15. from sympy.physics.quantum.state import (
  16. Ket, Bra, TimeDepKet, TimeDepBra,
  17. KetBase, BraBase, StateBase, Wavefunction,
  18. OrthogonalKet, OrthogonalBra
  19. )
  20. from sympy.physics.quantum.hilbert import HilbertSpace
  21. x, y, t = symbols('x,y,t')
  22. class CustomKet(Ket):
  23. @classmethod
  24. def default_args(self):
  25. return ("test",)
  26. class CustomKetMultipleLabels(Ket):
  27. @classmethod
  28. def default_args(self):
  29. return ("r", "theta", "phi")
  30. class CustomTimeDepKet(TimeDepKet):
  31. @classmethod
  32. def default_args(self):
  33. return ("test", "t")
  34. class CustomTimeDepKetMultipleLabels(TimeDepKet):
  35. @classmethod
  36. def default_args(self):
  37. return ("r", "theta", "phi", "t")
  38. def test_ket():
  39. k = Ket('0')
  40. assert isinstance(k, Ket)
  41. assert isinstance(k, KetBase)
  42. assert isinstance(k, StateBase)
  43. assert isinstance(k, QExpr)
  44. assert k.label == (Symbol('0'),)
  45. assert k.hilbert_space == HilbertSpace()
  46. assert k.is_commutative is False
  47. # Make sure this doesn't get converted to the number pi.
  48. k = Ket('pi')
  49. assert k.label == (Symbol('pi'),)
  50. k = Ket(x, y)
  51. assert k.label == (x, y)
  52. assert k.hilbert_space == HilbertSpace()
  53. assert k.is_commutative is False
  54. assert k.dual_class() == Bra
  55. assert k.dual == Bra(x, y)
  56. assert k.subs(x, y) == Ket(y, y)
  57. k = CustomKet()
  58. assert k == CustomKet("test")
  59. k = CustomKetMultipleLabels()
  60. assert k == CustomKetMultipleLabels("r", "theta", "phi")
  61. assert Ket() == Ket('psi')
  62. def test_bra():
  63. b = Bra('0')
  64. assert isinstance(b, Bra)
  65. assert isinstance(b, BraBase)
  66. assert isinstance(b, StateBase)
  67. assert isinstance(b, QExpr)
  68. assert b.label == (Symbol('0'),)
  69. assert b.hilbert_space == HilbertSpace()
  70. assert b.is_commutative is False
  71. # Make sure this doesn't get converted to the number pi.
  72. b = Bra('pi')
  73. assert b.label == (Symbol('pi'),)
  74. b = Bra(x, y)
  75. assert b.label == (x, y)
  76. assert b.hilbert_space == HilbertSpace()
  77. assert b.is_commutative is False
  78. assert b.dual_class() == Ket
  79. assert b.dual == Ket(x, y)
  80. assert b.subs(x, y) == Bra(y, y)
  81. assert Bra() == Bra('psi')
  82. def test_ops():
  83. k0 = Ket(0)
  84. k1 = Ket(1)
  85. k = 2*I*k0 - (x/sqrt(2))*k1
  86. assert k == Add(Mul(2, I, k0),
  87. Mul(Rational(-1, 2), x, Pow(2, S.Half), k1))
  88. def test_time_dep_ket():
  89. k = TimeDepKet(0, t)
  90. assert isinstance(k, TimeDepKet)
  91. assert isinstance(k, KetBase)
  92. assert isinstance(k, StateBase)
  93. assert isinstance(k, QExpr)
  94. assert k.label == (Integer(0),)
  95. assert k.args == (Integer(0), t)
  96. assert k.time == t
  97. assert k.dual_class() == TimeDepBra
  98. assert k.dual == TimeDepBra(0, t)
  99. assert k.subs(t, 2) == TimeDepKet(0, 2)
  100. k = TimeDepKet(x, 0.5)
  101. assert k.label == (x,)
  102. assert k.args == (x, sympify(0.5))
  103. k = CustomTimeDepKet()
  104. assert k.label == (Symbol("test"),)
  105. assert k.time == Symbol("t")
  106. assert k == CustomTimeDepKet("test", "t")
  107. k = CustomTimeDepKetMultipleLabels()
  108. assert k.label == (Symbol("r"), Symbol("theta"), Symbol("phi"))
  109. assert k.time == Symbol("t")
  110. assert k == CustomTimeDepKetMultipleLabels("r", "theta", "phi", "t")
  111. assert TimeDepKet() == TimeDepKet("psi", "t")
  112. def test_time_dep_bra():
  113. b = TimeDepBra(0, t)
  114. assert isinstance(b, TimeDepBra)
  115. assert isinstance(b, BraBase)
  116. assert isinstance(b, StateBase)
  117. assert isinstance(b, QExpr)
  118. assert b.label == (Integer(0),)
  119. assert b.args == (Integer(0), t)
  120. assert b.time == t
  121. assert b.dual_class() == TimeDepKet
  122. assert b.dual == TimeDepKet(0, t)
  123. k = TimeDepBra(x, 0.5)
  124. assert k.label == (x,)
  125. assert k.args == (x, sympify(0.5))
  126. assert TimeDepBra() == TimeDepBra("psi", "t")
  127. def test_bra_ket_dagger():
  128. x = symbols('x', complex=True)
  129. k = Ket('k')
  130. b = Bra('b')
  131. assert Dagger(k) == Bra('k')
  132. assert Dagger(b) == Ket('b')
  133. assert Dagger(k).is_commutative is False
  134. k2 = Ket('k2')
  135. e = 2*I*k + x*k2
  136. assert Dagger(e) == conjugate(x)*Dagger(k2) - 2*I*Dagger(k)
  137. def test_wavefunction():
  138. x, y = symbols('x y', real=True)
  139. L = symbols('L', positive=True)
  140. n = symbols('n', integer=True, positive=True)
  141. f = Wavefunction(x**2, x)
  142. p = f.prob()
  143. lims = f.limits
  144. assert f.is_normalized is False
  145. assert f.norm is oo
  146. assert f(10) == 100
  147. assert p(10) == 10000
  148. assert lims[x] == (-oo, oo)
  149. assert diff(f, x) == Wavefunction(2*x, x)
  150. raises(NotImplementedError, lambda: f.normalize())
  151. assert conjugate(f) == Wavefunction(conjugate(f.expr), x)
  152. assert conjugate(f) == Dagger(f)
  153. g = Wavefunction(x**2*y + y**2*x, (x, 0, 1), (y, 0, 2))
  154. lims_g = g.limits
  155. assert lims_g[x] == (0, 1)
  156. assert lims_g[y] == (0, 2)
  157. assert g.is_normalized is False
  158. assert g.norm == sqrt(42)/3
  159. assert g(2, 4) == 0
  160. assert g(1, 1) == 2
  161. assert diff(diff(g, x), y) == Wavefunction(2*x + 2*y, (x, 0, 1), (y, 0, 2))
  162. assert conjugate(g) == Wavefunction(conjugate(g.expr), *g.args[1:])
  163. assert conjugate(g) == Dagger(g)
  164. h = Wavefunction(sqrt(5)*x**2, (x, 0, 1))
  165. assert h.is_normalized is True
  166. assert h.normalize() == h
  167. assert conjugate(h) == Wavefunction(conjugate(h.expr), (x, 0, 1))
  168. assert conjugate(h) == Dagger(h)
  169. piab = Wavefunction(sin(n*pi*x/L), (x, 0, L))
  170. assert piab.norm == sqrt(L/2)
  171. assert piab(L + 1) == 0
  172. assert piab(0.5) == sin(0.5*n*pi/L)
  173. assert piab(0.5, n=1, L=1) == sin(0.5*pi)
  174. assert piab.normalize() == \
  175. Wavefunction(sqrt(2)/sqrt(L)*sin(n*pi*x/L), (x, 0, L))
  176. assert conjugate(piab) == Wavefunction(conjugate(piab.expr), (x, 0, L))
  177. assert conjugate(piab) == Dagger(piab)
  178. k = Wavefunction(x**2, 'x')
  179. assert type(k.variables[0]) == Symbol
  180. def test_orthogonal_states():
  181. braket = OrthogonalBra(x) * OrthogonalKet(x)
  182. assert braket.doit() == 1
  183. braket = OrthogonalBra(x) * OrthogonalKet(x+1)
  184. assert braket.doit() == 0
  185. braket = OrthogonalBra(x) * OrthogonalKet(y)
  186. assert braket.doit() == braket