mathematica.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. from __future__ import annotations
  2. import re
  3. import typing
  4. from itertools import product
  5. from typing import Any, Callable
  6. import sympy
  7. from sympy import Mul, Add, Pow, log, exp, sqrt, cos, sin, tan, asin, acos, acot, asec, acsc, sinh, cosh, tanh, asinh, \
  8. acosh, atanh, acoth, asech, acsch, expand, im, flatten, polylog, cancel, expand_trig, sign, simplify, \
  9. UnevaluatedExpr, S, atan, atan2, Mod, Max, Min, rf, Ei, Si, Ci, airyai, airyaiprime, airybi, primepi, prime, \
  10. isprime, cot, sec, csc, csch, sech, coth, Function, I, pi, Tuple, GreaterThan, StrictGreaterThan, StrictLessThan, \
  11. LessThan, Equality, Or, And, Lambda, Integer, Dummy, symbols
  12. from sympy.core.sympify import sympify, _sympify
  13. from sympy.functions.special.bessel import airybiprime
  14. from sympy.functions.special.error_functions import li
  15. from sympy.utilities.exceptions import sympy_deprecation_warning
  16. def mathematica(s, additional_translations=None):
  17. sympy_deprecation_warning(
  18. """The ``mathematica`` function for the Mathematica parser is now
  19. deprecated. Use ``parse_mathematica`` instead.
  20. The parameter ``additional_translation`` can be replaced by SymPy's
  21. .replace( ) or .subs( ) methods on the output expression instead.""",
  22. deprecated_since_version="1.11",
  23. active_deprecations_target="mathematica-parser-new",
  24. )
  25. parser = MathematicaParser(additional_translations)
  26. return sympify(parser._parse_old(s))
  27. def parse_mathematica(s):
  28. """
  29. Translate a string containing a Wolfram Mathematica expression to a SymPy
  30. expression.
  31. If the translator is unable to find a suitable SymPy expression, the
  32. ``FullForm`` of the Mathematica expression will be output, using SymPy
  33. ``Function`` objects as nodes of the syntax tree.
  34. Examples
  35. ========
  36. >>> from sympy.parsing.mathematica import parse_mathematica
  37. >>> parse_mathematica("Sin[x]^2 Tan[y]")
  38. sin(x)**2*tan(y)
  39. >>> e = parse_mathematica("F[7,5,3]")
  40. >>> e
  41. F(7, 5, 3)
  42. >>> from sympy import Function, Max, Min
  43. >>> e.replace(Function("F"), lambda *x: Max(*x)*Min(*x))
  44. 21
  45. Both standard input form and Mathematica full form are supported:
  46. >>> parse_mathematica("x*(a + b)")
  47. x*(a + b)
  48. >>> parse_mathematica("Times[x, Plus[a, b]]")
  49. x*(a + b)
  50. To get a matrix from Wolfram's code:
  51. >>> m = parse_mathematica("{{a, b}, {c, d}}")
  52. >>> m
  53. ((a, b), (c, d))
  54. >>> from sympy import Matrix
  55. >>> Matrix(m)
  56. Matrix([
  57. [a, b],
  58. [c, d]])
  59. If the translation into equivalent SymPy expressions fails, an SymPy
  60. expression equivalent to Wolfram Mathematica's "FullForm" will be created:
  61. >>> parse_mathematica("x_.")
  62. Optional(Pattern(x, Blank()))
  63. >>> parse_mathematica("Plus @@ {x, y, z}")
  64. Apply(Plus, (x, y, z))
  65. >>> parse_mathematica("f[x_, 3] := x^3 /; x > 0")
  66. SetDelayed(f(Pattern(x, Blank()), 3), Condition(x**3, x > 0))
  67. """
  68. parser = MathematicaParser()
  69. return parser.parse(s)
  70. def _parse_Function(*args):
  71. if len(args) == 1:
  72. arg = args[0]
  73. Slot = Function("Slot")
  74. slots = arg.atoms(Slot)
  75. numbers = [a.args[0] for a in slots]
  76. number_of_arguments = max(numbers)
  77. if isinstance(number_of_arguments, Integer):
  78. variables = symbols(f"dummy0:{number_of_arguments}", cls=Dummy)
  79. return Lambda(variables, arg.xreplace({Slot(i+1): v for i, v in enumerate(variables)}))
  80. return Lambda((), arg)
  81. elif len(args) == 2:
  82. variables = args[0]
  83. body = args[1]
  84. return Lambda(variables, body)
  85. else:
  86. raise SyntaxError("Function node expects 1 or 2 arguments")
  87. def _deco(cls):
  88. cls._initialize_class()
  89. return cls
  90. @_deco
  91. class MathematicaParser:
  92. """
  93. An instance of this class converts a string of a Wolfram Mathematica
  94. expression to a SymPy expression.
  95. The main parser acts internally in three stages:
  96. 1. tokenizer: tokenizes the Mathematica expression and adds the missing *
  97. operators. Handled by ``_from_mathematica_to_tokens(...)``
  98. 2. full form list: sort the list of strings output by the tokenizer into a
  99. syntax tree of nested lists and strings, equivalent to Mathematica's
  100. ``FullForm`` expression output. This is handled by the function
  101. ``_from_tokens_to_fullformlist(...)``.
  102. 3. SymPy expression: the syntax tree expressed as full form list is visited
  103. and the nodes with equivalent classes in SymPy are replaced. Unknown
  104. syntax tree nodes are cast to SymPy ``Function`` objects. This is
  105. handled by ``_from_fullformlist_to_sympy(...)``.
  106. """
  107. # left: Mathematica, right: SymPy
  108. CORRESPONDENCES = {
  109. 'Sqrt[x]': 'sqrt(x)',
  110. 'Exp[x]': 'exp(x)',
  111. 'Log[x]': 'log(x)',
  112. 'Log[x,y]': 'log(y,x)',
  113. 'Log2[x]': 'log(x,2)',
  114. 'Log10[x]': 'log(x,10)',
  115. 'Mod[x,y]': 'Mod(x,y)',
  116. 'Max[*x]': 'Max(*x)',
  117. 'Min[*x]': 'Min(*x)',
  118. 'Pochhammer[x,y]':'rf(x,y)',
  119. 'ArcTan[x,y]':'atan2(y,x)',
  120. 'ExpIntegralEi[x]': 'Ei(x)',
  121. 'SinIntegral[x]': 'Si(x)',
  122. 'CosIntegral[x]': 'Ci(x)',
  123. 'AiryAi[x]': 'airyai(x)',
  124. 'AiryAiPrime[x]': 'airyaiprime(x)',
  125. 'AiryBi[x]' :'airybi(x)',
  126. 'AiryBiPrime[x]' :'airybiprime(x)',
  127. 'LogIntegral[x]':' li(x)',
  128. 'PrimePi[x]': 'primepi(x)',
  129. 'Prime[x]': 'prime(x)',
  130. 'PrimeQ[x]': 'isprime(x)'
  131. }
  132. # trigonometric, e.t.c.
  133. for arc, tri, h in product(('', 'Arc'), (
  134. 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')):
  135. fm = arc + tri + h + '[x]'
  136. if arc: # arc func
  137. fs = 'a' + tri.lower() + h + '(x)'
  138. else: # non-arc func
  139. fs = tri.lower() + h + '(x)'
  140. CORRESPONDENCES.update({fm: fs})
  141. REPLACEMENTS = {
  142. ' ': '',
  143. '^': '**',
  144. '{': '[',
  145. '}': ']',
  146. }
  147. RULES = {
  148. # a single whitespace to '*'
  149. 'whitespace': (
  150. re.compile(r'''
  151. (?:(?<=[a-zA-Z\d])|(?<=\d\.)) # a letter or a number
  152. \s+ # any number of whitespaces
  153. (?:(?=[a-zA-Z\d])|(?=\.\d)) # a letter or a number
  154. ''', re.VERBOSE),
  155. '*'),
  156. # add omitted '*' character
  157. 'add*_1': (
  158. re.compile(r'''
  159. (?:(?<=[])\d])|(?<=\d\.)) # ], ) or a number
  160. # ''
  161. (?=[(a-zA-Z]) # ( or a single letter
  162. ''', re.VERBOSE),
  163. '*'),
  164. # add omitted '*' character (variable letter preceding)
  165. 'add*_2': (
  166. re.compile(r'''
  167. (?<=[a-zA-Z]) # a letter
  168. \( # ( as a character
  169. (?=.) # any characters
  170. ''', re.VERBOSE),
  171. '*('),
  172. # convert 'Pi' to 'pi'
  173. 'Pi': (
  174. re.compile(r'''
  175. (?:
  176. \A|(?<=[^a-zA-Z])
  177. )
  178. Pi # 'Pi' is 3.14159... in Mathematica
  179. (?=[^a-zA-Z])
  180. ''', re.VERBOSE),
  181. 'pi'),
  182. }
  183. # Mathematica function name pattern
  184. FM_PATTERN = re.compile(r'''
  185. (?:
  186. \A|(?<=[^a-zA-Z]) # at the top or a non-letter
  187. )
  188. [A-Z][a-zA-Z\d]* # Function
  189. (?=\[) # [ as a character
  190. ''', re.VERBOSE)
  191. # list or matrix pattern (for future usage)
  192. ARG_MTRX_PATTERN = re.compile(r'''
  193. \{.*\}
  194. ''', re.VERBOSE)
  195. # regex string for function argument pattern
  196. ARGS_PATTERN_TEMPLATE = r'''
  197. (?:
  198. \A|(?<=[^a-zA-Z])
  199. )
  200. {arguments} # model argument like x, y,...
  201. (?=[^a-zA-Z])
  202. '''
  203. # will contain transformed CORRESPONDENCES dictionary
  204. TRANSLATIONS: dict[tuple[str, int], dict[str, Any]] = {}
  205. # cache for a raw users' translation dictionary
  206. cache_original: dict[tuple[str, int], dict[str, Any]] = {}
  207. # cache for a compiled users' translation dictionary
  208. cache_compiled: dict[tuple[str, int], dict[str, Any]] = {}
  209. @classmethod
  210. def _initialize_class(cls):
  211. # get a transformed CORRESPONDENCES dictionary
  212. d = cls._compile_dictionary(cls.CORRESPONDENCES)
  213. cls.TRANSLATIONS.update(d)
  214. def __init__(self, additional_translations=None):
  215. self.translations = {}
  216. # update with TRANSLATIONS (class constant)
  217. self.translations.update(self.TRANSLATIONS)
  218. if additional_translations is None:
  219. additional_translations = {}
  220. # check the latest added translations
  221. if self.__class__.cache_original != additional_translations:
  222. if not isinstance(additional_translations, dict):
  223. raise ValueError('The argument must be dict type')
  224. # get a transformed additional_translations dictionary
  225. d = self._compile_dictionary(additional_translations)
  226. # update cache
  227. self.__class__.cache_original = additional_translations
  228. self.__class__.cache_compiled = d
  229. # merge user's own translations
  230. self.translations.update(self.__class__.cache_compiled)
  231. @classmethod
  232. def _compile_dictionary(cls, dic):
  233. # for return
  234. d = {}
  235. for fm, fs in dic.items():
  236. # check function form
  237. cls._check_input(fm)
  238. cls._check_input(fs)
  239. # uncover '*' hiding behind a whitespace
  240. fm = cls._apply_rules(fm, 'whitespace')
  241. fs = cls._apply_rules(fs, 'whitespace')
  242. # remove whitespace(s)
  243. fm = cls._replace(fm, ' ')
  244. fs = cls._replace(fs, ' ')
  245. # search Mathematica function name
  246. m = cls.FM_PATTERN.search(fm)
  247. # if no-hit
  248. if m is None:
  249. err = "'{f}' function form is invalid.".format(f=fm)
  250. raise ValueError(err)
  251. # get Mathematica function name like 'Log'
  252. fm_name = m.group()
  253. # get arguments of Mathematica function
  254. args, end = cls._get_args(m)
  255. # function side check. (e.g.) '2*Func[x]' is invalid.
  256. if m.start() != 0 or end != len(fm):
  257. err = "'{f}' function form is invalid.".format(f=fm)
  258. raise ValueError(err)
  259. # check the last argument's 1st character
  260. if args[-1][0] == '*':
  261. key_arg = '*'
  262. else:
  263. key_arg = len(args)
  264. key = (fm_name, key_arg)
  265. # convert '*x' to '\\*x' for regex
  266. re_args = [x if x[0] != '*' else '\\' + x for x in args]
  267. # for regex. Example: (?:(x|y|z))
  268. xyz = '(?:(' + '|'.join(re_args) + '))'
  269. # string for regex compile
  270. patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz)
  271. pat = re.compile(patStr, re.VERBOSE)
  272. # update dictionary
  273. d[key] = {}
  274. d[key]['fs'] = fs # SymPy function template
  275. d[key]['args'] = args # args are ['x', 'y'] for example
  276. d[key]['pat'] = pat
  277. return d
  278. def _convert_function(self, s):
  279. '''Parse Mathematica function to SymPy one'''
  280. # compiled regex object
  281. pat = self.FM_PATTERN
  282. scanned = '' # converted string
  283. cur = 0 # position cursor
  284. while True:
  285. m = pat.search(s)
  286. if m is None:
  287. # append the rest of string
  288. scanned += s
  289. break
  290. # get Mathematica function name
  291. fm = m.group()
  292. # get arguments, and the end position of fm function
  293. args, end = self._get_args(m)
  294. # the start position of fm function
  295. bgn = m.start()
  296. # convert Mathematica function to SymPy one
  297. s = self._convert_one_function(s, fm, args, bgn, end)
  298. # update cursor
  299. cur = bgn
  300. # append converted part
  301. scanned += s[:cur]
  302. # shrink s
  303. s = s[cur:]
  304. return scanned
  305. def _convert_one_function(self, s, fm, args, bgn, end):
  306. # no variable-length argument
  307. if (fm, len(args)) in self.translations:
  308. key = (fm, len(args))
  309. # x, y,... model arguments
  310. x_args = self.translations[key]['args']
  311. # make CORRESPONDENCES between model arguments and actual ones
  312. d = {k: v for k, v in zip(x_args, args)}
  313. # with variable-length argument
  314. elif (fm, '*') in self.translations:
  315. key = (fm, '*')
  316. # x, y,..*args (model arguments)
  317. x_args = self.translations[key]['args']
  318. # make CORRESPONDENCES between model arguments and actual ones
  319. d = {}
  320. for i, x in enumerate(x_args):
  321. if x[0] == '*':
  322. d[x] = ','.join(args[i:])
  323. break
  324. d[x] = args[i]
  325. # out of self.translations
  326. else:
  327. err = "'{f}' is out of the whitelist.".format(f=fm)
  328. raise ValueError(err)
  329. # template string of converted function
  330. template = self.translations[key]['fs']
  331. # regex pattern for x_args
  332. pat = self.translations[key]['pat']
  333. scanned = ''
  334. cur = 0
  335. while True:
  336. m = pat.search(template)
  337. if m is None:
  338. scanned += template
  339. break
  340. # get model argument
  341. x = m.group()
  342. # get a start position of the model argument
  343. xbgn = m.start()
  344. # add the corresponding actual argument
  345. scanned += template[:xbgn] + d[x]
  346. # update cursor to the end of the model argument
  347. cur = m.end()
  348. # shrink template
  349. template = template[cur:]
  350. # update to swapped string
  351. s = s[:bgn] + scanned + s[end:]
  352. return s
  353. @classmethod
  354. def _get_args(cls, m):
  355. '''Get arguments of a Mathematica function'''
  356. s = m.string # whole string
  357. anc = m.end() + 1 # pointing the first letter of arguments
  358. square, curly = [], [] # stack for brakets
  359. args = []
  360. # current cursor
  361. cur = anc
  362. for i, c in enumerate(s[anc:], anc):
  363. # extract one argument
  364. if c == ',' and (not square) and (not curly):
  365. args.append(s[cur:i]) # add an argument
  366. cur = i + 1 # move cursor
  367. # handle list or matrix (for future usage)
  368. if c == '{':
  369. curly.append(c)
  370. elif c == '}':
  371. curly.pop()
  372. # seek corresponding ']' with skipping irrevant ones
  373. if c == '[':
  374. square.append(c)
  375. elif c == ']':
  376. if square:
  377. square.pop()
  378. else: # empty stack
  379. args.append(s[cur:i])
  380. break
  381. # the next position to ']' bracket (the function end)
  382. func_end = i + 1
  383. return args, func_end
  384. @classmethod
  385. def _replace(cls, s, bef):
  386. aft = cls.REPLACEMENTS[bef]
  387. s = s.replace(bef, aft)
  388. return s
  389. @classmethod
  390. def _apply_rules(cls, s, bef):
  391. pat, aft = cls.RULES[bef]
  392. return pat.sub(aft, s)
  393. @classmethod
  394. def _check_input(cls, s):
  395. for bracket in (('[', ']'), ('{', '}'), ('(', ')')):
  396. if s.count(bracket[0]) != s.count(bracket[1]):
  397. err = "'{f}' function form is invalid.".format(f=s)
  398. raise ValueError(err)
  399. if '{' in s:
  400. err = "Currently list is not supported."
  401. raise ValueError(err)
  402. def _parse_old(self, s):
  403. # input check
  404. self._check_input(s)
  405. # uncover '*' hiding behind a whitespace
  406. s = self._apply_rules(s, 'whitespace')
  407. # remove whitespace(s)
  408. s = self._replace(s, ' ')
  409. # add omitted '*' character
  410. s = self._apply_rules(s, 'add*_1')
  411. s = self._apply_rules(s, 'add*_2')
  412. # translate function
  413. s = self._convert_function(s)
  414. # '^' to '**'
  415. s = self._replace(s, '^')
  416. # 'Pi' to 'pi'
  417. s = self._apply_rules(s, 'Pi')
  418. # '{', '}' to '[', ']', respectively
  419. # s = cls._replace(s, '{') # currently list is not taken into account
  420. # s = cls._replace(s, '}')
  421. return s
  422. def parse(self, s):
  423. s2 = self._from_mathematica_to_tokens(s)
  424. s3 = self._from_tokens_to_fullformlist(s2)
  425. s4 = self._from_fullformlist_to_sympy(s3)
  426. return s4
  427. INFIX = "Infix"
  428. PREFIX = "Prefix"
  429. POSTFIX = "Postfix"
  430. FLAT = "Flat"
  431. RIGHT = "Right"
  432. LEFT = "Left"
  433. _mathematica_op_precedence: list[tuple[str, str | None, dict[str, str | Callable]]] = [
  434. (POSTFIX, None, {";": lambda x: x + ["Null"] if isinstance(x, list) and x and x[0] == "CompoundExpression" else ["CompoundExpression", x, "Null"]}),
  435. (INFIX, FLAT, {";": "CompoundExpression"}),
  436. (INFIX, RIGHT, {"=": "Set", ":=": "SetDelayed", "+=": "AddTo", "-=": "SubtractFrom", "*=": "TimesBy", "/=": "DivideBy"}),
  437. (INFIX, LEFT, {"//": lambda x, y: [x, y]}),
  438. (POSTFIX, None, {"&": "Function"}),
  439. (INFIX, LEFT, {"/.": "ReplaceAll"}),
  440. (INFIX, RIGHT, {"->": "Rule", ":>": "RuleDelayed"}),
  441. (INFIX, LEFT, {"/;": "Condition"}),
  442. (INFIX, FLAT, {"|": "Alternatives"}),
  443. (POSTFIX, None, {"..": "Repeated", "...": "RepeatedNull"}),
  444. (INFIX, FLAT, {"||": "Or"}),
  445. (INFIX, FLAT, {"&&": "And"}),
  446. (PREFIX, None, {"!": "Not"}),
  447. (INFIX, FLAT, {"===": "SameQ", "=!=": "UnsameQ"}),
  448. (INFIX, FLAT, {"==": "Equal", "!=": "Unequal", "<=": "LessEqual", "<": "Less", ">=": "GreaterEqual", ">": "Greater"}),
  449. (INFIX, None, {";;": "Span"}),
  450. (INFIX, FLAT, {"+": "Plus", "-": "Plus"}),
  451. (INFIX, FLAT, {"*": "Times", "/": "Times"}),
  452. (INFIX, FLAT, {".": "Dot"}),
  453. (PREFIX, None, {"-": lambda x: MathematicaParser._get_neg(x),
  454. "+": lambda x: x}),
  455. (INFIX, RIGHT, {"^": "Power"}),
  456. (INFIX, RIGHT, {"@@": "Apply", "/@": "Map", "//@": "MapAll", "@@@": lambda x, y: ["Apply", x, y, ["List", "1"]]}),
  457. (POSTFIX, None, {"'": "Derivative", "!": "Factorial", "!!": "Factorial2", "--": "Decrement"}),
  458. (INFIX, None, {"[": lambda x, y: [x, *y], "[[": lambda x, y: ["Part", x, *y]}),
  459. (PREFIX, None, {"{": lambda x: ["List", *x], "(": lambda x: x[0]}),
  460. (INFIX, None, {"?": "PatternTest"}),
  461. (POSTFIX, None, {
  462. "_": lambda x: ["Pattern", x, ["Blank"]],
  463. "_.": lambda x: ["Optional", ["Pattern", x, ["Blank"]]],
  464. "__": lambda x: ["Pattern", x, ["BlankSequence"]],
  465. "___": lambda x: ["Pattern", x, ["BlankNullSequence"]],
  466. }),
  467. (INFIX, None, {"_": lambda x, y: ["Pattern", x, ["Blank", y]]}),
  468. (PREFIX, None, {"#": "Slot", "##": "SlotSequence"}),
  469. ]
  470. _missing_arguments_default = {
  471. "#": lambda: ["Slot", "1"],
  472. "##": lambda: ["SlotSequence", "1"],
  473. }
  474. _literal = r"[A-Za-z][A-Za-z0-9]*"
  475. _number = r"(?:[0-9]+(?:\.[0-9]*)?|\.[0-9]+)"
  476. _enclosure_open = ["(", "[", "[[", "{"]
  477. _enclosure_close = [")", "]", "]]", "}"]
  478. @classmethod
  479. def _get_neg(cls, x):
  480. return f"-{x}" if isinstance(x, str) and re.match(MathematicaParser._number, x) else ["Times", "-1", x]
  481. @classmethod
  482. def _get_inv(cls, x):
  483. return ["Power", x, "-1"]
  484. _regex_tokenizer = None
  485. def _get_tokenizer(self):
  486. if self._regex_tokenizer is not None:
  487. # Check if the regular expression has already been compiled:
  488. return self._regex_tokenizer
  489. tokens = [self._literal, self._number]
  490. tokens_escape = self._enclosure_open[:] + self._enclosure_close[:]
  491. for typ, strat, symdict in self._mathematica_op_precedence:
  492. for k in symdict:
  493. tokens_escape.append(k)
  494. tokens_escape.sort(key=lambda x: -len(x))
  495. tokens.extend(map(re.escape, tokens_escape))
  496. tokens.append(",")
  497. tokens.append("\n")
  498. tokenizer = re.compile("(" + "|".join(tokens) + ")")
  499. self._regex_tokenizer = tokenizer
  500. return self._regex_tokenizer
  501. def _from_mathematica_to_tokens(self, code: str):
  502. tokenizer = self._get_tokenizer()
  503. # Find strings:
  504. code_splits: list[str | list] = []
  505. while True:
  506. string_start = code.find("\"")
  507. if string_start == -1:
  508. if len(code) > 0:
  509. code_splits.append(code)
  510. break
  511. match_end = re.search(r'(?<!\\)"', code[string_start+1:])
  512. if match_end is None:
  513. raise SyntaxError('mismatch in string " " expression')
  514. string_end = string_start + match_end.start() + 1
  515. if string_start > 0:
  516. code_splits.append(code[:string_start])
  517. code_splits.append(["_Str", code[string_start+1:string_end].replace('\\"', '"')])
  518. code = code[string_end+1:]
  519. # Remove comments:
  520. for i, code_split in enumerate(code_splits):
  521. if isinstance(code_split, list):
  522. continue
  523. while True:
  524. pos_comment_start = code_split.find("(*")
  525. if pos_comment_start == -1:
  526. break
  527. pos_comment_end = code_split.find("*)")
  528. if pos_comment_end == -1 or pos_comment_end < pos_comment_start:
  529. raise SyntaxError("mismatch in comment (* *) code")
  530. code_split = code_split[:pos_comment_start] + code_split[pos_comment_end+2:]
  531. code_splits[i] = code_split
  532. # Tokenize the input strings with a regular expression:
  533. token_lists = [tokenizer.findall(i) if isinstance(i, str) and i.isascii() else [i] for i in code_splits]
  534. tokens = [j for i in token_lists for j in i]
  535. # Remove newlines at the beginning
  536. while tokens and tokens[0] == "\n":
  537. tokens.pop(0)
  538. # Remove newlines at the end
  539. while tokens and tokens[-1] == "\n":
  540. tokens.pop(-1)
  541. return tokens
  542. def _is_op(self, token: str | list) -> bool:
  543. if isinstance(token, list):
  544. return False
  545. if re.match(self._literal, token):
  546. return False
  547. if re.match("-?" + self._number, token):
  548. return False
  549. return True
  550. def _is_valid_star1(self, token: str | list) -> bool:
  551. if token in (")", "}"):
  552. return True
  553. return not self._is_op(token)
  554. def _is_valid_star2(self, token: str | list) -> bool:
  555. if token in ("(", "{"):
  556. return True
  557. return not self._is_op(token)
  558. def _from_tokens_to_fullformlist(self, tokens: list):
  559. stack: list[list] = [[]]
  560. open_seq = []
  561. pointer: int = 0
  562. while pointer < len(tokens):
  563. token = tokens[pointer]
  564. if token in self._enclosure_open:
  565. stack[-1].append(token)
  566. open_seq.append(token)
  567. stack.append([])
  568. elif token == ",":
  569. if len(stack[-1]) == 0 and stack[-2][-1] == open_seq[-1]:
  570. raise SyntaxError("%s cannot be followed by comma ," % open_seq[-1])
  571. stack[-1] = self._parse_after_braces(stack[-1])
  572. stack.append([])
  573. elif token in self._enclosure_close:
  574. ind = self._enclosure_close.index(token)
  575. if self._enclosure_open[ind] != open_seq[-1]:
  576. unmatched_enclosure = SyntaxError("unmatched enclosure")
  577. if token == "]]" and open_seq[-1] == "[":
  578. if open_seq[-2] == "[":
  579. # These two lines would be logically correct, but are
  580. # unnecessary:
  581. # token = "]"
  582. # tokens[pointer] = "]"
  583. tokens.insert(pointer+1, "]")
  584. elif open_seq[-2] == "[[":
  585. if tokens[pointer+1] == "]":
  586. tokens[pointer+1] = "]]"
  587. elif tokens[pointer+1] == "]]":
  588. tokens[pointer+1] = "]]"
  589. tokens.insert(pointer+2, "]")
  590. else:
  591. raise unmatched_enclosure
  592. else:
  593. raise unmatched_enclosure
  594. if len(stack[-1]) == 0 and stack[-2][-1] == "(":
  595. raise SyntaxError("( ) not valid syntax")
  596. last_stack = self._parse_after_braces(stack[-1], True)
  597. stack[-1] = last_stack
  598. new_stack_element = []
  599. while stack[-1][-1] != open_seq[-1]:
  600. new_stack_element.append(stack.pop())
  601. new_stack_element.reverse()
  602. if open_seq[-1] == "(" and len(new_stack_element) != 1:
  603. raise SyntaxError("( must be followed by one expression, %i detected" % len(new_stack_element))
  604. stack[-1].append(new_stack_element)
  605. open_seq.pop(-1)
  606. else:
  607. stack[-1].append(token)
  608. pointer += 1
  609. assert len(stack) == 1
  610. return self._parse_after_braces(stack[0])
  611. def _util_remove_newlines(self, lines: list, tokens: list, inside_enclosure: bool):
  612. pointer = 0
  613. size = len(tokens)
  614. while pointer < size:
  615. token = tokens[pointer]
  616. if token == "\n":
  617. if inside_enclosure:
  618. # Ignore newlines inside enclosures
  619. tokens.pop(pointer)
  620. size -= 1
  621. continue
  622. if pointer == 0:
  623. tokens.pop(0)
  624. size -= 1
  625. continue
  626. if pointer > 1:
  627. try:
  628. prev_expr = self._parse_after_braces(tokens[:pointer], inside_enclosure)
  629. except SyntaxError:
  630. tokens.pop(pointer)
  631. size -= 1
  632. continue
  633. else:
  634. prev_expr = tokens[0]
  635. if len(prev_expr) > 0 and prev_expr[0] == "CompoundExpression":
  636. lines.extend(prev_expr[1:])
  637. else:
  638. lines.append(prev_expr)
  639. for i in range(pointer):
  640. tokens.pop(0)
  641. size -= pointer
  642. pointer = 0
  643. continue
  644. pointer += 1
  645. def _util_add_missing_asterisks(self, tokens: list):
  646. size: int = len(tokens)
  647. pointer: int = 0
  648. while pointer < size:
  649. if (pointer > 0 and
  650. self._is_valid_star1(tokens[pointer - 1]) and
  651. self._is_valid_star2(tokens[pointer])):
  652. # This is a trick to add missing * operators in the expression,
  653. # `"*" in op_dict` makes sure the precedence level is the same as "*",
  654. # while `not self._is_op( ... )` makes sure this and the previous
  655. # expression are not operators.
  656. if tokens[pointer] == "(":
  657. # ( has already been processed by now, replace:
  658. tokens[pointer] = "*"
  659. tokens[pointer + 1] = tokens[pointer + 1][0]
  660. else:
  661. tokens.insert(pointer, "*")
  662. pointer += 1
  663. size += 1
  664. pointer += 1
  665. def _parse_after_braces(self, tokens: list, inside_enclosure: bool = False):
  666. op_dict: dict
  667. changed: bool = False
  668. lines: list = []
  669. self._util_remove_newlines(lines, tokens, inside_enclosure)
  670. for op_type, grouping_strat, op_dict in reversed(self._mathematica_op_precedence):
  671. if "*" in op_dict:
  672. self._util_add_missing_asterisks(tokens)
  673. size: int = len(tokens)
  674. pointer: int = 0
  675. while pointer < size:
  676. token = tokens[pointer]
  677. if isinstance(token, str) and token in op_dict:
  678. op_name: str | Callable = op_dict[token]
  679. node: list
  680. first_index: int
  681. if isinstance(op_name, str):
  682. node = [op_name]
  683. first_index = 1
  684. else:
  685. node = []
  686. first_index = 0
  687. if token in ("+", "-") and op_type == self.PREFIX and pointer > 0 and not self._is_op(tokens[pointer - 1]):
  688. # Make sure that PREFIX + - don't match expressions like a + b or a - b,
  689. # the INFIX + - are supposed to match that expression:
  690. pointer += 1
  691. continue
  692. if op_type == self.INFIX:
  693. if pointer == 0 or pointer == size - 1 or self._is_op(tokens[pointer - 1]) or self._is_op(tokens[pointer + 1]):
  694. pointer += 1
  695. continue
  696. changed = True
  697. tokens[pointer] = node
  698. if op_type == self.INFIX:
  699. arg1 = tokens.pop(pointer-1)
  700. arg2 = tokens.pop(pointer)
  701. if token == "/":
  702. arg2 = self._get_inv(arg2)
  703. elif token == "-":
  704. arg2 = self._get_neg(arg2)
  705. pointer -= 1
  706. size -= 2
  707. node.append(arg1)
  708. node_p = node
  709. if grouping_strat == self.FLAT:
  710. while pointer + 2 < size and self._check_op_compatible(tokens[pointer+1], token):
  711. node_p.append(arg2)
  712. other_op = tokens.pop(pointer+1)
  713. arg2 = tokens.pop(pointer+1)
  714. if other_op == "/":
  715. arg2 = self._get_inv(arg2)
  716. elif other_op == "-":
  717. arg2 = self._get_neg(arg2)
  718. size -= 2
  719. node_p.append(arg2)
  720. elif grouping_strat == self.RIGHT:
  721. while pointer + 2 < size and tokens[pointer+1] == token:
  722. node_p.append([op_name, arg2])
  723. node_p = node_p[-1]
  724. tokens.pop(pointer+1)
  725. arg2 = tokens.pop(pointer+1)
  726. size -= 2
  727. node_p.append(arg2)
  728. elif grouping_strat == self.LEFT:
  729. while pointer + 1 < size and tokens[pointer+1] == token:
  730. if isinstance(op_name, str):
  731. node_p[first_index] = [op_name, node_p[first_index], arg2]
  732. else:
  733. node_p[first_index] = op_name(node_p[first_index], arg2)
  734. tokens.pop(pointer+1)
  735. arg2 = tokens.pop(pointer+1)
  736. size -= 2
  737. node_p.append(arg2)
  738. else:
  739. node.append(arg2)
  740. elif op_type == self.PREFIX:
  741. assert grouping_strat is None
  742. if pointer == size - 1 or self._is_op(tokens[pointer + 1]):
  743. tokens[pointer] = self._missing_arguments_default[token]()
  744. else:
  745. node.append(tokens.pop(pointer+1))
  746. size -= 1
  747. elif op_type == self.POSTFIX:
  748. assert grouping_strat is None
  749. if pointer == 0 or self._is_op(tokens[pointer - 1]):
  750. tokens[pointer] = self._missing_arguments_default[token]()
  751. else:
  752. node.append(tokens.pop(pointer-1))
  753. pointer -= 1
  754. size -= 1
  755. if isinstance(op_name, Callable): # type: ignore
  756. op_call: Callable = typing.cast(Callable, op_name)
  757. new_node = op_call(*node)
  758. node.clear()
  759. if isinstance(new_node, list):
  760. node.extend(new_node)
  761. else:
  762. tokens[pointer] = new_node
  763. pointer += 1
  764. if len(tokens) > 1 or (len(lines) == 0 and len(tokens) == 0):
  765. if changed:
  766. # Trick to deal with cases in which an operator with lower
  767. # precedence should be transformed before an operator of higher
  768. # precedence. Such as in the case of `#&[x]` (that is
  769. # equivalent to `Lambda(d_, d_)(x)` in SymPy). In this case the
  770. # operator `&` has lower precedence than `[`, but needs to be
  771. # evaluated first because otherwise `# (&[x])` is not a valid
  772. # expression:
  773. return self._parse_after_braces(tokens, inside_enclosure)
  774. raise SyntaxError("unable to create a single AST for the expression")
  775. if len(lines) > 0:
  776. if tokens[0] and tokens[0][0] == "CompoundExpression":
  777. tokens = tokens[0][1:]
  778. compound_expression = ["CompoundExpression", *lines, *tokens]
  779. return compound_expression
  780. return tokens[0]
  781. def _check_op_compatible(self, op1: str, op2: str):
  782. if op1 == op2:
  783. return True
  784. muldiv = {"*", "/"}
  785. addsub = {"+", "-"}
  786. if op1 in muldiv and op2 in muldiv:
  787. return True
  788. if op1 in addsub and op2 in addsub:
  789. return True
  790. return False
  791. def _from_fullform_to_fullformlist(self, wmexpr: str):
  792. """
  793. Parses FullForm[Downvalues[]] generated by Mathematica
  794. """
  795. out: list = []
  796. stack = [out]
  797. generator = re.finditer(r'[\[\],]', wmexpr)
  798. last_pos = 0
  799. for match in generator:
  800. if match is None:
  801. break
  802. position = match.start()
  803. last_expr = wmexpr[last_pos:position].replace(',', '').replace(']', '').replace('[', '').strip()
  804. if match.group() == ',':
  805. if last_expr != '':
  806. stack[-1].append(last_expr)
  807. elif match.group() == ']':
  808. if last_expr != '':
  809. stack[-1].append(last_expr)
  810. stack.pop()
  811. elif match.group() == '[':
  812. stack[-1].append([last_expr])
  813. stack.append(stack[-1][-1])
  814. last_pos = match.end()
  815. return out[0]
  816. def _from_fullformlist_to_fullformsympy(self, pylist: list):
  817. from sympy import Function, Symbol
  818. def converter(expr):
  819. if isinstance(expr, list):
  820. if len(expr) > 0:
  821. head = expr[0]
  822. args = [converter(arg) for arg in expr[1:]]
  823. return Function(head)(*args)
  824. else:
  825. raise ValueError("Empty list of expressions")
  826. elif isinstance(expr, str):
  827. return Symbol(expr)
  828. else:
  829. return _sympify(expr)
  830. return converter(pylist)
  831. _node_conversions = {
  832. "Times": Mul,
  833. "Plus": Add,
  834. "Power": Pow,
  835. "Log": lambda *a: log(*reversed(a)),
  836. "Log2": lambda x: log(x, 2),
  837. "Log10": lambda x: log(x, 10),
  838. "Exp": exp,
  839. "Sqrt": sqrt,
  840. "Sin": sin,
  841. "Cos": cos,
  842. "Tan": tan,
  843. "Cot": cot,
  844. "Sec": sec,
  845. "Csc": csc,
  846. "ArcSin": asin,
  847. "ArcCos": acos,
  848. "ArcTan": lambda *a: atan2(*reversed(a)) if len(a) == 2 else atan(*a),
  849. "ArcCot": acot,
  850. "ArcSec": asec,
  851. "ArcCsc": acsc,
  852. "Sinh": sinh,
  853. "Cosh": cosh,
  854. "Tanh": tanh,
  855. "Coth": coth,
  856. "Sech": sech,
  857. "Csch": csch,
  858. "ArcSinh": asinh,
  859. "ArcCosh": acosh,
  860. "ArcTanh": atanh,
  861. "ArcCoth": acoth,
  862. "ArcSech": asech,
  863. "ArcCsch": acsch,
  864. "Expand": expand,
  865. "Im": im,
  866. "Re": sympy.re,
  867. "Flatten": flatten,
  868. "Polylog": polylog,
  869. "Cancel": cancel,
  870. # Gamma=gamma,
  871. "TrigExpand": expand_trig,
  872. "Sign": sign,
  873. "Simplify": simplify,
  874. "Defer": UnevaluatedExpr,
  875. "Identity": S,
  876. # Sum=Sum_doit,
  877. # Module=With,
  878. # Block=With,
  879. "Null": lambda *a: S.Zero,
  880. "Mod": Mod,
  881. "Max": Max,
  882. "Min": Min,
  883. "Pochhammer": rf,
  884. "ExpIntegralEi": Ei,
  885. "SinIntegral": Si,
  886. "CosIntegral": Ci,
  887. "AiryAi": airyai,
  888. "AiryAiPrime": airyaiprime,
  889. "AiryBi": airybi,
  890. "AiryBiPrime": airybiprime,
  891. "LogIntegral": li,
  892. "PrimePi": primepi,
  893. "Prime": prime,
  894. "PrimeQ": isprime,
  895. "List": Tuple,
  896. "Greater": StrictGreaterThan,
  897. "GreaterEqual": GreaterThan,
  898. "Less": StrictLessThan,
  899. "LessEqual": LessThan,
  900. "Equal": Equality,
  901. "Or": Or,
  902. "And": And,
  903. "Function": _parse_Function,
  904. }
  905. _atom_conversions = {
  906. "I": I,
  907. "Pi": pi,
  908. }
  909. def _from_fullformlist_to_sympy(self, full_form_list):
  910. def recurse(expr):
  911. if isinstance(expr, list):
  912. if isinstance(expr[0], list):
  913. head = recurse(expr[0])
  914. else:
  915. head = self._node_conversions.get(expr[0], Function(expr[0]))
  916. return head(*[recurse(arg) for arg in expr[1:]])
  917. else:
  918. return self._atom_conversions.get(expr, sympify(expr))
  919. return recurse(full_form_list)
  920. def _from_fullformsympy_to_sympy(self, mform):
  921. expr = mform
  922. for mma_form, sympy_node in self._node_conversions.items():
  923. expr = expr.replace(Function(mma_form), sympy_node)
  924. return expr