test_symbol.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import threading
  2. from sympy.core.function import Function, UndefinedFunction
  3. from sympy.core.numbers import (I, Rational, pi)
  4. from sympy.core.relational import (GreaterThan, LessThan, StrictGreaterThan, StrictLessThan)
  5. from sympy.core.symbol import (Dummy, Symbol, Wild, symbols)
  6. from sympy.core.sympify import sympify # can't import as S yet
  7. from sympy.core.symbol import uniquely_named_symbol, _symbol, Str
  8. from sympy.testing.pytest import raises, skip_under_pyodide
  9. from sympy.core.symbol import disambiguate
  10. def test_Str():
  11. a1 = Str('a')
  12. a2 = Str('a')
  13. b = Str('b')
  14. assert a1 == a2 != b
  15. raises(TypeError, lambda: Str())
  16. def test_Symbol():
  17. a = Symbol("a")
  18. x1 = Symbol("x")
  19. x2 = Symbol("x")
  20. xdummy1 = Dummy("x")
  21. xdummy2 = Dummy("x")
  22. assert a != x1
  23. assert a != x2
  24. assert x1 == x2
  25. assert x1 != xdummy1
  26. assert xdummy1 != xdummy2
  27. assert Symbol("x") == Symbol("x")
  28. assert Dummy("x") != Dummy("x")
  29. d = symbols('d', cls=Dummy)
  30. assert isinstance(d, Dummy)
  31. c, d = symbols('c,d', cls=Dummy)
  32. assert isinstance(c, Dummy)
  33. assert isinstance(d, Dummy)
  34. raises(TypeError, lambda: Symbol())
  35. def test_Dummy():
  36. assert Dummy() != Dummy()
  37. def test_Dummy_force_dummy_index():
  38. raises(AssertionError, lambda: Dummy(dummy_index=1))
  39. assert Dummy('d', dummy_index=2) == Dummy('d', dummy_index=2)
  40. assert Dummy('d1', dummy_index=2) != Dummy('d2', dummy_index=2)
  41. d1 = Dummy('d', dummy_index=3)
  42. d2 = Dummy('d')
  43. # might fail if d1 were created with dummy_index >= 10**6
  44. assert d1 != d2
  45. d3 = Dummy('d', dummy_index=3)
  46. assert d1 == d3
  47. assert Dummy()._count == Dummy('d', dummy_index=3)._count
  48. def test_lt_gt():
  49. S = sympify
  50. x, y = Symbol('x'), Symbol('y')
  51. assert (x >= y) == GreaterThan(x, y)
  52. assert (x >= 0) == GreaterThan(x, 0)
  53. assert (x <= y) == LessThan(x, y)
  54. assert (x <= 0) == LessThan(x, 0)
  55. assert (0 <= x) == GreaterThan(x, 0)
  56. assert (0 >= x) == LessThan(x, 0)
  57. assert (S(0) >= x) == GreaterThan(0, x)
  58. assert (S(0) <= x) == LessThan(0, x)
  59. assert (x > y) == StrictGreaterThan(x, y)
  60. assert (x > 0) == StrictGreaterThan(x, 0)
  61. assert (x < y) == StrictLessThan(x, y)
  62. assert (x < 0) == StrictLessThan(x, 0)
  63. assert (0 < x) == StrictGreaterThan(x, 0)
  64. assert (0 > x) == StrictLessThan(x, 0)
  65. assert (S(0) > x) == StrictGreaterThan(0, x)
  66. assert (S(0) < x) == StrictLessThan(0, x)
  67. e = x**2 + 4*x + 1
  68. assert (e >= 0) == GreaterThan(e, 0)
  69. assert (0 <= e) == GreaterThan(e, 0)
  70. assert (e > 0) == StrictGreaterThan(e, 0)
  71. assert (0 < e) == StrictGreaterThan(e, 0)
  72. assert (e <= 0) == LessThan(e, 0)
  73. assert (0 >= e) == LessThan(e, 0)
  74. assert (e < 0) == StrictLessThan(e, 0)
  75. assert (0 > e) == StrictLessThan(e, 0)
  76. assert (S(0) >= e) == GreaterThan(0, e)
  77. assert (S(0) <= e) == LessThan(0, e)
  78. assert (S(0) < e) == StrictLessThan(0, e)
  79. assert (S(0) > e) == StrictGreaterThan(0, e)
  80. def test_no_len():
  81. # there should be no len for numbers
  82. x = Symbol('x')
  83. raises(TypeError, lambda: len(x))
  84. def test_ineq_unequal():
  85. S = sympify
  86. x, y, z = symbols('x,y,z')
  87. e = (
  88. S(-1) >= x, S(-1) >= y, S(-1) >= z,
  89. S(-1) > x, S(-1) > y, S(-1) > z,
  90. S(-1) <= x, S(-1) <= y, S(-1) <= z,
  91. S(-1) < x, S(-1) < y, S(-1) < z,
  92. S(0) >= x, S(0) >= y, S(0) >= z,
  93. S(0) > x, S(0) > y, S(0) > z,
  94. S(0) <= x, S(0) <= y, S(0) <= z,
  95. S(0) < x, S(0) < y, S(0) < z,
  96. S('3/7') >= x, S('3/7') >= y, S('3/7') >= z,
  97. S('3/7') > x, S('3/7') > y, S('3/7') > z,
  98. S('3/7') <= x, S('3/7') <= y, S('3/7') <= z,
  99. S('3/7') < x, S('3/7') < y, S('3/7') < z,
  100. S(1.5) >= x, S(1.5) >= y, S(1.5) >= z,
  101. S(1.5) > x, S(1.5) > y, S(1.5) > z,
  102. S(1.5) <= x, S(1.5) <= y, S(1.5) <= z,
  103. S(1.5) < x, S(1.5) < y, S(1.5) < z,
  104. S(2) >= x, S(2) >= y, S(2) >= z,
  105. S(2) > x, S(2) > y, S(2) > z,
  106. S(2) <= x, S(2) <= y, S(2) <= z,
  107. S(2) < x, S(2) < y, S(2) < z,
  108. x >= -1, y >= -1, z >= -1,
  109. x > -1, y > -1, z > -1,
  110. x <= -1, y <= -1, z <= -1,
  111. x < -1, y < -1, z < -1,
  112. x >= 0, y >= 0, z >= 0,
  113. x > 0, y > 0, z > 0,
  114. x <= 0, y <= 0, z <= 0,
  115. x < 0, y < 0, z < 0,
  116. x >= 1.5, y >= 1.5, z >= 1.5,
  117. x > 1.5, y > 1.5, z > 1.5,
  118. x <= 1.5, y <= 1.5, z <= 1.5,
  119. x < 1.5, y < 1.5, z < 1.5,
  120. x >= 2, y >= 2, z >= 2,
  121. x > 2, y > 2, z > 2,
  122. x <= 2, y <= 2, z <= 2,
  123. x < 2, y < 2, z < 2,
  124. x >= y, x >= z, y >= x, y >= z, z >= x, z >= y,
  125. x > y, x > z, y > x, y > z, z > x, z > y,
  126. x <= y, x <= z, y <= x, y <= z, z <= x, z <= y,
  127. x < y, x < z, y < x, y < z, z < x, z < y,
  128. x - pi >= y + z, y - pi >= x + z, z - pi >= x + y,
  129. x - pi > y + z, y - pi > x + z, z - pi > x + y,
  130. x - pi <= y + z, y - pi <= x + z, z - pi <= x + y,
  131. x - pi < y + z, y - pi < x + z, z - pi < x + y,
  132. True, False
  133. )
  134. left_e = e[:-1]
  135. for i, e1 in enumerate( left_e ):
  136. for e2 in e[i + 1:]:
  137. assert e1 != e2
  138. def test_Wild_properties():
  139. S = sympify
  140. # these tests only include Atoms
  141. x = Symbol("x")
  142. y = Symbol("y")
  143. p = Symbol("p", positive=True)
  144. k = Symbol("k", integer=True)
  145. n = Symbol("n", integer=True, positive=True)
  146. given_patterns = [ x, y, p, k, -k, n, -n, S(-3), S(3),
  147. pi, Rational(3, 2), I ]
  148. integerp = lambda k: k.is_integer
  149. positivep = lambda k: k.is_positive
  150. symbolp = lambda k: k.is_Symbol
  151. realp = lambda k: k.is_extended_real
  152. S = Wild("S", properties=[symbolp])
  153. R = Wild("R", properties=[realp])
  154. Y = Wild("Y", exclude=[x, p, k, n])
  155. P = Wild("P", properties=[positivep])
  156. K = Wild("K", properties=[integerp])
  157. N = Wild("N", properties=[positivep, integerp])
  158. given_wildcards = [ S, R, Y, P, K, N ]
  159. goodmatch = {
  160. S: (x, y, p, k, n),
  161. R: (p, k, -k, n, -n, -3, 3, pi, Rational(3, 2)),
  162. Y: (y, -3, 3, pi, Rational(3, 2), I ),
  163. P: (p, n, 3, pi, Rational(3, 2)),
  164. K: (k, -k, n, -n, -3, 3),
  165. N: (n, 3)}
  166. for A in given_wildcards:
  167. for pat in given_patterns:
  168. d = pat.match(A)
  169. if pat in goodmatch[A]:
  170. assert d[A] in goodmatch[A]
  171. else:
  172. assert d is None
  173. def test_symbols():
  174. x = Symbol('x')
  175. y = Symbol('y')
  176. z = Symbol('z')
  177. assert symbols('x') == x
  178. assert symbols('x ') == x
  179. assert symbols(' x ') == x
  180. assert symbols('x,') == (x,)
  181. assert symbols('x, ') == (x,)
  182. assert symbols('x ,') == (x,)
  183. assert symbols('x , y') == (x, y)
  184. assert symbols('x,y,z') == (x, y, z)
  185. assert symbols('x y z') == (x, y, z)
  186. assert symbols('x,y,z,') == (x, y, z)
  187. assert symbols('x y z ') == (x, y, z)
  188. xyz = Symbol('xyz')
  189. abc = Symbol('abc')
  190. assert symbols('xyz') == xyz
  191. assert symbols('xyz,') == (xyz,)
  192. assert symbols('xyz,abc') == (xyz, abc)
  193. assert symbols(('xyz',)) == (xyz,)
  194. assert symbols(('xyz,',)) == ((xyz,),)
  195. assert symbols(('x,y,z,',)) == ((x, y, z),)
  196. assert symbols(('xyz', 'abc')) == (xyz, abc)
  197. assert symbols(('xyz,abc',)) == ((xyz, abc),)
  198. assert symbols(('xyz,abc', 'x,y,z')) == ((xyz, abc), (x, y, z))
  199. assert symbols(('x', 'y', 'z')) == (x, y, z)
  200. assert symbols(['x', 'y', 'z']) == [x, y, z]
  201. assert symbols({'x', 'y', 'z'}) == {x, y, z}
  202. raises(ValueError, lambda: symbols(''))
  203. raises(ValueError, lambda: symbols(','))
  204. raises(ValueError, lambda: symbols('x,,y,,z'))
  205. raises(ValueError, lambda: symbols(('x', '', 'y', '', 'z')))
  206. a, b = symbols('x,y', real=True)
  207. assert a.is_real and b.is_real
  208. x0 = Symbol('x0')
  209. x1 = Symbol('x1')
  210. x2 = Symbol('x2')
  211. y0 = Symbol('y0')
  212. y1 = Symbol('y1')
  213. assert symbols('x0:0') == ()
  214. assert symbols('x0:1') == (x0,)
  215. assert symbols('x0:2') == (x0, x1)
  216. assert symbols('x0:3') == (x0, x1, x2)
  217. assert symbols('x:0') == ()
  218. assert symbols('x:1') == (x0,)
  219. assert symbols('x:2') == (x0, x1)
  220. assert symbols('x:3') == (x0, x1, x2)
  221. assert symbols('x1:1') == ()
  222. assert symbols('x1:2') == (x1,)
  223. assert symbols('x1:3') == (x1, x2)
  224. assert symbols('x1:3,x,y,z') == (x1, x2, x, y, z)
  225. assert symbols('x:3,y:2') == (x0, x1, x2, y0, y1)
  226. assert symbols(('x:3', 'y:2')) == ((x0, x1, x2), (y0, y1))
  227. a = Symbol('a')
  228. b = Symbol('b')
  229. c = Symbol('c')
  230. d = Symbol('d')
  231. assert symbols('x:z') == (x, y, z)
  232. assert symbols('a:d,x:z') == (a, b, c, d, x, y, z)
  233. assert symbols(('a:d', 'x:z')) == ((a, b, c, d), (x, y, z))
  234. aa = Symbol('aa')
  235. ab = Symbol('ab')
  236. ac = Symbol('ac')
  237. ad = Symbol('ad')
  238. assert symbols('aa:d') == (aa, ab, ac, ad)
  239. assert symbols('aa:d,x:z') == (aa, ab, ac, ad, x, y, z)
  240. assert symbols(('aa:d','x:z')) == ((aa, ab, ac, ad), (x, y, z))
  241. assert type(symbols(('q:2', 'u:2'), cls=Function)[0][0]) == UndefinedFunction # issue 23532
  242. # issue 6675
  243. def sym(s):
  244. return str(symbols(s))
  245. assert sym('a0:4') == '(a0, a1, a2, a3)'
  246. assert sym('a2:4,b1:3') == '(a2, a3, b1, b2)'
  247. assert sym('a1(2:4)') == '(a12, a13)'
  248. assert sym('a0:2.0:2') == '(a0.0, a0.1, a1.0, a1.1)'
  249. assert sym('aa:cz') == '(aaz, abz, acz)'
  250. assert sym('aa:c0:2') == '(aa0, aa1, ab0, ab1, ac0, ac1)'
  251. assert sym('aa:ba:b') == '(aaa, aab, aba, abb)'
  252. assert sym('a:3b') == '(a0b, a1b, a2b)'
  253. assert sym('a-1:3b') == '(a-1b, a-2b)'
  254. assert sym(r'a:2\,:2' + chr(0)) == '(a0,0%s, a0,1%s, a1,0%s, a1,1%s)' % (
  255. (chr(0),)*4)
  256. assert sym('x(:a:3)') == '(x(a0), x(a1), x(a2))'
  257. assert sym('x(:c):1') == '(xa0, xb0, xc0)'
  258. assert sym('x((:a)):3') == '(x(a)0, x(a)1, x(a)2)'
  259. assert sym('x(:a:3') == '(x(a0, x(a1, x(a2)'
  260. assert sym(':2') == '(0, 1)'
  261. assert sym(':b') == '(a, b)'
  262. assert sym(':b:2') == '(a0, a1, b0, b1)'
  263. assert sym(':2:2') == '(00, 01, 10, 11)'
  264. assert sym(':b:b') == '(aa, ab, ba, bb)'
  265. raises(ValueError, lambda: symbols(':'))
  266. raises(ValueError, lambda: symbols('a:'))
  267. raises(ValueError, lambda: symbols('::'))
  268. raises(ValueError, lambda: symbols('a::'))
  269. raises(ValueError, lambda: symbols(':a:'))
  270. raises(ValueError, lambda: symbols('::a'))
  271. def test_symbols_become_functions_issue_3539():
  272. from sympy.abc import alpha, phi, beta, t
  273. raises(TypeError, lambda: beta(2))
  274. raises(TypeError, lambda: beta(2.5))
  275. raises(TypeError, lambda: phi(2.5))
  276. raises(TypeError, lambda: alpha(2.5))
  277. raises(TypeError, lambda: phi(t))
  278. def test_unicode():
  279. xu = Symbol('x')
  280. x = Symbol('x')
  281. assert x == xu
  282. raises(TypeError, lambda: Symbol(1))
  283. def test_uniquely_named_symbol_and_Symbol():
  284. F = uniquely_named_symbol
  285. x = Symbol('x')
  286. assert F(x) == x
  287. assert F('x') == x
  288. assert str(F('x', x)) == 'x0'
  289. assert str(F('x', (x + 1, 1/x))) == 'x0'
  290. _x = Symbol('x', real=True)
  291. assert F(('x', _x)) == _x
  292. assert F((x, _x)) == _x
  293. assert F('x', real=True).is_real
  294. y = Symbol('y')
  295. assert F(('x', y), real=True).is_real
  296. r = Symbol('x', real=True)
  297. assert F(('x', r)).is_real
  298. assert F(('x', r), real=False).is_real
  299. assert F('x1', Symbol('x1'),
  300. compare=lambda i: str(i).rstrip('1')).name == 'x0'
  301. assert F('x1', Symbol('x1'),
  302. modify=lambda i: i + '_').name == 'x1_'
  303. assert _symbol(x, _x) == x
  304. def test_disambiguate():
  305. x, y, y_1, _x, x_1, x_2 = symbols('x y y_1 _x x_1 x_2')
  306. t1 = Dummy('y'), _x, Dummy('x'), Dummy('x')
  307. t2 = Dummy('x'), Dummy('x')
  308. t3 = Dummy('x'), Dummy('y')
  309. t4 = x, Dummy('x')
  310. t5 = Symbol('x', integer=True), x, Symbol('x_1')
  311. assert disambiguate(*t1) == (y, x_2, x, x_1)
  312. assert disambiguate(*t2) == (x, x_1)
  313. assert disambiguate(*t3) == (x, y)
  314. assert disambiguate(*t4) == (x_1, x)
  315. assert disambiguate(*t5) == (t5[0], x_2, x_1)
  316. assert disambiguate(*t5)[0] != x # assumptions are retained
  317. t6 = _x, Dummy('x')/y
  318. t7 = y*Dummy('y'), y
  319. assert disambiguate(*t6) == (x_1, x/y)
  320. assert disambiguate(*t7) == (y*y_1, y_1)
  321. assert disambiguate(Dummy('x_1'), Dummy('x_1')
  322. ) == (x_1, Symbol('x_1_1'))
  323. @skip_under_pyodide("Cannot create threads under pyodide.")
  324. def test_issue_gh_16734():
  325. # https://github.com/sympy/sympy/issues/16734
  326. syms = list(symbols('x, y'))
  327. def thread1():
  328. for n in range(1000):
  329. syms[0], syms[1] = symbols(f'x{n}, y{n}')
  330. syms[0].is_positive # Check an assumption in this thread.
  331. syms[0] = None
  332. def thread2():
  333. while syms[0] is not None:
  334. # Compare the symbol in this thread.
  335. result = (syms[0] == syms[1]) # noqa
  336. # Previously this would be very likely to raise an exception:
  337. thread = threading.Thread(target=thread1)
  338. thread.start()
  339. thread2()
  340. thread.join()