123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604 |
- # Ported from latex2sympy by @augustt198
- # https://github.com/augustt198/latex2sympy
- # See license in LICENSE.txt
- from importlib.metadata import version
- import sympy
- from sympy.external import import_module
- from sympy.printing.str import StrPrinter
- from sympy.physics.quantum.state import Bra, Ket
- from .errors import LaTeXParsingError
- LaTeXParser = LaTeXLexer = MathErrorListener = None
- try:
- LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser',
- import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser
- LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer',
- import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer
- except Exception:
- pass
- ErrorListener = import_module('antlr4.error.ErrorListener',
- warn_not_installed=True,
- import_kwargs={'fromlist': ['ErrorListener']}
- )
- if ErrorListener:
- class MathErrorListener(ErrorListener.ErrorListener): # type: ignore
- def __init__(self, src):
- super(ErrorListener.ErrorListener, self).__init__()
- self.src = src
- def syntaxError(self, recog, symbol, line, col, msg, e):
- fmt = "%s\n%s\n%s"
- marker = "~" * col + "^"
- if msg.startswith("missing"):
- err = fmt % (msg, self.src, marker)
- elif msg.startswith("no viable"):
- err = fmt % ("I expected something else here", self.src, marker)
- elif msg.startswith("mismatched"):
- names = LaTeXParser.literalNames
- expected = [
- names[i] for i in e.getExpectedTokens() if i < len(names)
- ]
- if len(expected) < 10:
- expected = " ".join(expected)
- err = (fmt % ("I expected one of these: " + expected, self.src,
- marker))
- else:
- err = (fmt % ("I expected something else here", self.src,
- marker))
- else:
- err = fmt % ("I don't understand this", self.src, marker)
- raise LaTeXParsingError(err)
- def parse_latex(sympy):
- antlr4 = import_module('antlr4')
- if None in [antlr4, MathErrorListener] or \
- not version('antlr4-python3-runtime').startswith('4.11'):
- raise ImportError("LaTeX parsing requires the antlr4 Python package,"
- " provided by pip (antlr4-python3-runtime) or"
- " conda (antlr-python-runtime), version 4.11")
- matherror = MathErrorListener(sympy)
- stream = antlr4.InputStream(sympy)
- lex = LaTeXLexer(stream)
- lex.removeErrorListeners()
- lex.addErrorListener(matherror)
- tokens = antlr4.CommonTokenStream(lex)
- parser = LaTeXParser(tokens)
- # remove default console error listener
- parser.removeErrorListeners()
- parser.addErrorListener(matherror)
- relation = parser.math().relation()
- expr = convert_relation(relation)
- return expr
- def convert_relation(rel):
- if rel.expr():
- return convert_expr(rel.expr())
- lh = convert_relation(rel.relation(0))
- rh = convert_relation(rel.relation(1))
- if rel.LT():
- return sympy.StrictLessThan(lh, rh)
- elif rel.LTE():
- return sympy.LessThan(lh, rh)
- elif rel.GT():
- return sympy.StrictGreaterThan(lh, rh)
- elif rel.GTE():
- return sympy.GreaterThan(lh, rh)
- elif rel.EQUAL():
- return sympy.Eq(lh, rh)
- elif rel.NEQ():
- return sympy.Ne(lh, rh)
- def convert_expr(expr):
- return convert_add(expr.additive())
- def convert_add(add):
- if add.ADD():
- lh = convert_add(add.additive(0))
- rh = convert_add(add.additive(1))
- return sympy.Add(lh, rh, evaluate=False)
- elif add.SUB():
- lh = convert_add(add.additive(0))
- rh = convert_add(add.additive(1))
- if hasattr(rh, "is_Atom") and rh.is_Atom:
- return sympy.Add(lh, -1 * rh, evaluate=False)
- return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), evaluate=False)
- else:
- return convert_mp(add.mp())
- def convert_mp(mp):
- if hasattr(mp, 'mp'):
- mp_left = mp.mp(0)
- mp_right = mp.mp(1)
- else:
- mp_left = mp.mp_nofunc(0)
- mp_right = mp.mp_nofunc(1)
- if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT():
- lh = convert_mp(mp_left)
- rh = convert_mp(mp_right)
- return sympy.Mul(lh, rh, evaluate=False)
- elif mp.DIV() or mp.CMD_DIV() or mp.COLON():
- lh = convert_mp(mp_left)
- rh = convert_mp(mp_right)
- return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False)
- else:
- if hasattr(mp, 'unary'):
- return convert_unary(mp.unary())
- else:
- return convert_unary(mp.unary_nofunc())
- def convert_unary(unary):
- if hasattr(unary, 'unary'):
- nested_unary = unary.unary()
- else:
- nested_unary = unary.unary_nofunc()
- if hasattr(unary, 'postfix_nofunc'):
- first = unary.postfix()
- tail = unary.postfix_nofunc()
- postfix = [first] + tail
- else:
- postfix = unary.postfix()
- if unary.ADD():
- return convert_unary(nested_unary)
- elif unary.SUB():
- numabs = convert_unary(nested_unary)
- # Use Integer(-n) instead of Mul(-1, n)
- return -numabs
- elif postfix:
- return convert_postfix_list(postfix)
- def convert_postfix_list(arr, i=0):
- if i >= len(arr):
- raise LaTeXParsingError("Index out of bounds")
- res = convert_postfix(arr[i])
- if isinstance(res, sympy.Expr):
- if i == len(arr) - 1:
- return res # nothing to multiply by
- else:
- if i > 0:
- left = convert_postfix(arr[i - 1])
- right = convert_postfix(arr[i + 1])
- if isinstance(left, sympy.Expr) and isinstance(
- right, sympy.Expr):
- left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol)
- right_syms = convert_postfix(arr[i + 1]).atoms(
- sympy.Symbol)
- # if the left and right sides contain no variables and the
- # symbol in between is 'x', treat as multiplication.
- if not (left_syms or right_syms) and str(res) == 'x':
- return convert_postfix_list(arr, i + 1)
- # multiply by next
- return sympy.Mul(
- res, convert_postfix_list(arr, i + 1), evaluate=False)
- else: # must be derivative
- wrt = res[0]
- if i == len(arr) - 1:
- raise LaTeXParsingError("Expected expression for derivative")
- else:
- expr = convert_postfix_list(arr, i + 1)
- return sympy.Derivative(expr, wrt)
- def do_subs(expr, at):
- if at.expr():
- at_expr = convert_expr(at.expr())
- syms = at_expr.atoms(sympy.Symbol)
- if len(syms) == 0:
- return expr
- elif len(syms) > 0:
- sym = next(iter(syms))
- return expr.subs(sym, at_expr)
- elif at.equality():
- lh = convert_expr(at.equality().expr(0))
- rh = convert_expr(at.equality().expr(1))
- return expr.subs(lh, rh)
- def convert_postfix(postfix):
- if hasattr(postfix, 'exp'):
- exp_nested = postfix.exp()
- else:
- exp_nested = postfix.exp_nofunc()
- exp = convert_exp(exp_nested)
- for op in postfix.postfix_op():
- if op.BANG():
- if isinstance(exp, list):
- raise LaTeXParsingError("Cannot apply postfix to derivative")
- exp = sympy.factorial(exp, evaluate=False)
- elif op.eval_at():
- ev = op.eval_at()
- at_b = None
- at_a = None
- if ev.eval_at_sup():
- at_b = do_subs(exp, ev.eval_at_sup())
- if ev.eval_at_sub():
- at_a = do_subs(exp, ev.eval_at_sub())
- if at_b is not None and at_a is not None:
- exp = sympy.Add(at_b, -1 * at_a, evaluate=False)
- elif at_b is not None:
- exp = at_b
- elif at_a is not None:
- exp = at_a
- return exp
- def convert_exp(exp):
- if hasattr(exp, 'exp'):
- exp_nested = exp.exp()
- else:
- exp_nested = exp.exp_nofunc()
- if exp_nested:
- base = convert_exp(exp_nested)
- if isinstance(base, list):
- raise LaTeXParsingError("Cannot raise derivative to power")
- if exp.atom():
- exponent = convert_atom(exp.atom())
- elif exp.expr():
- exponent = convert_expr(exp.expr())
- return sympy.Pow(base, exponent, evaluate=False)
- else:
- if hasattr(exp, 'comp'):
- return convert_comp(exp.comp())
- else:
- return convert_comp(exp.comp_nofunc())
- def convert_comp(comp):
- if comp.group():
- return convert_expr(comp.group().expr())
- elif comp.abs_group():
- return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False)
- elif comp.atom():
- return convert_atom(comp.atom())
- elif comp.floor():
- return convert_floor(comp.floor())
- elif comp.ceil():
- return convert_ceil(comp.ceil())
- elif comp.func():
- return convert_func(comp.func())
- def convert_atom(atom):
- if atom.LETTER():
- sname = atom.LETTER().getText()
- if atom.subexpr():
- if atom.subexpr().expr(): # subscript is expr
- subscript = convert_expr(atom.subexpr().expr())
- else: # subscript is atom
- subscript = convert_atom(atom.subexpr().atom())
- sname += '_{' + StrPrinter().doprint(subscript) + '}'
- if atom.SINGLE_QUOTES():
- sname += atom.SINGLE_QUOTES().getText() # put after subscript for easy identify
- return sympy.Symbol(sname)
- elif atom.SYMBOL():
- s = atom.SYMBOL().getText()[1:]
- if s == "infty":
- return sympy.oo
- else:
- if atom.subexpr():
- subscript = None
- if atom.subexpr().expr(): # subscript is expr
- subscript = convert_expr(atom.subexpr().expr())
- else: # subscript is atom
- subscript = convert_atom(atom.subexpr().atom())
- subscriptName = StrPrinter().doprint(subscript)
- s += '_{' + subscriptName + '}'
- return sympy.Symbol(s)
- elif atom.number():
- s = atom.number().getText().replace(",", "")
- return sympy.Number(s)
- elif atom.DIFFERENTIAL():
- var = get_differential_var(atom.DIFFERENTIAL())
- return sympy.Symbol('d' + var.name)
- elif atom.mathit():
- text = rule2text(atom.mathit().mathit_text())
- return sympy.Symbol(text)
- elif atom.frac():
- return convert_frac(atom.frac())
- elif atom.binom():
- return convert_binom(atom.binom())
- elif atom.bra():
- val = convert_expr(atom.bra().expr())
- return Bra(val)
- elif atom.ket():
- val = convert_expr(atom.ket().expr())
- return Ket(val)
- def rule2text(ctx):
- stream = ctx.start.getInputStream()
- # starting index of starting token
- startIdx = ctx.start.start
- # stopping index of stopping token
- stopIdx = ctx.stop.stop
- return stream.getText(startIdx, stopIdx)
- def convert_frac(frac):
- diff_op = False
- partial_op = False
- if frac.lower and frac.upper:
- lower_itv = frac.lower.getSourceInterval()
- lower_itv_len = lower_itv[1] - lower_itv[0] + 1
- if (frac.lower.start == frac.lower.stop
- and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL):
- wrt = get_differential_var_str(frac.lower.start.text)
- diff_op = True
- elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL
- and frac.lower.start.text == '\\partial'
- and (frac.lower.stop.type == LaTeXLexer.LETTER
- or frac.lower.stop.type == LaTeXLexer.SYMBOL)):
- partial_op = True
- wrt = frac.lower.stop.text
- if frac.lower.stop.type == LaTeXLexer.SYMBOL:
- wrt = wrt[1:]
- if diff_op or partial_op:
- wrt = sympy.Symbol(wrt)
- if (diff_op and frac.upper.start == frac.upper.stop
- and frac.upper.start.type == LaTeXLexer.LETTER
- and frac.upper.start.text == 'd'):
- return [wrt]
- elif (partial_op and frac.upper.start == frac.upper.stop
- and frac.upper.start.type == LaTeXLexer.SYMBOL
- and frac.upper.start.text == '\\partial'):
- return [wrt]
- upper_text = rule2text(frac.upper)
- expr_top = None
- if diff_op and upper_text.startswith('d'):
- expr_top = parse_latex(upper_text[1:])
- elif partial_op and frac.upper.start.text == '\\partial':
- expr_top = parse_latex(upper_text[len('\\partial'):])
- if expr_top:
- return sympy.Derivative(expr_top, wrt)
- if frac.upper:
- expr_top = convert_expr(frac.upper)
- else:
- expr_top = sympy.Number(frac.upperd.text)
- if frac.lower:
- expr_bot = convert_expr(frac.lower)
- else:
- expr_bot = sympy.Number(frac.lowerd.text)
- inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False)
- if expr_top == 1:
- return inverse_denom
- else:
- return sympy.Mul(expr_top, inverse_denom, evaluate=False)
- def convert_binom(binom):
- expr_n = convert_expr(binom.n)
- expr_k = convert_expr(binom.k)
- return sympy.binomial(expr_n, expr_k, evaluate=False)
- def convert_floor(floor):
- val = convert_expr(floor.val)
- return sympy.floor(val, evaluate=False)
- def convert_ceil(ceil):
- val = convert_expr(ceil.val)
- return sympy.ceiling(val, evaluate=False)
- def convert_func(func):
- if func.func_normal():
- if func.L_PAREN(): # function called with parenthesis
- arg = convert_func_arg(func.func_arg())
- else:
- arg = convert_func_arg(func.func_arg_noparens())
- name = func.func_normal().start.text[1:]
- # change arc<trig> -> a<trig>
- if name in [
- "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot"
- ]:
- name = "a" + name[3:]
- expr = getattr(sympy.functions, name)(arg, evaluate=False)
- if name in ["arsinh", "arcosh", "artanh"]:
- name = "a" + name[2:]
- expr = getattr(sympy.functions, name)(arg, evaluate=False)
- if name == "exp":
- expr = sympy.exp(arg, evaluate=False)
- if name in ("log", "lg", "ln"):
- if func.subexpr():
- if func.subexpr().expr():
- base = convert_expr(func.subexpr().expr())
- else:
- base = convert_atom(func.subexpr().atom())
- elif name == "lg": # ISO 80000-2:2019
- base = 10
- elif name in ("ln", "log"): # SymPy's latex printer prints ln as log by default
- base = sympy.E
- expr = sympy.log(arg, base, evaluate=False)
- func_pow = None
- should_pow = True
- if func.supexpr():
- if func.supexpr().expr():
- func_pow = convert_expr(func.supexpr().expr())
- else:
- func_pow = convert_atom(func.supexpr().atom())
- if name in [
- "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh",
- "tanh"
- ]:
- if func_pow == -1:
- name = "a" + name
- should_pow = False
- expr = getattr(sympy.functions, name)(arg, evaluate=False)
- if func_pow and should_pow:
- expr = sympy.Pow(expr, func_pow, evaluate=False)
- return expr
- elif func.LETTER() or func.SYMBOL():
- if func.LETTER():
- fname = func.LETTER().getText()
- elif func.SYMBOL():
- fname = func.SYMBOL().getText()[1:]
- fname = str(fname) # can't be unicode
- if func.subexpr():
- if func.subexpr().expr(): # subscript is expr
- subscript = convert_expr(func.subexpr().expr())
- else: # subscript is atom
- subscript = convert_atom(func.subexpr().atom())
- subscriptName = StrPrinter().doprint(subscript)
- fname += '_{' + subscriptName + '}'
- if func.SINGLE_QUOTES():
- fname += func.SINGLE_QUOTES().getText()
- input_args = func.args()
- output_args = []
- while input_args.args(): # handle multiple arguments to function
- output_args.append(convert_expr(input_args.expr()))
- input_args = input_args.args()
- output_args.append(convert_expr(input_args.expr()))
- return sympy.Function(fname)(*output_args)
- elif func.FUNC_INT():
- return handle_integral(func)
- elif func.FUNC_SQRT():
- expr = convert_expr(func.base)
- if func.root:
- r = convert_expr(func.root)
- return sympy.root(expr, r, evaluate=False)
- else:
- return sympy.sqrt(expr, evaluate=False)
- elif func.FUNC_OVERLINE():
- expr = convert_expr(func.base)
- return sympy.conjugate(expr, evaluate=False)
- elif func.FUNC_SUM():
- return handle_sum_or_prod(func, "summation")
- elif func.FUNC_PROD():
- return handle_sum_or_prod(func, "product")
- elif func.FUNC_LIM():
- return handle_limit(func)
- def convert_func_arg(arg):
- if hasattr(arg, 'expr'):
- return convert_expr(arg.expr())
- else:
- return convert_mp(arg.mp_nofunc())
- def handle_integral(func):
- if func.additive():
- integrand = convert_add(func.additive())
- elif func.frac():
- integrand = convert_frac(func.frac())
- else:
- integrand = 1
- int_var = None
- if func.DIFFERENTIAL():
- int_var = get_differential_var(func.DIFFERENTIAL())
- else:
- for sym in integrand.atoms(sympy.Symbol):
- s = str(sym)
- if len(s) > 1 and s[0] == 'd':
- if s[1] == '\\':
- int_var = sympy.Symbol(s[2:])
- else:
- int_var = sympy.Symbol(s[1:])
- int_sym = sym
- if int_var:
- integrand = integrand.subs(int_sym, 1)
- else:
- # Assume dx by default
- int_var = sympy.Symbol('x')
- if func.subexpr():
- if func.subexpr().atom():
- lower = convert_atom(func.subexpr().atom())
- else:
- lower = convert_expr(func.subexpr().expr())
- if func.supexpr().atom():
- upper = convert_atom(func.supexpr().atom())
- else:
- upper = convert_expr(func.supexpr().expr())
- return sympy.Integral(integrand, (int_var, lower, upper))
- else:
- return sympy.Integral(integrand, int_var)
- def handle_sum_or_prod(func, name):
- val = convert_mp(func.mp())
- iter_var = convert_expr(func.subeq().equality().expr(0))
- start = convert_expr(func.subeq().equality().expr(1))
- if func.supexpr().expr(): # ^{expr}
- end = convert_expr(func.supexpr().expr())
- else: # ^atom
- end = convert_atom(func.supexpr().atom())
- if name == "summation":
- return sympy.Sum(val, (iter_var, start, end))
- elif name == "product":
- return sympy.Product(val, (iter_var, start, end))
- def handle_limit(func):
- sub = func.limit_sub()
- if sub.LETTER():
- var = sympy.Symbol(sub.LETTER().getText())
- elif sub.SYMBOL():
- var = sympy.Symbol(sub.SYMBOL().getText()[1:])
- else:
- var = sympy.Symbol('x')
- if sub.SUB():
- direction = "-"
- elif sub.ADD():
- direction = "+"
- else:
- direction = "+-"
- approaching = convert_expr(sub.expr())
- content = convert_mp(func.mp())
- return sympy.Limit(content, var, approaching, direction)
- def get_differential_var(d):
- text = get_differential_var_str(d.getText())
- return sympy.Symbol(text)
- def get_differential_var_str(text):
- for i in range(1, len(text)):
- c = text[i]
- if not (c == " " or c == "\r" or c == "\n" or c == "\t"):
- idx = i
- break
- text = text[idx:]
- if text[0] == "\\":
- text = text[1:]
- return text
|