repr.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. """
  2. A Printer for generating executable code.
  3. The most important function here is srepr that returns a string so that the
  4. relation eval(srepr(expr))=expr holds in an appropriate environment.
  5. """
  6. from __future__ import annotations
  7. from typing import Any
  8. from sympy.core.function import AppliedUndef
  9. from sympy.core.mul import Mul
  10. from mpmath.libmp import repr_dps, to_str as mlib_to_str
  11. from .printer import Printer, print_function
  12. class ReprPrinter(Printer):
  13. printmethod = "_sympyrepr"
  14. _default_settings: dict[str, Any] = {
  15. "order": None,
  16. "perm_cyclic" : True,
  17. }
  18. def reprify(self, args, sep):
  19. """
  20. Prints each item in `args` and joins them with `sep`.
  21. """
  22. return sep.join([self.doprint(item) for item in args])
  23. def emptyPrinter(self, expr):
  24. """
  25. The fallback printer.
  26. """
  27. if isinstance(expr, str):
  28. return expr
  29. elif hasattr(expr, "__srepr__"):
  30. return expr.__srepr__()
  31. elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"):
  32. l = []
  33. for o in expr.args:
  34. l.append(self._print(o))
  35. return expr.__class__.__name__ + '(%s)' % ', '.join(l)
  36. elif hasattr(expr, "__module__") and hasattr(expr, "__name__"):
  37. return "<'%s.%s'>" % (expr.__module__, expr.__name__)
  38. else:
  39. return str(expr)
  40. def _print_Add(self, expr, order=None):
  41. args = self._as_ordered_terms(expr, order=order)
  42. args = map(self._print, args)
  43. clsname = type(expr).__name__
  44. return clsname + "(%s)" % ", ".join(args)
  45. def _print_Cycle(self, expr):
  46. return expr.__repr__()
  47. def _print_Permutation(self, expr):
  48. from sympy.combinatorics.permutations import Permutation, Cycle
  49. from sympy.utilities.exceptions import sympy_deprecation_warning
  50. perm_cyclic = Permutation.print_cyclic
  51. if perm_cyclic is not None:
  52. sympy_deprecation_warning(
  53. f"""
  54. Setting Permutation.print_cyclic is deprecated. Instead use
  55. init_printing(perm_cyclic={perm_cyclic}).
  56. """,
  57. deprecated_since_version="1.6",
  58. active_deprecations_target="deprecated-permutation-print_cyclic",
  59. stacklevel=7,
  60. )
  61. else:
  62. perm_cyclic = self._settings.get("perm_cyclic", True)
  63. if perm_cyclic:
  64. if not expr.size:
  65. return 'Permutation()'
  66. # before taking Cycle notation, see if the last element is
  67. # a singleton and move it to the head of the string
  68. s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):]
  69. last = s.rfind('(')
  70. if not last == 0 and ',' not in s[last:]:
  71. s = s[last:] + s[:last]
  72. return 'Permutation%s' %s
  73. else:
  74. s = expr.support()
  75. if not s:
  76. if expr.size < 5:
  77. return 'Permutation(%s)' % str(expr.array_form)
  78. return 'Permutation([], size=%s)' % expr.size
  79. trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size
  80. use = full = str(expr.array_form)
  81. if len(trim) < len(full):
  82. use = trim
  83. return 'Permutation(%s)' % use
  84. def _print_Function(self, expr):
  85. r = self._print(expr.func)
  86. r += '(%s)' % ', '.join([self._print(a) for a in expr.args])
  87. return r
  88. def _print_Heaviside(self, expr):
  89. # Same as _print_Function but uses pargs to suppress default value for
  90. # 2nd arg.
  91. r = self._print(expr.func)
  92. r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs])
  93. return r
  94. def _print_FunctionClass(self, expr):
  95. if issubclass(expr, AppliedUndef):
  96. return 'Function(%r)' % (expr.__name__)
  97. else:
  98. return expr.__name__
  99. def _print_Half(self, expr):
  100. return 'Rational(1, 2)'
  101. def _print_RationalConstant(self, expr):
  102. return str(expr)
  103. def _print_AtomicExpr(self, expr):
  104. return str(expr)
  105. def _print_NumberSymbol(self, expr):
  106. return str(expr)
  107. def _print_Integer(self, expr):
  108. return 'Integer(%i)' % expr.p
  109. def _print_Complexes(self, expr):
  110. return 'Complexes'
  111. def _print_Integers(self, expr):
  112. return 'Integers'
  113. def _print_Naturals(self, expr):
  114. return 'Naturals'
  115. def _print_Naturals0(self, expr):
  116. return 'Naturals0'
  117. def _print_Rationals(self, expr):
  118. return 'Rationals'
  119. def _print_Reals(self, expr):
  120. return 'Reals'
  121. def _print_EmptySet(self, expr):
  122. return 'EmptySet'
  123. def _print_UniversalSet(self, expr):
  124. return 'UniversalSet'
  125. def _print_EmptySequence(self, expr):
  126. return 'EmptySequence'
  127. def _print_list(self, expr):
  128. return "[%s]" % self.reprify(expr, ", ")
  129. def _print_dict(self, expr):
  130. sep = ", "
  131. dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()]
  132. return "{%s}" % sep.join(dict_kvs)
  133. def _print_set(self, expr):
  134. if not expr:
  135. return "set()"
  136. return "{%s}" % self.reprify(expr, ", ")
  137. def _print_MatrixBase(self, expr):
  138. # special case for some empty matrices
  139. if (expr.rows == 0) ^ (expr.cols == 0):
  140. return '%s(%s, %s, %s)' % (expr.__class__.__name__,
  141. self._print(expr.rows),
  142. self._print(expr.cols),
  143. self._print([]))
  144. l = []
  145. for i in range(expr.rows):
  146. l.append([])
  147. for j in range(expr.cols):
  148. l[-1].append(expr[i, j])
  149. return '%s(%s)' % (expr.__class__.__name__, self._print(l))
  150. def _print_BooleanTrue(self, expr):
  151. return "true"
  152. def _print_BooleanFalse(self, expr):
  153. return "false"
  154. def _print_NaN(self, expr):
  155. return "nan"
  156. def _print_Mul(self, expr, order=None):
  157. if self.order not in ('old', 'none'):
  158. args = expr.as_ordered_factors()
  159. else:
  160. # use make_args in case expr was something like -x -> x
  161. args = Mul.make_args(expr)
  162. args = map(self._print, args)
  163. clsname = type(expr).__name__
  164. return clsname + "(%s)" % ", ".join(args)
  165. def _print_Rational(self, expr):
  166. return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q))
  167. def _print_PythonRational(self, expr):
  168. return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q)
  169. def _print_Fraction(self, expr):
  170. return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator))
  171. def _print_Float(self, expr):
  172. r = mlib_to_str(expr._mpf_, repr_dps(expr._prec))
  173. return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec)
  174. def _print_Sum2(self, expr):
  175. return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i),
  176. self._print(expr.a), self._print(expr.b))
  177. def _print_Str(self, s):
  178. return "%s(%s)" % (s.__class__.__name__, self._print(s.name))
  179. def _print_Symbol(self, expr):
  180. d = expr._assumptions_orig
  181. # print the dummy_index like it was an assumption
  182. if expr.is_Dummy:
  183. d['dummy_index'] = expr.dummy_index
  184. if d == {}:
  185. return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name))
  186. else:
  187. attr = ['%s=%s' % (k, v) for k, v in d.items()]
  188. return "%s(%s, %s)" % (expr.__class__.__name__,
  189. self._print(expr.name), ', '.join(attr))
  190. def _print_CoordinateSymbol(self, expr):
  191. d = expr._assumptions.generator
  192. if d == {}:
  193. return "%s(%s, %s)" % (
  194. expr.__class__.__name__,
  195. self._print(expr.coord_sys),
  196. self._print(expr.index)
  197. )
  198. else:
  199. attr = ['%s=%s' % (k, v) for k, v in d.items()]
  200. return "%s(%s, %s, %s)" % (
  201. expr.__class__.__name__,
  202. self._print(expr.coord_sys),
  203. self._print(expr.index),
  204. ', '.join(attr)
  205. )
  206. def _print_Predicate(self, expr):
  207. return "Q.%s" % expr.name
  208. def _print_AppliedPredicate(self, expr):
  209. # will be changed to just expr.args when args overriding is removed
  210. args = expr._args
  211. return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", "))
  212. def _print_str(self, expr):
  213. return repr(expr)
  214. def _print_tuple(self, expr):
  215. if len(expr) == 1:
  216. return "(%s,)" % self._print(expr[0])
  217. else:
  218. return "(%s)" % self.reprify(expr, ", ")
  219. def _print_WildFunction(self, expr):
  220. return "%s('%s')" % (expr.__class__.__name__, expr.name)
  221. def _print_AlgebraicNumber(self, expr):
  222. return "%s(%s, %s)" % (expr.__class__.__name__,
  223. self._print(expr.root), self._print(expr.coeffs()))
  224. def _print_PolyRing(self, ring):
  225. return "%s(%s, %s, %s)" % (ring.__class__.__name__,
  226. self._print(ring.symbols), self._print(ring.domain), self._print(ring.order))
  227. def _print_FracField(self, field):
  228. return "%s(%s, %s, %s)" % (field.__class__.__name__,
  229. self._print(field.symbols), self._print(field.domain), self._print(field.order))
  230. def _print_PolyElement(self, poly):
  231. terms = list(poly.terms())
  232. terms.sort(key=poly.ring.order, reverse=True)
  233. return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms))
  234. def _print_FracElement(self, frac):
  235. numer_terms = list(frac.numer.terms())
  236. numer_terms.sort(key=frac.field.order, reverse=True)
  237. denom_terms = list(frac.denom.terms())
  238. denom_terms.sort(key=frac.field.order, reverse=True)
  239. numer = self._print(numer_terms)
  240. denom = self._print(denom_terms)
  241. return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom)
  242. def _print_FractionField(self, domain):
  243. cls = domain.__class__.__name__
  244. field = self._print(domain.field)
  245. return "%s(%s)" % (cls, field)
  246. def _print_PolynomialRingBase(self, ring):
  247. cls = ring.__class__.__name__
  248. dom = self._print(ring.domain)
  249. gens = ', '.join(map(self._print, ring.gens))
  250. order = str(ring.order)
  251. if order != ring.default_order:
  252. orderstr = ", order=" + order
  253. else:
  254. orderstr = ""
  255. return "%s(%s, %s%s)" % (cls, dom, gens, orderstr)
  256. def _print_DMP(self, p):
  257. cls = p.__class__.__name__
  258. rep = self._print(p.rep)
  259. dom = self._print(p.dom)
  260. if p.ring is not None:
  261. ringstr = ", ring=" + self._print(p.ring)
  262. else:
  263. ringstr = ""
  264. return "%s(%s, %s%s)" % (cls, rep, dom, ringstr)
  265. def _print_MonogenicFiniteExtension(self, ext):
  266. # The expanded tree shown by srepr(ext.modulus)
  267. # is not practical.
  268. return "FiniteExtension(%s)" % str(ext.modulus)
  269. def _print_ExtensionElement(self, f):
  270. rep = self._print(f.rep)
  271. ext = self._print(f.ext)
  272. return "ExtElem(%s, %s)" % (rep, ext)
  273. @print_function(ReprPrinter)
  274. def srepr(expr, **settings):
  275. """return expr in repr form"""
  276. return ReprPrinter(settings).doprint(expr)