lambdarepr.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from .pycode import (
  2. PythonCodePrinter,
  3. MpmathPrinter,
  4. )
  5. from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility
  6. from sympy.core.sorting import default_sort_key
  7. __all__ = [
  8. 'PythonCodePrinter',
  9. 'MpmathPrinter', # MpmathPrinter is published for backward compatibility
  10. 'NumPyPrinter',
  11. 'LambdaPrinter',
  12. 'NumPyPrinter',
  13. 'IntervalPrinter',
  14. 'lambdarepr',
  15. ]
  16. class LambdaPrinter(PythonCodePrinter):
  17. """
  18. This printer converts expressions into strings that can be used by
  19. lambdify.
  20. """
  21. printmethod = "_lambdacode"
  22. def _print_And(self, expr):
  23. result = ['(']
  24. for arg in sorted(expr.args, key=default_sort_key):
  25. result.extend(['(', self._print(arg), ')'])
  26. result.append(' and ')
  27. result = result[:-1]
  28. result.append(')')
  29. return ''.join(result)
  30. def _print_Or(self, expr):
  31. result = ['(']
  32. for arg in sorted(expr.args, key=default_sort_key):
  33. result.extend(['(', self._print(arg), ')'])
  34. result.append(' or ')
  35. result = result[:-1]
  36. result.append(')')
  37. return ''.join(result)
  38. def _print_Not(self, expr):
  39. result = ['(', 'not (', self._print(expr.args[0]), '))']
  40. return ''.join(result)
  41. def _print_BooleanTrue(self, expr):
  42. return "True"
  43. def _print_BooleanFalse(self, expr):
  44. return "False"
  45. def _print_ITE(self, expr):
  46. result = [
  47. '((', self._print(expr.args[1]),
  48. ') if (', self._print(expr.args[0]),
  49. ') else (', self._print(expr.args[2]), '))'
  50. ]
  51. return ''.join(result)
  52. def _print_NumberSymbol(self, expr):
  53. return str(expr)
  54. def _print_Pow(self, expr, **kwargs):
  55. # XXX Temporary workaround. Should Python math printer be
  56. # isolated from PythonCodePrinter?
  57. return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs)
  58. # numexpr works by altering the string passed to numexpr.evaluate
  59. # rather than by populating a namespace. Thus a special printer...
  60. class NumExprPrinter(LambdaPrinter):
  61. # key, value pairs correspond to SymPy name and numexpr name
  62. # functions not appearing in this dict will raise a TypeError
  63. printmethod = "_numexprcode"
  64. _numexpr_functions = {
  65. 'sin' : 'sin',
  66. 'cos' : 'cos',
  67. 'tan' : 'tan',
  68. 'asin': 'arcsin',
  69. 'acos': 'arccos',
  70. 'atan': 'arctan',
  71. 'atan2' : 'arctan2',
  72. 'sinh' : 'sinh',
  73. 'cosh' : 'cosh',
  74. 'tanh' : 'tanh',
  75. 'asinh': 'arcsinh',
  76. 'acosh': 'arccosh',
  77. 'atanh': 'arctanh',
  78. 'ln' : 'log',
  79. 'log': 'log',
  80. 'exp': 'exp',
  81. 'sqrt' : 'sqrt',
  82. 'Abs' : 'abs',
  83. 'conjugate' : 'conj',
  84. 'im' : 'imag',
  85. 're' : 'real',
  86. 'where' : 'where',
  87. 'complex' : 'complex',
  88. 'contains' : 'contains',
  89. }
  90. module = 'numexpr'
  91. def _print_ImaginaryUnit(self, expr):
  92. return '1j'
  93. def _print_seq(self, seq, delimiter=', '):
  94. # simplified _print_seq taken from pretty.py
  95. s = [self._print(item) for item in seq]
  96. if s:
  97. return delimiter.join(s)
  98. else:
  99. return ""
  100. def _print_Function(self, e):
  101. func_name = e.func.__name__
  102. nstr = self._numexpr_functions.get(func_name, None)
  103. if nstr is None:
  104. # check for implemented_function
  105. if hasattr(e, '_imp_'):
  106. return "(%s)" % self._print(e._imp_(*e.args))
  107. else:
  108. raise TypeError("numexpr does not support function '%s'" %
  109. func_name)
  110. return "%s(%s)" % (nstr, self._print_seq(e.args))
  111. def _print_Piecewise(self, expr):
  112. "Piecewise function printer"
  113. exprs = [self._print(arg.expr) for arg in expr.args]
  114. conds = [self._print(arg.cond) for arg in expr.args]
  115. # If [default_value, True] is a (expr, cond) sequence in a Piecewise object
  116. # it will behave the same as passing the 'default' kwarg to select()
  117. # *as long as* it is the last element in expr.args.
  118. # If this is not the case, it may be triggered prematurely.
  119. ans = []
  120. parenthesis_count = 0
  121. is_last_cond_True = False
  122. for cond, expr in zip(conds, exprs):
  123. if cond == 'True':
  124. ans.append(expr)
  125. is_last_cond_True = True
  126. break
  127. else:
  128. ans.append('where(%s, %s, ' % (cond, expr))
  129. parenthesis_count += 1
  130. if not is_last_cond_True:
  131. # See https://github.com/pydata/numexpr/issues/298
  132. #
  133. # simplest way to put a nan but raises
  134. # 'RuntimeWarning: invalid value encountered in log'
  135. #
  136. # There are other ways to do this such as
  137. #
  138. # >>> import numexpr as ne
  139. # >>> nan = float('nan')
  140. # >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan})
  141. # array([-1., nan, nan])
  142. #
  143. # That needs to be handled in the lambdified function though rather
  144. # than here in the printer.
  145. ans.append('log(-1)')
  146. return ''.join(ans) + ')' * parenthesis_count
  147. def _print_ITE(self, expr):
  148. from sympy.functions.elementary.piecewise import Piecewise
  149. return self._print(expr.rewrite(Piecewise))
  150. def blacklisted(self, expr):
  151. raise TypeError("numexpr cannot be used with %s" %
  152. expr.__class__.__name__)
  153. # blacklist all Matrix printing
  154. _print_SparseRepMatrix = \
  155. _print_MutableSparseMatrix = \
  156. _print_ImmutableSparseMatrix = \
  157. _print_Matrix = \
  158. _print_DenseMatrix = \
  159. _print_MutableDenseMatrix = \
  160. _print_ImmutableMatrix = \
  161. _print_ImmutableDenseMatrix = \
  162. blacklisted
  163. # blacklist some Python expressions
  164. _print_list = \
  165. _print_tuple = \
  166. _print_Tuple = \
  167. _print_dict = \
  168. _print_Dict = \
  169. blacklisted
  170. def _print_NumExprEvaluate(self, expr):
  171. evaluate = self._module_format(self.module +".evaluate")
  172. return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr))
  173. def doprint(self, expr):
  174. from sympy.codegen.ast import CodegenAST
  175. from sympy.codegen.pynodes import NumExprEvaluate
  176. if not isinstance(expr, CodegenAST):
  177. expr = NumExprEvaluate(expr)
  178. return super().doprint(expr)
  179. def _print_Return(self, expr):
  180. from sympy.codegen.pynodes import NumExprEvaluate
  181. r, = expr.args
  182. if not isinstance(r, NumExprEvaluate):
  183. expr = expr.func(NumExprEvaluate(r))
  184. return super()._print_Return(expr)
  185. def _print_Assignment(self, expr):
  186. from sympy.codegen.pynodes import NumExprEvaluate
  187. lhs, rhs, *args = expr.args
  188. if not isinstance(rhs, NumExprEvaluate):
  189. expr = expr.func(lhs, NumExprEvaluate(rhs), *args)
  190. return super()._print_Assignment(expr)
  191. def _print_CodeBlock(self, expr):
  192. from sympy.codegen.ast import CodegenAST
  193. from sympy.codegen.pynodes import NumExprEvaluate
  194. args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ]
  195. return super()._print_CodeBlock(self, expr.func(*args))
  196. class IntervalPrinter(MpmathPrinter, LambdaPrinter):
  197. """Use ``lambda`` printer but print numbers as ``mpi`` intervals. """
  198. def _print_Integer(self, expr):
  199. return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr)
  200. def _print_Rational(self, expr):
  201. return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
  202. def _print_Half(self, expr):
  203. return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
  204. def _print_Pow(self, expr):
  205. return super(MpmathPrinter, self)._print_Pow(expr, rational=True)
  206. for k in NumExprPrinter._numexpr_functions:
  207. setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function)
  208. def lambdarepr(expr, **settings):
  209. """
  210. Returns a string usable for lambdifying.
  211. """
  212. return LambdaPrinter(settings).doprint(expr)