mathematica.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """
  2. Mathematica code printer
  3. """
  4. from __future__ import annotations
  5. from typing import Any
  6. from sympy.core import Basic, Expr, Float
  7. from sympy.core.sorting import default_sort_key
  8. from sympy.printing.codeprinter import CodePrinter
  9. from sympy.printing.precedence import precedence
  10. # Used in MCodePrinter._print_Function(self)
  11. known_functions = {
  12. "exp": [(lambda x: True, "Exp")],
  13. "log": [(lambda x: True, "Log")],
  14. "sin": [(lambda x: True, "Sin")],
  15. "cos": [(lambda x: True, "Cos")],
  16. "tan": [(lambda x: True, "Tan")],
  17. "cot": [(lambda x: True, "Cot")],
  18. "sec": [(lambda x: True, "Sec")],
  19. "csc": [(lambda x: True, "Csc")],
  20. "asin": [(lambda x: True, "ArcSin")],
  21. "acos": [(lambda x: True, "ArcCos")],
  22. "atan": [(lambda x: True, "ArcTan")],
  23. "acot": [(lambda x: True, "ArcCot")],
  24. "asec": [(lambda x: True, "ArcSec")],
  25. "acsc": [(lambda x: True, "ArcCsc")],
  26. "atan2": [(lambda *x: True, "ArcTan")],
  27. "sinh": [(lambda x: True, "Sinh")],
  28. "cosh": [(lambda x: True, "Cosh")],
  29. "tanh": [(lambda x: True, "Tanh")],
  30. "coth": [(lambda x: True, "Coth")],
  31. "sech": [(lambda x: True, "Sech")],
  32. "csch": [(lambda x: True, "Csch")],
  33. "asinh": [(lambda x: True, "ArcSinh")],
  34. "acosh": [(lambda x: True, "ArcCosh")],
  35. "atanh": [(lambda x: True, "ArcTanh")],
  36. "acoth": [(lambda x: True, "ArcCoth")],
  37. "asech": [(lambda x: True, "ArcSech")],
  38. "acsch": [(lambda x: True, "ArcCsch")],
  39. "sinc": [(lambda x: True, "Sinc")],
  40. "conjugate": [(lambda x: True, "Conjugate")],
  41. "Max": [(lambda *x: True, "Max")],
  42. "Min": [(lambda *x: True, "Min")],
  43. "erf": [(lambda x: True, "Erf")],
  44. "erf2": [(lambda *x: True, "Erf")],
  45. "erfc": [(lambda x: True, "Erfc")],
  46. "erfi": [(lambda x: True, "Erfi")],
  47. "erfinv": [(lambda x: True, "InverseErf")],
  48. "erfcinv": [(lambda x: True, "InverseErfc")],
  49. "erf2inv": [(lambda *x: True, "InverseErf")],
  50. "expint": [(lambda *x: True, "ExpIntegralE")],
  51. "Ei": [(lambda x: True, "ExpIntegralEi")],
  52. "fresnelc": [(lambda x: True, "FresnelC")],
  53. "fresnels": [(lambda x: True, "FresnelS")],
  54. "gamma": [(lambda x: True, "Gamma")],
  55. "uppergamma": [(lambda *x: True, "Gamma")],
  56. "polygamma": [(lambda *x: True, "PolyGamma")],
  57. "loggamma": [(lambda x: True, "LogGamma")],
  58. "beta": [(lambda *x: True, "Beta")],
  59. "Ci": [(lambda x: True, "CosIntegral")],
  60. "Si": [(lambda x: True, "SinIntegral")],
  61. "Chi": [(lambda x: True, "CoshIntegral")],
  62. "Shi": [(lambda x: True, "SinhIntegral")],
  63. "li": [(lambda x: True, "LogIntegral")],
  64. "factorial": [(lambda x: True, "Factorial")],
  65. "factorial2": [(lambda x: True, "Factorial2")],
  66. "subfactorial": [(lambda x: True, "Subfactorial")],
  67. "catalan": [(lambda x: True, "CatalanNumber")],
  68. "harmonic": [(lambda *x: True, "HarmonicNumber")],
  69. "lucas": [(lambda x: True, "LucasL")],
  70. "RisingFactorial": [(lambda *x: True, "Pochhammer")],
  71. "FallingFactorial": [(lambda *x: True, "FactorialPower")],
  72. "laguerre": [(lambda *x: True, "LaguerreL")],
  73. "assoc_laguerre": [(lambda *x: True, "LaguerreL")],
  74. "hermite": [(lambda *x: True, "HermiteH")],
  75. "jacobi": [(lambda *x: True, "JacobiP")],
  76. "gegenbauer": [(lambda *x: True, "GegenbauerC")],
  77. "chebyshevt": [(lambda *x: True, "ChebyshevT")],
  78. "chebyshevu": [(lambda *x: True, "ChebyshevU")],
  79. "legendre": [(lambda *x: True, "LegendreP")],
  80. "assoc_legendre": [(lambda *x: True, "LegendreP")],
  81. "mathieuc": [(lambda *x: True, "MathieuC")],
  82. "mathieus": [(lambda *x: True, "MathieuS")],
  83. "mathieucprime": [(lambda *x: True, "MathieuCPrime")],
  84. "mathieusprime": [(lambda *x: True, "MathieuSPrime")],
  85. "stieltjes": [(lambda x: True, "StieltjesGamma")],
  86. "elliptic_e": [(lambda *x: True, "EllipticE")],
  87. "elliptic_f": [(lambda *x: True, "EllipticE")],
  88. "elliptic_k": [(lambda x: True, "EllipticK")],
  89. "elliptic_pi": [(lambda *x: True, "EllipticPi")],
  90. "zeta": [(lambda *x: True, "Zeta")],
  91. "dirichlet_eta": [(lambda x: True, "DirichletEta")],
  92. "riemann_xi": [(lambda x: True, "RiemannXi")],
  93. "besseli": [(lambda *x: True, "BesselI")],
  94. "besselj": [(lambda *x: True, "BesselJ")],
  95. "besselk": [(lambda *x: True, "BesselK")],
  96. "bessely": [(lambda *x: True, "BesselY")],
  97. "hankel1": [(lambda *x: True, "HankelH1")],
  98. "hankel2": [(lambda *x: True, "HankelH2")],
  99. "airyai": [(lambda x: True, "AiryAi")],
  100. "airybi": [(lambda x: True, "AiryBi")],
  101. "airyaiprime": [(lambda x: True, "AiryAiPrime")],
  102. "airybiprime": [(lambda x: True, "AiryBiPrime")],
  103. "polylog": [(lambda *x: True, "PolyLog")],
  104. "lerchphi": [(lambda *x: True, "LerchPhi")],
  105. "gcd": [(lambda *x: True, "GCD")],
  106. "lcm": [(lambda *x: True, "LCM")],
  107. "jn": [(lambda *x: True, "SphericalBesselJ")],
  108. "yn": [(lambda *x: True, "SphericalBesselY")],
  109. "hyper": [(lambda *x: True, "HypergeometricPFQ")],
  110. "meijerg": [(lambda *x: True, "MeijerG")],
  111. "appellf1": [(lambda *x: True, "AppellF1")],
  112. "DiracDelta": [(lambda x: True, "DiracDelta")],
  113. "Heaviside": [(lambda x: True, "HeavisideTheta")],
  114. "KroneckerDelta": [(lambda *x: True, "KroneckerDelta")],
  115. "sqrt": [(lambda x: True, "Sqrt")], # For automatic rewrites
  116. }
  117. class MCodePrinter(CodePrinter):
  118. """A printer to convert Python expressions to
  119. strings of the Wolfram's Mathematica code
  120. """
  121. printmethod = "_mcode"
  122. language = "Wolfram Language"
  123. _default_settings: dict[str, Any] = {
  124. 'order': None,
  125. 'full_prec': 'auto',
  126. 'precision': 15,
  127. 'user_functions': {},
  128. 'human': True,
  129. 'allow_unknown_functions': False,
  130. }
  131. _number_symbols: set[tuple[Expr, Float]] = set()
  132. _not_supported: set[Basic] = set()
  133. def __init__(self, settings={}):
  134. """Register function mappings supplied by user"""
  135. CodePrinter.__init__(self, settings)
  136. self.known_functions = dict(known_functions)
  137. userfuncs = settings.get('user_functions', {}).copy()
  138. for k, v in userfuncs.items():
  139. if not isinstance(v, list):
  140. userfuncs[k] = [(lambda *x: True, v)]
  141. self.known_functions.update(userfuncs)
  142. def _format_code(self, lines):
  143. return lines
  144. def _print_Pow(self, expr):
  145. PREC = precedence(expr)
  146. return '%s^%s' % (self.parenthesize(expr.base, PREC),
  147. self.parenthesize(expr.exp, PREC))
  148. def _print_Mul(self, expr):
  149. PREC = precedence(expr)
  150. c, nc = expr.args_cnc()
  151. res = super()._print_Mul(expr.func(*c))
  152. if nc:
  153. res += '*'
  154. res += '**'.join(self.parenthesize(a, PREC) for a in nc)
  155. return res
  156. def _print_Relational(self, expr):
  157. lhs_code = self._print(expr.lhs)
  158. rhs_code = self._print(expr.rhs)
  159. op = expr.rel_op
  160. return "{} {} {}".format(lhs_code, op, rhs_code)
  161. # Primitive numbers
  162. def _print_Zero(self, expr):
  163. return '0'
  164. def _print_One(self, expr):
  165. return '1'
  166. def _print_NegativeOne(self, expr):
  167. return '-1'
  168. def _print_Half(self, expr):
  169. return '1/2'
  170. def _print_ImaginaryUnit(self, expr):
  171. return 'I'
  172. # Infinity and invalid numbers
  173. def _print_Infinity(self, expr):
  174. return 'Infinity'
  175. def _print_NegativeInfinity(self, expr):
  176. return '-Infinity'
  177. def _print_ComplexInfinity(self, expr):
  178. return 'ComplexInfinity'
  179. def _print_NaN(self, expr):
  180. return 'Indeterminate'
  181. # Mathematical constants
  182. def _print_Exp1(self, expr):
  183. return 'E'
  184. def _print_Pi(self, expr):
  185. return 'Pi'
  186. def _print_GoldenRatio(self, expr):
  187. return 'GoldenRatio'
  188. def _print_TribonacciConstant(self, expr):
  189. expanded = expr.expand(func=True)
  190. PREC = precedence(expr)
  191. return self.parenthesize(expanded, PREC)
  192. def _print_EulerGamma(self, expr):
  193. return 'EulerGamma'
  194. def _print_Catalan(self, expr):
  195. return 'Catalan'
  196. def _print_list(self, expr):
  197. return '{' + ', '.join(self.doprint(a) for a in expr) + '}'
  198. _print_tuple = _print_list
  199. _print_Tuple = _print_list
  200. def _print_ImmutableDenseMatrix(self, expr):
  201. return self.doprint(expr.tolist())
  202. def _print_ImmutableSparseMatrix(self, expr):
  203. def print_rule(pos, val):
  204. return '{} -> {}'.format(
  205. self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val))
  206. def print_data():
  207. items = sorted(expr.todok().items(), key=default_sort_key)
  208. return '{' + \
  209. ', '.join(print_rule(k, v) for k, v in items) + \
  210. '}'
  211. def print_dims():
  212. return self.doprint(expr.shape)
  213. return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
  214. def _print_ImmutableDenseNDimArray(self, expr):
  215. return self.doprint(expr.tolist())
  216. def _print_ImmutableSparseNDimArray(self, expr):
  217. def print_string_list(string_list):
  218. return '{' + ', '.join(a for a in string_list) + '}'
  219. def to_mathematica_index(*args):
  220. """Helper function to change Python style indexing to
  221. Pathematica indexing.
  222. Python indexing (0, 1 ... n-1)
  223. -> Mathematica indexing (1, 2 ... n)
  224. """
  225. return tuple(i + 1 for i in args)
  226. def print_rule(pos, val):
  227. """Helper function to print a rule of Mathematica"""
  228. return '{} -> {}'.format(self.doprint(pos), self.doprint(val))
  229. def print_data():
  230. """Helper function to print data part of Mathematica
  231. sparse array.
  232. It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
  233. from
  234. https://reference.wolfram.com/language/ref/SparseArray.html
  235. ``data`` must be formatted with rule.
  236. """
  237. return print_string_list(
  238. [print_rule(
  239. to_mathematica_index(*(expr._get_tuple_index(key))),
  240. value)
  241. for key, value in sorted(expr._sparse_array.items())]
  242. )
  243. def print_dims():
  244. """Helper function to print dimensions part of Mathematica
  245. sparse array.
  246. It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
  247. from
  248. https://reference.wolfram.com/language/ref/SparseArray.html
  249. """
  250. return self.doprint(expr.shape)
  251. return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
  252. def _print_Function(self, expr):
  253. if expr.func.__name__ in self.known_functions:
  254. cond_mfunc = self.known_functions[expr.func.__name__]
  255. for cond, mfunc in cond_mfunc:
  256. if cond(*expr.args):
  257. return "%s[%s]" % (mfunc, self.stringify(expr.args, ", "))
  258. elif expr.func.__name__ in self._rewriteable_functions:
  259. # Simple rewrite to supported function possible
  260. target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
  261. if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
  262. return self._print(expr.rewrite(target_f))
  263. return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ")
  264. _print_MinMaxBase = _print_Function
  265. def _print_LambertW(self, expr):
  266. if len(expr.args) == 1:
  267. return "ProductLog[{}]".format(self._print(expr.args[0]))
  268. return "ProductLog[{}, {}]".format(
  269. self._print(expr.args[1]), self._print(expr.args[0]))
  270. def _print_Integral(self, expr):
  271. if len(expr.variables) == 1 and not expr.limits[0][1:]:
  272. args = [expr.args[0], expr.variables[0]]
  273. else:
  274. args = expr.args
  275. return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]"
  276. def _print_Sum(self, expr):
  277. return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]"
  278. def _print_Derivative(self, expr):
  279. dexpr = expr.expr
  280. dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count]
  281. return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]"
  282. def _get_comment(self, text):
  283. return "(* {} *)".format(text)
  284. def mathematica_code(expr, **settings):
  285. r"""Converts an expr to a string of the Wolfram Mathematica code
  286. Examples
  287. ========
  288. >>> from sympy import mathematica_code as mcode, symbols, sin
  289. >>> x = symbols('x')
  290. >>> mcode(sin(x).series(x).removeO())
  291. '(1/120)*x^5 - 1/6*x^3 + x'
  292. """
  293. return MCodePrinter(settings).doprint(expr)