_parse_latex_antlr.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
  1. # Ported from latex2sympy by @augustt198
  2. # https://github.com/augustt198/latex2sympy
  3. # See license in LICENSE.txt
  4. from importlib.metadata import version
  5. import sympy
  6. from sympy.external import import_module
  7. from sympy.printing.str import StrPrinter
  8. from sympy.physics.quantum.state import Bra, Ket
  9. from .errors import LaTeXParsingError
  10. LaTeXParser = LaTeXLexer = MathErrorListener = None
  11. try:
  12. LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser',
  13. import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser
  14. LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer',
  15. import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer
  16. except Exception:
  17. pass
  18. ErrorListener = import_module('antlr4.error.ErrorListener',
  19. warn_not_installed=True,
  20. import_kwargs={'fromlist': ['ErrorListener']}
  21. )
  22. if ErrorListener:
  23. class MathErrorListener(ErrorListener.ErrorListener): # type: ignore
  24. def __init__(self, src):
  25. super(ErrorListener.ErrorListener, self).__init__()
  26. self.src = src
  27. def syntaxError(self, recog, symbol, line, col, msg, e):
  28. fmt = "%s\n%s\n%s"
  29. marker = "~" * col + "^"
  30. if msg.startswith("missing"):
  31. err = fmt % (msg, self.src, marker)
  32. elif msg.startswith("no viable"):
  33. err = fmt % ("I expected something else here", self.src, marker)
  34. elif msg.startswith("mismatched"):
  35. names = LaTeXParser.literalNames
  36. expected = [
  37. names[i] for i in e.getExpectedTokens() if i < len(names)
  38. ]
  39. if len(expected) < 10:
  40. expected = " ".join(expected)
  41. err = (fmt % ("I expected one of these: " + expected, self.src,
  42. marker))
  43. else:
  44. err = (fmt % ("I expected something else here", self.src,
  45. marker))
  46. else:
  47. err = fmt % ("I don't understand this", self.src, marker)
  48. raise LaTeXParsingError(err)
  49. def parse_latex(sympy):
  50. antlr4 = import_module('antlr4')
  51. if None in [antlr4, MathErrorListener] or \
  52. not version('antlr4-python3-runtime').startswith('4.11'):
  53. raise ImportError("LaTeX parsing requires the antlr4 Python package,"
  54. " provided by pip (antlr4-python3-runtime) or"
  55. " conda (antlr-python-runtime), version 4.11")
  56. matherror = MathErrorListener(sympy)
  57. stream = antlr4.InputStream(sympy)
  58. lex = LaTeXLexer(stream)
  59. lex.removeErrorListeners()
  60. lex.addErrorListener(matherror)
  61. tokens = antlr4.CommonTokenStream(lex)
  62. parser = LaTeXParser(tokens)
  63. # remove default console error listener
  64. parser.removeErrorListeners()
  65. parser.addErrorListener(matherror)
  66. relation = parser.math().relation()
  67. expr = convert_relation(relation)
  68. return expr
  69. def convert_relation(rel):
  70. if rel.expr():
  71. return convert_expr(rel.expr())
  72. lh = convert_relation(rel.relation(0))
  73. rh = convert_relation(rel.relation(1))
  74. if rel.LT():
  75. return sympy.StrictLessThan(lh, rh)
  76. elif rel.LTE():
  77. return sympy.LessThan(lh, rh)
  78. elif rel.GT():
  79. return sympy.StrictGreaterThan(lh, rh)
  80. elif rel.GTE():
  81. return sympy.GreaterThan(lh, rh)
  82. elif rel.EQUAL():
  83. return sympy.Eq(lh, rh)
  84. elif rel.NEQ():
  85. return sympy.Ne(lh, rh)
  86. def convert_expr(expr):
  87. return convert_add(expr.additive())
  88. def convert_add(add):
  89. if add.ADD():
  90. lh = convert_add(add.additive(0))
  91. rh = convert_add(add.additive(1))
  92. return sympy.Add(lh, rh, evaluate=False)
  93. elif add.SUB():
  94. lh = convert_add(add.additive(0))
  95. rh = convert_add(add.additive(1))
  96. if hasattr(rh, "is_Atom") and rh.is_Atom:
  97. return sympy.Add(lh, -1 * rh, evaluate=False)
  98. return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), evaluate=False)
  99. else:
  100. return convert_mp(add.mp())
  101. def convert_mp(mp):
  102. if hasattr(mp, 'mp'):
  103. mp_left = mp.mp(0)
  104. mp_right = mp.mp(1)
  105. else:
  106. mp_left = mp.mp_nofunc(0)
  107. mp_right = mp.mp_nofunc(1)
  108. if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
  109. lh = convert_mp(mp_left)
  110. rh = convert_mp(mp_right)
  111. return sympy.Mul(lh, rh, evaluate=False)
  112. elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
  113. lh = convert_mp(mp_left)
  114. rh = convert_mp(mp_right)
  115. return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
  116. else:
  117. if hasattr(mp, 'unary'):
  118. return convert_unary(mp.unary())
  119. else:
  120. return convert_unary(mp.unary_nofunc())
  121. def convert_unary(unary):
  122. if hasattr(unary, 'unary'):
  123. nested_unary = unary.unary()
  124. else:
  125. nested_unary = unary.unary_nofunc()
  126. if hasattr(unary, 'postfix_nofunc'):
  127. first = unary.postfix()
  128. tail = unary.postfix_nofunc()
  129. postfix = [first] + tail
  130. else:
  131. postfix = unary.postfix()
  132. if unary.ADD():
  133. return convert_unary(nested_unary)
  134. elif unary.SUB():
  135. numabs = convert_unary(nested_unary)
  136. # Use Integer(-n) instead of Mul(-1, n)
  137. return -numabs
  138. elif postfix:
  139. return convert_postfix_list(postfix)
  140. def convert_postfix_list(arr, i=0):
  141. if i >= len(arr):
  142. raise LaTeXParsingError("Index out of bounds")
  143. res = convert_postfix(arr[i])
  144. if isinstance(res, sympy.Expr):
  145. if i == len(arr) - 1:
  146. return res # nothing to multiply by
  147. else:
  148. if i > 0:
  149. left = convert_postfix(arr[i - 1])
  150. right = convert_postfix(arr[i + 1])
  151. if isinstance(left, sympy.Expr) and isinstance(
  152. right, sympy.Expr):
  153. left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol)
  154. right_syms = convert_postfix(arr[i + 1]).atoms(
  155. sympy.Symbol)
  156. # if the left and right sides contain no variables and the
  157. # symbol in between is 'x', treat as multiplication.
  158. if not (left_syms or right_syms) and str(res) == 'x':
  159. return convert_postfix_list(arr, i + 1)
  160. # multiply by next
  161. return sympy.Mul(
  162. res, convert_postfix_list(arr, i + 1), evaluate=False)
  163. else: # must be derivative
  164. wrt = res[0]
  165. if i == len(arr) - 1:
  166. raise LaTeXParsingError("Expected expression for derivative")
  167. else:
  168. expr = convert_postfix_list(arr, i + 1)
  169. return sympy.Derivative(expr, wrt)
  170. def do_subs(expr, at):
  171. if at.expr():
  172. at_expr = convert_expr(at.expr())
  173. syms = at_expr.atoms(sympy.Symbol)
  174. if len(syms) == 0:
  175. return expr
  176. elif len(syms) > 0:
  177. sym = next(iter(syms))
  178. return expr.subs(sym, at_expr)
  179. elif at.equality():
  180. lh = convert_expr(at.equality().expr(0))
  181. rh = convert_expr(at.equality().expr(1))
  182. return expr.subs(lh, rh)
  183. def convert_postfix(postfix):
  184. if hasattr(postfix, 'exp'):
  185. exp_nested = postfix.exp()
  186. else:
  187. exp_nested = postfix.exp_nofunc()
  188. exp = convert_exp(exp_nested)
  189. for op in postfix.postfix_op():
  190. if op.BANG():
  191. if isinstance(exp, list):
  192. raise LaTeXParsingError("Cannot apply postfix to derivative")
  193. exp = sympy.factorial(exp, evaluate=False)
  194. elif op.eval_at():
  195. ev = op.eval_at()
  196. at_b = None
  197. at_a = None
  198. if ev.eval_at_sup():
  199. at_b = do_subs(exp, ev.eval_at_sup())
  200. if ev.eval_at_sub():
  201. at_a = do_subs(exp, ev.eval_at_sub())
  202. if at_b is not None and at_a is not None:
  203. exp = sympy.Add(at_b, -1 * at_a, evaluate=False)
  204. elif at_b is not None:
  205. exp = at_b
  206. elif at_a is not None:
  207. exp = at_a
  208. return exp
  209. def convert_exp(exp):
  210. if hasattr(exp, 'exp'):
  211. exp_nested = exp.exp()
  212. else:
  213. exp_nested = exp.exp_nofunc()
  214. if exp_nested:
  215. base = convert_exp(exp_nested)
  216. if isinstance(base, list):
  217. raise LaTeXParsingError("Cannot raise derivative to power")
  218. if exp.atom():
  219. exponent = convert_atom(exp.atom())
  220. elif exp.expr():
  221. exponent = convert_expr(exp.expr())
  222. return sympy.Pow(base, exponent, evaluate=False)
  223. else:
  224. if hasattr(exp, 'comp'):
  225. return convert_comp(exp.comp())
  226. else:
  227. return convert_comp(exp.comp_nofunc())
  228. def convert_comp(comp):
  229. if comp.group():
  230. return convert_expr(comp.group().expr())
  231. elif comp.abs_group():
  232. return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
  233. elif comp.atom():
  234. return convert_atom(comp.atom())
  235. elif comp.floor():
  236. return convert_floor(comp.floor())
  237. elif comp.ceil():
  238. return convert_ceil(comp.ceil())
  239. elif comp.func():
  240. return convert_func(comp.func())
  241. def convert_atom(atom):
  242. if atom.LETTER():
  243. sname = atom.LETTER().getText()
  244. if atom.subexpr():
  245. if atom.subexpr().expr(): # subscript is expr
  246. subscript = convert_expr(atom.subexpr().expr())
  247. else: # subscript is atom
  248. subscript = convert_atom(atom.subexpr().atom())
  249. sname += '_{' + StrPrinter().doprint(subscript) + '}'
  250. if atom.SINGLE_QUOTES():
  251. sname += atom.SINGLE_QUOTES().getText() # put after subscript for easy identify
  252. return sympy.Symbol(sname)
  253. elif atom.SYMBOL():
  254. s = atom.SYMBOL().getText()[1:]
  255. if s == "infty":
  256. return sympy.oo
  257. else:
  258. if atom.subexpr():
  259. subscript = None
  260. if atom.subexpr().expr(): # subscript is expr
  261. subscript = convert_expr(atom.subexpr().expr())
  262. else: # subscript is atom
  263. subscript = convert_atom(atom.subexpr().atom())
  264. subscriptName = StrPrinter().doprint(subscript)
  265. s += '_{' + subscriptName + '}'
  266. return sympy.Symbol(s)
  267. elif atom.number():
  268. s = atom.number().getText().replace(",", "")
  269. return sympy.Number(s)
  270. elif atom.DIFFERENTIAL():
  271. var = get_differential_var(atom.DIFFERENTIAL())
  272. return sympy.Symbol('d' + var.name)
  273. elif atom.mathit():
  274. text = rule2text(atom.mathit().mathit_text())
  275. return sympy.Symbol(text)
  276. elif atom.frac():
  277. return convert_frac(atom.frac())
  278. elif atom.binom():
  279. return convert_binom(atom.binom())
  280. elif atom.bra():
  281. val = convert_expr(atom.bra().expr())
  282. return Bra(val)
  283. elif atom.ket():
  284. val = convert_expr(atom.ket().expr())
  285. return Ket(val)
  286. def rule2text(ctx):
  287. stream = ctx.start.getInputStream()
  288. # starting index of starting token
  289. startIdx = ctx.start.start
  290. # stopping index of stopping token
  291. stopIdx = ctx.stop.stop
  292. return stream.getText(startIdx, stopIdx)
  293. def convert_frac(frac):
  294. diff_op = False
  295. partial_op = False
  296. if frac.lower and frac.upper:
  297. lower_itv = frac.lower.getSourceInterval()
  298. lower_itv_len = lower_itv[1] - lower_itv[0] + 1
  299. if (frac.lower.start == frac.lower.stop
  300. and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL):
  301. wrt = get_differential_var_str(frac.lower.start.text)
  302. diff_op = True
  303. elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL
  304. and frac.lower.start.text == '\\partial'
  305. and (frac.lower.stop.type == LaTeXLexer.LETTER
  306. or frac.lower.stop.type == LaTeXLexer.SYMBOL)):
  307. partial_op = True
  308. wrt = frac.lower.stop.text
  309. if frac.lower.stop.type == LaTeXLexer.SYMBOL:
  310. wrt = wrt[1:]
  311. if diff_op or partial_op:
  312. wrt = sympy.Symbol(wrt)
  313. if (diff_op and frac.upper.start == frac.upper.stop
  314. and frac.upper.start.type == LaTeXLexer.LETTER
  315. and frac.upper.start.text == 'd'):
  316. return [wrt]
  317. elif (partial_op and frac.upper.start == frac.upper.stop
  318. and frac.upper.start.type == LaTeXLexer.SYMBOL
  319. and frac.upper.start.text == '\\partial'):
  320. return [wrt]
  321. upper_text = rule2text(frac.upper)
  322. expr_top = None
  323. if diff_op and upper_text.startswith('d'):
  324. expr_top = parse_latex(upper_text[1:])
  325. elif partial_op and frac.upper.start.text == '\\partial':
  326. expr_top = parse_latex(upper_text[len('\\partial'):])
  327. if expr_top:
  328. return sympy.Derivative(expr_top, wrt)
  329. if frac.upper:
  330. expr_top = convert_expr(frac.upper)
  331. else:
  332. expr_top = sympy.Number(frac.upperd.text)
  333. if frac.lower:
  334. expr_bot = convert_expr(frac.lower)
  335. else:
  336. expr_bot = sympy.Number(frac.lowerd.text)
  337. inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False)
  338. if expr_top == 1:
  339. return inverse_denom
  340. else:
  341. return sympy.Mul(expr_top, inverse_denom, evaluate=False)
  342. def convert_binom(binom):
  343. expr_n = convert_expr(binom.n)
  344. expr_k = convert_expr(binom.k)
  345. return sympy.binomial(expr_n, expr_k, evaluate=False)
  346. def convert_floor(floor):
  347. val = convert_expr(floor.val)
  348. return sympy.floor(val, evaluate=False)
  349. def convert_ceil(ceil):
  350. val = convert_expr(ceil.val)
  351. return sympy.ceiling(val, evaluate=False)
  352. def convert_func(func):
  353. if func.func_normal():
  354. if func.L_PAREN(): # function called with parenthesis
  355. arg = convert_func_arg(func.func_arg())
  356. else:
  357. arg = convert_func_arg(func.func_arg_noparens())
  358. name = func.func_normal().start.text[1:]
  359. # change arc<trig> -> a<trig>
  360. if name in [
  361. "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"
  362. ]:
  363. name = "a" + name[3:]
  364. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  365. if name in ["arsinh", "arcosh", "artanh"]:
  366. name = "a" + name[2:]
  367. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  368. if name == "exp":
  369. expr = sympy.exp(arg, evaluate=False)
  370. if name in ("log", "lg", "ln"):
  371. if func.subexpr():
  372. if func.subexpr().expr():
  373. base = convert_expr(func.subexpr().expr())
  374. else:
  375. base = convert_atom(func.subexpr().atom())
  376. elif name == "lg": # ISO 80000-2:2019
  377. base = 10
  378. elif name in ("ln", "log"): # SymPy's latex printer prints ln as log by default
  379. base = sympy.E
  380. expr = sympy.log(arg, base, evaluate=False)
  381. func_pow = None
  382. should_pow = True
  383. if func.supexpr():
  384. if func.supexpr().expr():
  385. func_pow = convert_expr(func.supexpr().expr())
  386. else:
  387. func_pow = convert_atom(func.supexpr().atom())
  388. if name in [
  389. "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh",
  390. "tanh"
  391. ]:
  392. if func_pow == -1:
  393. name = "a" + name
  394. should_pow = False
  395. expr = getattr(sympy.functions, name)(arg, evaluate=False)
  396. if func_pow and should_pow:
  397. expr = sympy.Pow(expr, func_pow, evaluate=False)
  398. return expr
  399. elif func.LETTER() or func.SYMBOL():
  400. if func.LETTER():
  401. fname = func.LETTER().getText()
  402. elif func.SYMBOL():
  403. fname = func.SYMBOL().getText()[1:]
  404. fname = str(fname) # can't be unicode
  405. if func.subexpr():
  406. if func.subexpr().expr(): # subscript is expr
  407. subscript = convert_expr(func.subexpr().expr())
  408. else: # subscript is atom
  409. subscript = convert_atom(func.subexpr().atom())
  410. subscriptName = StrPrinter().doprint(subscript)
  411. fname += '_{' + subscriptName + '}'
  412. if func.SINGLE_QUOTES():
  413. fname += func.SINGLE_QUOTES().getText()
  414. input_args = func.args()
  415. output_args = []
  416. while input_args.args(): # handle multiple arguments to function
  417. output_args.append(convert_expr(input_args.expr()))
  418. input_args = input_args.args()
  419. output_args.append(convert_expr(input_args.expr()))
  420. return sympy.Function(fname)(*output_args)
  421. elif func.FUNC_INT():
  422. return handle_integral(func)
  423. elif func.FUNC_SQRT():
  424. expr = convert_expr(func.base)
  425. if func.root:
  426. r = convert_expr(func.root)
  427. return sympy.root(expr, r, evaluate=False)
  428. else:
  429. return sympy.sqrt(expr, evaluate=False)
  430. elif func.FUNC_OVERLINE():
  431. expr = convert_expr(func.base)
  432. return sympy.conjugate(expr, evaluate=False)
  433. elif func.FUNC_SUM():
  434. return handle_sum_or_prod(func, "summation")
  435. elif func.FUNC_PROD():
  436. return handle_sum_or_prod(func, "product")
  437. elif func.FUNC_LIM():
  438. return handle_limit(func)
  439. def convert_func_arg(arg):
  440. if hasattr(arg, 'expr'):
  441. return convert_expr(arg.expr())
  442. else:
  443. return convert_mp(arg.mp_nofunc())
  444. def handle_integral(func):
  445. if func.additive():
  446. integrand = convert_add(func.additive())
  447. elif func.frac():
  448. integrand = convert_frac(func.frac())
  449. else:
  450. integrand = 1
  451. int_var = None
  452. if func.DIFFERENTIAL():
  453. int_var = get_differential_var(func.DIFFERENTIAL())
  454. else:
  455. for sym in integrand.atoms(sympy.Symbol):
  456. s = str(sym)
  457. if len(s) > 1 and s[0] == 'd':
  458. if s[1] == '\\':
  459. int_var = sympy.Symbol(s[2:])
  460. else:
  461. int_var = sympy.Symbol(s[1:])
  462. int_sym = sym
  463. if int_var:
  464. integrand = integrand.subs(int_sym, 1)
  465. else:
  466. # Assume dx by default
  467. int_var = sympy.Symbol('x')
  468. if func.subexpr():
  469. if func.subexpr().atom():
  470. lower = convert_atom(func.subexpr().atom())
  471. else:
  472. lower = convert_expr(func.subexpr().expr())
  473. if func.supexpr().atom():
  474. upper = convert_atom(func.supexpr().atom())
  475. else:
  476. upper = convert_expr(func.supexpr().expr())
  477. return sympy.Integral(integrand, (int_var, lower, upper))
  478. else:
  479. return sympy.Integral(integrand, int_var)
  480. def handle_sum_or_prod(func, name):
  481. val = convert_mp(func.mp())
  482. iter_var = convert_expr(func.subeq().equality().expr(0))
  483. start = convert_expr(func.subeq().equality().expr(1))
  484. if func.supexpr().expr(): # ^{expr}
  485. end = convert_expr(func.supexpr().expr())
  486. else: # ^atom
  487. end = convert_atom(func.supexpr().atom())
  488. if name == "summation":
  489. return sympy.Sum(val, (iter_var, start, end))
  490. elif name == "product":
  491. return sympy.Product(val, (iter_var, start, end))
  492. def handle_limit(func):
  493. sub = func.limit_sub()
  494. if sub.LETTER():
  495. var = sympy.Symbol(sub.LETTER().getText())
  496. elif sub.SYMBOL():
  497. var = sympy.Symbol(sub.SYMBOL().getText()[1:])
  498. else:
  499. var = sympy.Symbol('x')
  500. if sub.SUB():
  501. direction = "-"
  502. elif sub.ADD():
  503. direction = "+"
  504. else:
  505. direction = "+-"
  506. approaching = convert_expr(sub.expr())
  507. content = convert_mp(func.mp())
  508. return sympy.Limit(content, var, approaching, direction)
  509. def get_differential_var(d):
  510. text = get_differential_var_str(d.getText())
  511. return sympy.Symbol(text)
  512. def get_differential_var_str(text):
  513. for i in range(1, len(text)):
  514. c = text[i]
  515. if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
  516. idx = i
  517. break
  518. text = text[idx:]
  519. if text[0] == "\\":
  520. text = text[1:]
  521. return text