pycode.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. """
  2. Python code printers
  3. This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code.
  4. """
  5. from collections import defaultdict
  6. from itertools import chain
  7. from sympy.core import S
  8. from sympy.core.mod import Mod
  9. from .precedence import precedence
  10. from .codeprinter import CodePrinter
  11. _kw = {
  12. 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
  13. 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
  14. 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
  15. 'with', 'yield', 'None', 'False', 'nonlocal', 'True'
  16. }
  17. _known_functions = {
  18. 'Abs': 'abs',
  19. 'Min': 'min',
  20. 'Max': 'max',
  21. }
  22. _known_functions_math = {
  23. 'acos': 'acos',
  24. 'acosh': 'acosh',
  25. 'asin': 'asin',
  26. 'asinh': 'asinh',
  27. 'atan': 'atan',
  28. 'atan2': 'atan2',
  29. 'atanh': 'atanh',
  30. 'ceiling': 'ceil',
  31. 'cos': 'cos',
  32. 'cosh': 'cosh',
  33. 'erf': 'erf',
  34. 'erfc': 'erfc',
  35. 'exp': 'exp',
  36. 'expm1': 'expm1',
  37. 'factorial': 'factorial',
  38. 'floor': 'floor',
  39. 'gamma': 'gamma',
  40. 'hypot': 'hypot',
  41. 'loggamma': 'lgamma',
  42. 'log': 'log',
  43. 'ln': 'log',
  44. 'log10': 'log10',
  45. 'log1p': 'log1p',
  46. 'log2': 'log2',
  47. 'sin': 'sin',
  48. 'sinh': 'sinh',
  49. 'Sqrt': 'sqrt',
  50. 'tan': 'tan',
  51. 'tanh': 'tanh'
  52. } # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf
  53. # radians trunc fmod fsum gcd degrees fabs]
  54. _known_constants_math = {
  55. 'Exp1': 'e',
  56. 'Pi': 'pi',
  57. 'E': 'e',
  58. 'Infinity': 'inf',
  59. 'NaN': 'nan',
  60. 'ComplexInfinity': 'nan'
  61. }
  62. def _print_known_func(self, expr):
  63. known = self.known_functions[expr.__class__.__name__]
  64. return '{name}({args})'.format(name=self._module_format(known),
  65. args=', '.join((self._print(arg) for arg in expr.args)))
  66. def _print_known_const(self, expr):
  67. known = self.known_constants[expr.__class__.__name__]
  68. return self._module_format(known)
  69. class AbstractPythonCodePrinter(CodePrinter):
  70. printmethod = "_pythoncode"
  71. language = "Python"
  72. reserved_words = _kw
  73. modules = None # initialized to a set in __init__
  74. tab = ' '
  75. _kf = dict(chain(
  76. _known_functions.items(),
  77. [(k, 'math.' + v) for k, v in _known_functions_math.items()]
  78. ))
  79. _kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
  80. _operators = {'and': 'and', 'or': 'or', 'not': 'not'}
  81. _default_settings = dict(
  82. CodePrinter._default_settings,
  83. user_functions={},
  84. precision=17,
  85. inline=True,
  86. fully_qualified_modules=True,
  87. contract=False,
  88. standard='python3',
  89. )
  90. def __init__(self, settings=None):
  91. super().__init__(settings)
  92. # Python standard handler
  93. std = self._settings['standard']
  94. if std is None:
  95. import sys
  96. std = 'python{}'.format(sys.version_info.major)
  97. if std != 'python3':
  98. raise ValueError('Only Python 3 is supported.')
  99. self.standard = std
  100. self.module_imports = defaultdict(set)
  101. # Known functions and constants handler
  102. self.known_functions = dict(self._kf, **(settings or {}).get(
  103. 'user_functions', {}))
  104. self.known_constants = dict(self._kc, **(settings or {}).get(
  105. 'user_constants', {}))
  106. def _declare_number_const(self, name, value):
  107. return "%s = %s" % (name, value)
  108. def _module_format(self, fqn, register=True):
  109. parts = fqn.split('.')
  110. if register and len(parts) > 1:
  111. self.module_imports['.'.join(parts[:-1])].add(parts[-1])
  112. if self._settings['fully_qualified_modules']:
  113. return fqn
  114. else:
  115. return fqn.split('(')[0].split('[')[0].split('.')[-1]
  116. def _format_code(self, lines):
  117. return lines
  118. def _get_statement(self, codestring):
  119. return "{}".format(codestring)
  120. def _get_comment(self, text):
  121. return " # {}".format(text)
  122. def _expand_fold_binary_op(self, op, args):
  123. """
  124. This method expands a fold on binary operations.
  125. ``functools.reduce`` is an example of a folded operation.
  126. For example, the expression
  127. `A + B + C + D`
  128. is folded into
  129. `((A + B) + C) + D`
  130. """
  131. if len(args) == 1:
  132. return self._print(args[0])
  133. else:
  134. return "%s(%s, %s)" % (
  135. self._module_format(op),
  136. self._expand_fold_binary_op(op, args[:-1]),
  137. self._print(args[-1]),
  138. )
  139. def _expand_reduce_binary_op(self, op, args):
  140. """
  141. This method expands a reductin on binary operations.
  142. Notice: this is NOT the same as ``functools.reduce``.
  143. For example, the expression
  144. `A + B + C + D`
  145. is reduced into:
  146. `(A + B) + (C + D)`
  147. """
  148. if len(args) == 1:
  149. return self._print(args[0])
  150. else:
  151. N = len(args)
  152. Nhalf = N // 2
  153. return "%s(%s, %s)" % (
  154. self._module_format(op),
  155. self._expand_reduce_binary_op(args[:Nhalf]),
  156. self._expand_reduce_binary_op(args[Nhalf:]),
  157. )
  158. def _print_NaN(self, expr):
  159. return "float('nan')"
  160. def _print_Infinity(self, expr):
  161. return "float('inf')"
  162. def _print_NegativeInfinity(self, expr):
  163. return "float('-inf')"
  164. def _print_ComplexInfinity(self, expr):
  165. return self._print_NaN(expr)
  166. def _print_Mod(self, expr):
  167. PREC = precedence(expr)
  168. return ('{} % {}'.format(*(self.parenthesize(x, PREC) for x in expr.args)))
  169. def _print_Piecewise(self, expr):
  170. result = []
  171. i = 0
  172. for arg in expr.args:
  173. e = arg.expr
  174. c = arg.cond
  175. if i == 0:
  176. result.append('(')
  177. result.append('(')
  178. result.append(self._print(e))
  179. result.append(')')
  180. result.append(' if ')
  181. result.append(self._print(c))
  182. result.append(' else ')
  183. i += 1
  184. result = result[:-1]
  185. if result[-1] == 'True':
  186. result = result[:-2]
  187. result.append(')')
  188. else:
  189. result.append(' else None)')
  190. return ''.join(result)
  191. def _print_Relational(self, expr):
  192. "Relational printer for Equality and Unequality"
  193. op = {
  194. '==' :'equal',
  195. '!=' :'not_equal',
  196. '<' :'less',
  197. '<=' :'less_equal',
  198. '>' :'greater',
  199. '>=' :'greater_equal',
  200. }
  201. if expr.rel_op in op:
  202. lhs = self._print(expr.lhs)
  203. rhs = self._print(expr.rhs)
  204. return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
  205. return super()._print_Relational(expr)
  206. def _print_ITE(self, expr):
  207. from sympy.functions.elementary.piecewise import Piecewise
  208. return self._print(expr.rewrite(Piecewise))
  209. def _print_Sum(self, expr):
  210. loops = (
  211. 'for {i} in range({a}, {b}+1)'.format(
  212. i=self._print(i),
  213. a=self._print(a),
  214. b=self._print(b))
  215. for i, a, b in expr.limits)
  216. return '(builtins.sum({function} {loops}))'.format(
  217. function=self._print(expr.function),
  218. loops=' '.join(loops))
  219. def _print_ImaginaryUnit(self, expr):
  220. return '1j'
  221. def _print_KroneckerDelta(self, expr):
  222. a, b = expr.args
  223. return '(1 if {a} == {b} else 0)'.format(
  224. a = self._print(a),
  225. b = self._print(b)
  226. )
  227. def _print_MatrixBase(self, expr):
  228. name = expr.__class__.__name__
  229. func = self.known_functions.get(name, name)
  230. return "%s(%s)" % (func, self._print(expr.tolist()))
  231. _print_SparseRepMatrix = \
  232. _print_MutableSparseMatrix = \
  233. _print_ImmutableSparseMatrix = \
  234. _print_Matrix = \
  235. _print_DenseMatrix = \
  236. _print_MutableDenseMatrix = \
  237. _print_ImmutableMatrix = \
  238. _print_ImmutableDenseMatrix = \
  239. lambda self, expr: self._print_MatrixBase(expr)
  240. def _indent_codestring(self, codestring):
  241. return '\n'.join([self.tab + line for line in codestring.split('\n')])
  242. def _print_FunctionDefinition(self, fd):
  243. body = '\n'.join((self._print(arg) for arg in fd.body))
  244. return "def {name}({parameters}):\n{body}".format(
  245. name=self._print(fd.name),
  246. parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
  247. body=self._indent_codestring(body)
  248. )
  249. def _print_While(self, whl):
  250. body = '\n'.join((self._print(arg) for arg in whl.body))
  251. return "while {cond}:\n{body}".format(
  252. cond=self._print(whl.condition),
  253. body=self._indent_codestring(body)
  254. )
  255. def _print_Declaration(self, decl):
  256. return '%s = %s' % (
  257. self._print(decl.variable.symbol),
  258. self._print(decl.variable.value)
  259. )
  260. def _print_Return(self, ret):
  261. arg, = ret.args
  262. return 'return %s' % self._print(arg)
  263. def _print_Print(self, prnt):
  264. print_args = ', '.join((self._print(arg) for arg in prnt.print_args))
  265. if prnt.format_string != None: # Must be '!= None', cannot be 'is not None'
  266. print_args = '{} % ({})'.format(
  267. self._print(prnt.format_string), print_args)
  268. if prnt.file != None: # Must be '!= None', cannot be 'is not None'
  269. print_args += ', file=%s' % self._print(prnt.file)
  270. return 'print(%s)' % print_args
  271. def _print_Stream(self, strm):
  272. if str(strm.name) == 'stdout':
  273. return self._module_format('sys.stdout')
  274. elif str(strm.name) == 'stderr':
  275. return self._module_format('sys.stderr')
  276. else:
  277. return self._print(strm.name)
  278. def _print_NoneToken(self, arg):
  279. return 'None'
  280. def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'):
  281. """Printing helper function for ``Pow``
  282. Notes
  283. =====
  284. This preprocesses the ``sqrt`` as math formatter and prints division
  285. Examples
  286. ========
  287. >>> from sympy import sqrt
  288. >>> from sympy.printing.pycode import PythonCodePrinter
  289. >>> from sympy.abc import x
  290. Python code printer automatically looks up ``math.sqrt``.
  291. >>> printer = PythonCodePrinter()
  292. >>> printer._hprint_Pow(sqrt(x), rational=True)
  293. 'x**(1/2)'
  294. >>> printer._hprint_Pow(sqrt(x), rational=False)
  295. 'math.sqrt(x)'
  296. >>> printer._hprint_Pow(1/sqrt(x), rational=True)
  297. 'x**(-1/2)'
  298. >>> printer._hprint_Pow(1/sqrt(x), rational=False)
  299. '1/math.sqrt(x)'
  300. >>> printer._hprint_Pow(1/x, rational=False)
  301. '1/x'
  302. >>> printer._hprint_Pow(1/x, rational=True)
  303. 'x**(-1)'
  304. Using sqrt from numpy or mpmath
  305. >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt')
  306. 'numpy.sqrt(x)'
  307. >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt')
  308. 'mpmath.sqrt(x)'
  309. See Also
  310. ========
  311. sympy.printing.str.StrPrinter._print_Pow
  312. """
  313. PREC = precedence(expr)
  314. if expr.exp == S.Half and not rational:
  315. func = self._module_format(sqrt)
  316. arg = self._print(expr.base)
  317. return '{func}({arg})'.format(func=func, arg=arg)
  318. if expr.is_commutative and not rational:
  319. if -expr.exp is S.Half:
  320. func = self._module_format(sqrt)
  321. num = self._print(S.One)
  322. arg = self._print(expr.base)
  323. return f"{num}/{func}({arg})"
  324. if expr.exp is S.NegativeOne:
  325. num = self._print(S.One)
  326. arg = self.parenthesize(expr.base, PREC, strict=False)
  327. return f"{num}/{arg}"
  328. base_str = self.parenthesize(expr.base, PREC, strict=False)
  329. exp_str = self.parenthesize(expr.exp, PREC, strict=False)
  330. return "{}**{}".format(base_str, exp_str)
  331. class ArrayPrinter:
  332. def _arrayify(self, indexed):
  333. from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
  334. try:
  335. return convert_indexed_to_array(indexed)
  336. except Exception:
  337. return indexed
  338. def _get_einsum_string(self, subranks, contraction_indices):
  339. letters = self._get_letter_generator_for_einsum()
  340. contraction_string = ""
  341. counter = 0
  342. d = {j: min(i) for i in contraction_indices for j in i}
  343. indices = []
  344. for rank_arg in subranks:
  345. lindices = []
  346. for i in range(rank_arg):
  347. if counter in d:
  348. lindices.append(d[counter])
  349. else:
  350. lindices.append(counter)
  351. counter += 1
  352. indices.append(lindices)
  353. mapping = {}
  354. letters_free = []
  355. letters_dum = []
  356. for i in indices:
  357. for j in i:
  358. if j not in mapping:
  359. l = next(letters)
  360. mapping[j] = l
  361. else:
  362. l = mapping[j]
  363. contraction_string += l
  364. if j in d:
  365. if l not in letters_dum:
  366. letters_dum.append(l)
  367. else:
  368. letters_free.append(l)
  369. contraction_string += ","
  370. contraction_string = contraction_string[:-1]
  371. return contraction_string, letters_free, letters_dum
  372. def _get_letter_generator_for_einsum(self):
  373. for i in range(97, 123):
  374. yield chr(i)
  375. for i in range(65, 91):
  376. yield chr(i)
  377. raise ValueError("out of letters")
  378. def _print_ArrayTensorProduct(self, expr):
  379. letters = self._get_letter_generator_for_einsum()
  380. contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks])
  381. return '%s("%s", %s)' % (
  382. self._module_format(self._module + "." + self._einsum),
  383. contraction_string,
  384. ", ".join([self._print(arg) for arg in expr.args])
  385. )
  386. def _print_ArrayContraction(self, expr):
  387. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  388. base = expr.expr
  389. contraction_indices = expr.contraction_indices
  390. if isinstance(base, ArrayTensorProduct):
  391. elems = ",".join(["%s" % (self._print(arg)) for arg in base.args])
  392. ranks = base.subranks
  393. else:
  394. elems = self._print(base)
  395. ranks = [len(base.shape)]
  396. contraction_string, letters_free, letters_dum = self._get_einsum_string(ranks, contraction_indices)
  397. if not contraction_indices:
  398. return self._print(base)
  399. if isinstance(base, ArrayTensorProduct):
  400. elems = ",".join(["%s" % (self._print(arg)) for arg in base.args])
  401. else:
  402. elems = self._print(base)
  403. return "%s(\"%s\", %s)" % (
  404. self._module_format(self._module + "." + self._einsum),
  405. "{}->{}".format(contraction_string, "".join(sorted(letters_free))),
  406. elems,
  407. )
  408. def _print_ArrayDiagonal(self, expr):
  409. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  410. diagonal_indices = list(expr.diagonal_indices)
  411. if isinstance(expr.expr, ArrayTensorProduct):
  412. subranks = expr.expr.subranks
  413. elems = expr.expr.args
  414. else:
  415. subranks = expr.subranks
  416. elems = [expr.expr]
  417. diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices)
  418. elems = [self._print(i) for i in elems]
  419. return '%s("%s", %s)' % (
  420. self._module_format(self._module + "." + self._einsum),
  421. "{}->{}".format(diagonal_string, "".join(letters_free+letters_dum)),
  422. ", ".join(elems)
  423. )
  424. def _print_PermuteDims(self, expr):
  425. return "%s(%s, %s)" % (
  426. self._module_format(self._module + "." + self._transpose),
  427. self._print(expr.expr),
  428. self._print(expr.permutation.array_form),
  429. )
  430. def _print_ArrayAdd(self, expr):
  431. return self._expand_fold_binary_op(self._module + "." + self._add, expr.args)
  432. def _print_OneArray(self, expr):
  433. return "%s((%s,))" % (
  434. self._module_format(self._module+ "." + self._ones),
  435. ','.join(map(self._print,expr.args))
  436. )
  437. def _print_ZeroArray(self, expr):
  438. return "%s((%s,))" % (
  439. self._module_format(self._module+ "." + self._zeros),
  440. ','.join(map(self._print,expr.args))
  441. )
  442. def _print_Assignment(self, expr):
  443. #XXX: maybe this needs to happen at a higher level e.g. at _print or
  444. #doprint?
  445. lhs = self._print(self._arrayify(expr.lhs))
  446. rhs = self._print(self._arrayify(expr.rhs))
  447. return "%s = %s" % ( lhs, rhs )
  448. def _print_IndexedBase(self, expr):
  449. return self._print_ArraySymbol(expr)
  450. class PythonCodePrinter(AbstractPythonCodePrinter):
  451. def _print_sign(self, e):
  452. return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
  453. f=self._module_format('math.copysign'), e=self._print(e.args[0]))
  454. def _print_Not(self, expr):
  455. PREC = precedence(expr)
  456. return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
  457. def _print_Indexed(self, expr):
  458. base = expr.args[0]
  459. index = expr.args[1:]
  460. return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index]))
  461. def _print_Pow(self, expr, rational=False):
  462. return self._hprint_Pow(expr, rational=rational)
  463. def _print_Rational(self, expr):
  464. return '{}/{}'.format(expr.p, expr.q)
  465. def _print_Half(self, expr):
  466. return self._print_Rational(expr)
  467. def _print_frac(self, expr):
  468. return self._print_Mod(Mod(expr.args[0], 1))
  469. def _print_Symbol(self, expr):
  470. name = super()._print_Symbol(expr)
  471. if name in self.reserved_words:
  472. if self._settings['error_on_reserved']:
  473. msg = ('This expression includes the symbol "{}" which is a '
  474. 'reserved keyword in this language.')
  475. raise ValueError(msg.format(name))
  476. return name + self._settings['reserved_word_suffix']
  477. elif '{' in name: # Remove curly braces from subscripted variables
  478. return name.replace('{', '').replace('}', '')
  479. else:
  480. return name
  481. _print_lowergamma = CodePrinter._print_not_supported
  482. _print_uppergamma = CodePrinter._print_not_supported
  483. _print_fresnelc = CodePrinter._print_not_supported
  484. _print_fresnels = CodePrinter._print_not_supported
  485. for k in PythonCodePrinter._kf:
  486. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)
  487. for k in _known_constants_math:
  488. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)
  489. def pycode(expr, **settings):
  490. """ Converts an expr to a string of Python code
  491. Parameters
  492. ==========
  493. expr : Expr
  494. A SymPy expression.
  495. fully_qualified_modules : bool
  496. Whether or not to write out full module names of functions
  497. (``math.sin`` vs. ``sin``). default: ``True``.
  498. standard : str or None, optional
  499. Only 'python3' (default) is supported.
  500. This parameter may be removed in the future.
  501. Examples
  502. ========
  503. >>> from sympy import pycode, tan, Symbol
  504. >>> pycode(tan(Symbol('x')) + 1)
  505. 'math.tan(x) + 1'
  506. """
  507. return PythonCodePrinter(settings).doprint(expr)
  508. _not_in_mpmath = 'log1p log2'.split()
  509. _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
  510. _known_functions_mpmath = dict(_in_mpmath, **{
  511. 'beta': 'beta',
  512. 'frac': 'frac',
  513. 'fresnelc': 'fresnelc',
  514. 'fresnels': 'fresnels',
  515. 'sign': 'sign',
  516. 'loggamma': 'loggamma',
  517. 'hyper': 'hyper',
  518. 'meijerg': 'meijerg',
  519. 'besselj': 'besselj',
  520. 'bessely': 'bessely',
  521. 'besseli': 'besseli',
  522. 'besselk': 'besselk',
  523. })
  524. _known_constants_mpmath = {
  525. 'Exp1': 'e',
  526. 'Pi': 'pi',
  527. 'GoldenRatio': 'phi',
  528. 'EulerGamma': 'euler',
  529. 'Catalan': 'catalan',
  530. 'NaN': 'nan',
  531. 'Infinity': 'inf',
  532. 'NegativeInfinity': 'ninf'
  533. }
  534. def _unpack_integral_limits(integral_expr):
  535. """ helper function for _print_Integral that
  536. - accepts an Integral expression
  537. - returns a tuple of
  538. - a list variables of integration
  539. - a list of tuples of the upper and lower limits of integration
  540. """
  541. integration_vars = []
  542. limits = []
  543. for integration_range in integral_expr.limits:
  544. if len(integration_range) == 3:
  545. integration_var, lower_limit, upper_limit = integration_range
  546. else:
  547. raise NotImplementedError("Only definite integrals are supported")
  548. integration_vars.append(integration_var)
  549. limits.append((lower_limit, upper_limit))
  550. return integration_vars, limits
  551. class MpmathPrinter(PythonCodePrinter):
  552. """
  553. Lambda printer for mpmath which maintains precision for floats
  554. """
  555. printmethod = "_mpmathcode"
  556. language = "Python with mpmath"
  557. _kf = dict(chain(
  558. _known_functions.items(),
  559. [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()]
  560. ))
  561. _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()}
  562. def _print_Float(self, e):
  563. # XXX: This does not handle setting mpmath.mp.dps. It is assumed that
  564. # the caller of the lambdified function will have set it to sufficient
  565. # precision to match the Floats in the expression.
  566. # Remove 'mpz' if gmpy is installed.
  567. args = str(tuple(map(int, e._mpf_)))
  568. return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args)
  569. def _print_Rational(self, e):
  570. return "{func}({p})/{func}({q})".format(
  571. func=self._module_format('mpmath.mpf'),
  572. q=self._print(e.q),
  573. p=self._print(e.p)
  574. )
  575. def _print_Half(self, e):
  576. return self._print_Rational(e)
  577. def _print_uppergamma(self, e):
  578. return "{}({}, {}, {})".format(
  579. self._module_format('mpmath.gammainc'),
  580. self._print(e.args[0]),
  581. self._print(e.args[1]),
  582. self._module_format('mpmath.inf'))
  583. def _print_lowergamma(self, e):
  584. return "{}({}, 0, {})".format(
  585. self._module_format('mpmath.gammainc'),
  586. self._print(e.args[0]),
  587. self._print(e.args[1]))
  588. def _print_log2(self, e):
  589. return '{0}({1})/{0}(2)'.format(
  590. self._module_format('mpmath.log'), self._print(e.args[0]))
  591. def _print_log1p(self, e):
  592. return '{}({})'.format(
  593. self._module_format('mpmath.log1p'), self._print(e.args[0]))
  594. def _print_Pow(self, expr, rational=False):
  595. return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt')
  596. def _print_Integral(self, e):
  597. integration_vars, limits = _unpack_integral_limits(e)
  598. return "{}(lambda {}: {}, {})".format(
  599. self._module_format("mpmath.quad"),
  600. ", ".join(map(self._print, integration_vars)),
  601. self._print(e.args[0]),
  602. ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  603. for k in MpmathPrinter._kf:
  604. setattr(MpmathPrinter, '_print_%s' % k, _print_known_func)
  605. for k in _known_constants_mpmath:
  606. setattr(MpmathPrinter, '_print_%s' % k, _print_known_const)
  607. class SymPyPrinter(AbstractPythonCodePrinter):
  608. language = "Python with SymPy"
  609. def _print_Function(self, expr):
  610. mod = expr.func.__module__ or ''
  611. return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__),
  612. ', '.join((self._print(arg) for arg in expr.args)))
  613. def _print_Pow(self, expr, rational=False):
  614. return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')