test_implicit_multiplication_application.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. import sympy
  2. from sympy.parsing.sympy_parser import (
  3. parse_expr,
  4. standard_transformations,
  5. convert_xor,
  6. implicit_multiplication_application,
  7. implicit_multiplication,
  8. implicit_application,
  9. function_exponentiation,
  10. split_symbols,
  11. split_symbols_custom,
  12. _token_splittable
  13. )
  14. from sympy.testing.pytest import raises
  15. def test_implicit_multiplication():
  16. cases = {
  17. '5x': '5*x',
  18. 'abc': 'a*b*c',
  19. '3sin(x)': '3*sin(x)',
  20. '(x+1)(x+2)': '(x+1)*(x+2)',
  21. '(5 x**2)sin(x)': '(5*x**2)*sin(x)',
  22. '2 sin(x) cos(x)': '2*sin(x)*cos(x)',
  23. 'pi x': 'pi*x',
  24. 'x pi': 'x*pi',
  25. 'E x': 'E*x',
  26. 'EulerGamma y': 'EulerGamma*y',
  27. 'E pi': 'E*pi',
  28. 'pi (x + 2)': 'pi*(x+2)',
  29. '(x + 2) pi': '(x+2)*pi',
  30. 'pi sin(x)': 'pi*sin(x)',
  31. }
  32. transformations = standard_transformations + (convert_xor,)
  33. transformations2 = transformations + (split_symbols,
  34. implicit_multiplication)
  35. for case in cases:
  36. implicit = parse_expr(case, transformations=transformations2)
  37. normal = parse_expr(cases[case], transformations=transformations)
  38. assert(implicit == normal)
  39. application = ['sin x', 'cos 2*x', 'sin cos x']
  40. for case in application:
  41. raises(SyntaxError,
  42. lambda: parse_expr(case, transformations=transformations2))
  43. raises(TypeError,
  44. lambda: parse_expr('sin**2(x)', transformations=transformations2))
  45. def test_implicit_application():
  46. cases = {
  47. 'factorial': 'factorial',
  48. 'sin x': 'sin(x)',
  49. 'tan y**3': 'tan(y**3)',
  50. 'cos 2*x': 'cos(2*x)',
  51. '(cot)': 'cot',
  52. 'sin cos tan x': 'sin(cos(tan(x)))'
  53. }
  54. transformations = standard_transformations + (convert_xor,)
  55. transformations2 = transformations + (implicit_application,)
  56. for case in cases:
  57. implicit = parse_expr(case, transformations=transformations2)
  58. normal = parse_expr(cases[case], transformations=transformations)
  59. assert(implicit == normal), (implicit, normal)
  60. multiplication = ['x y', 'x sin x', '2x']
  61. for case in multiplication:
  62. raises(SyntaxError,
  63. lambda: parse_expr(case, transformations=transformations2))
  64. raises(TypeError,
  65. lambda: parse_expr('sin**2(x)', transformations=transformations2))
  66. def test_function_exponentiation():
  67. cases = {
  68. 'sin**2(x)': 'sin(x)**2',
  69. 'exp^y(z)': 'exp(z)^y',
  70. 'sin**2(E^(x))': 'sin(E^(x))**2'
  71. }
  72. transformations = standard_transformations + (convert_xor,)
  73. transformations2 = transformations + (function_exponentiation,)
  74. for case in cases:
  75. implicit = parse_expr(case, transformations=transformations2)
  76. normal = parse_expr(cases[case], transformations=transformations)
  77. assert(implicit == normal)
  78. other_implicit = ['x y', 'x sin x', '2x', 'sin x',
  79. 'cos 2*x', 'sin cos x']
  80. for case in other_implicit:
  81. raises(SyntaxError,
  82. lambda: parse_expr(case, transformations=transformations2))
  83. assert parse_expr('x**2', local_dict={ 'x': sympy.Symbol('x') },
  84. transformations=transformations2) == parse_expr('x**2')
  85. def test_symbol_splitting():
  86. # By default Greek letter names should not be split (lambda is a keyword
  87. # so skip it)
  88. transformations = standard_transformations + (split_symbols,)
  89. greek_letters = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta',
  90. 'eta', 'theta', 'iota', 'kappa', 'mu', 'nu', 'xi',
  91. 'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon',
  92. 'phi', 'chi', 'psi', 'omega')
  93. for letter in greek_letters:
  94. assert(parse_expr(letter, transformations=transformations) ==
  95. parse_expr(letter))
  96. # Make sure symbol splitting resolves names
  97. transformations += (implicit_multiplication,)
  98. local_dict = { 'e': sympy.E }
  99. cases = {
  100. 'xe': 'E*x',
  101. 'Iy': 'I*y',
  102. 'ee': 'E*E',
  103. }
  104. for case, expected in cases.items():
  105. assert(parse_expr(case, local_dict=local_dict,
  106. transformations=transformations) ==
  107. parse_expr(expected))
  108. # Make sure custom splitting works
  109. def can_split(symbol):
  110. if symbol not in ('unsplittable', 'names'):
  111. return _token_splittable(symbol)
  112. return False
  113. transformations = standard_transformations
  114. transformations += (split_symbols_custom(can_split),
  115. implicit_multiplication)
  116. assert(parse_expr('unsplittable', transformations=transformations) ==
  117. parse_expr('unsplittable'))
  118. assert(parse_expr('names', transformations=transformations) ==
  119. parse_expr('names'))
  120. assert(parse_expr('xy', transformations=transformations) ==
  121. parse_expr('x*y'))
  122. for letter in greek_letters:
  123. assert(parse_expr(letter, transformations=transformations) ==
  124. parse_expr(letter))
  125. def test_all_implicit_steps():
  126. cases = {
  127. '2x': '2*x', # implicit multiplication
  128. 'x y': 'x*y',
  129. 'xy': 'x*y',
  130. 'sin x': 'sin(x)', # add parentheses
  131. '2sin x': '2*sin(x)',
  132. 'x y z': 'x*y*z',
  133. 'sin(2 * 3x)': 'sin(2 * 3 * x)',
  134. 'sin(x) (1 + cos(x))': 'sin(x) * (1 + cos(x))',
  135. '(x + 2) sin(x)': '(x + 2) * sin(x)',
  136. '(x + 2) sin x': '(x + 2) * sin(x)',
  137. 'sin(sin x)': 'sin(sin(x))',
  138. 'sin x!': 'sin(factorial(x))',
  139. 'sin x!!': 'sin(factorial2(x))',
  140. 'factorial': 'factorial', # don't apply a bare function
  141. 'x sin x': 'x * sin(x)', # both application and multiplication
  142. 'xy sin x': 'x * y * sin(x)',
  143. '(x+2)(x+3)': '(x + 2) * (x+3)',
  144. 'x**2 + 2xy + y**2': 'x**2 + 2 * x * y + y**2', # split the xy
  145. 'pi': 'pi', # don't mess with constants
  146. 'None': 'None',
  147. 'ln sin x': 'ln(sin(x))', # multiple implicit function applications
  148. 'factorial': 'factorial', # don't add parentheses
  149. 'sin x**2': 'sin(x**2)', # implicit application to an exponential
  150. 'alpha': 'Symbol("alpha")', # don't split Greek letters/subscripts
  151. 'x_2': 'Symbol("x_2")',
  152. 'sin^2 x**2': 'sin(x**2)**2', # function raised to a power
  153. 'sin**3(x)': 'sin(x)**3',
  154. '(factorial)': 'factorial',
  155. 'tan 3x': 'tan(3*x)',
  156. 'sin^2(3*E^(x))': 'sin(3*E**(x))**2',
  157. 'sin**2(E^(3x))': 'sin(E**(3*x))**2',
  158. 'sin^2 (3x*E^(x))': 'sin(3*x*E^x)**2',
  159. 'pi sin x': 'pi*sin(x)',
  160. }
  161. transformations = standard_transformations + (convert_xor,)
  162. transformations2 = transformations + (implicit_multiplication_application,)
  163. for case in cases:
  164. implicit = parse_expr(case, transformations=transformations2)
  165. normal = parse_expr(cases[case], transformations=transformations)
  166. assert(implicit == normal)
  167. def test_no_methods_implicit_multiplication():
  168. # Issue 21020
  169. u = sympy.Symbol('u')
  170. transformations = standard_transformations + \
  171. (implicit_multiplication,)
  172. expr = parse_expr('x.is_polynomial(x)', transformations=transformations)
  173. assert expr == True
  174. expr = parse_expr('(exp(x) / (1 + exp(2x))).subs(exp(x), u)',
  175. transformations=transformations)
  176. assert expr == u/(u**2 + 1)