printing.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. from sympy.core.function import Derivative
  2. from sympy.core.function import UndefinedFunction, AppliedUndef
  3. from sympy.core.symbol import Symbol
  4. from sympy.interactive.printing import init_printing
  5. from sympy.printing.latex import LatexPrinter
  6. from sympy.printing.pretty.pretty import PrettyPrinter
  7. from sympy.printing.pretty.pretty_symbology import center_accent
  8. from sympy.printing.str import StrPrinter
  9. from sympy.printing.precedence import PRECEDENCE
  10. __all__ = ['vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex',
  11. 'init_vprinting']
  12. class VectorStrPrinter(StrPrinter):
  13. """String Printer for vector expressions. """
  14. def _print_Derivative(self, e):
  15. from sympy.physics.vector.functions import dynamicsymbols
  16. t = dynamicsymbols._t
  17. if (bool(sum([i == t for i in e.variables])) &
  18. isinstance(type(e.args[0]), UndefinedFunction)):
  19. ol = str(e.args[0].func)
  20. for i, v in enumerate(e.variables):
  21. ol += dynamicsymbols._str
  22. return ol
  23. else:
  24. return StrPrinter().doprint(e)
  25. def _print_Function(self, e):
  26. from sympy.physics.vector.functions import dynamicsymbols
  27. t = dynamicsymbols._t
  28. if isinstance(type(e), UndefinedFunction):
  29. return StrPrinter().doprint(e).replace("(%s)" % t, '')
  30. return e.func.__name__ + "(%s)" % self.stringify(e.args, ", ")
  31. class VectorStrReprPrinter(VectorStrPrinter):
  32. """String repr printer for vector expressions."""
  33. def _print_str(self, s):
  34. return repr(s)
  35. class VectorLatexPrinter(LatexPrinter):
  36. """Latex Printer for vector expressions. """
  37. def _print_Function(self, expr, exp=None):
  38. from sympy.physics.vector.functions import dynamicsymbols
  39. func = expr.func.__name__
  40. t = dynamicsymbols._t
  41. if (hasattr(self, '_print_' + func) and not
  42. isinstance(type(expr), UndefinedFunction)):
  43. return getattr(self, '_print_' + func)(expr, exp)
  44. elif isinstance(type(expr), UndefinedFunction) and (expr.args == (t,)):
  45. # treat this function like a symbol
  46. expr = Symbol(func)
  47. if exp is not None:
  48. # copied from LatexPrinter._helper_print_standard_power, which
  49. # we can't call because we only have exp as a string.
  50. base = self.parenthesize(expr, PRECEDENCE['Pow'])
  51. base = self.parenthesize_super(base)
  52. return r"%s^{%s}" % (base, exp)
  53. else:
  54. return super()._print(expr)
  55. else:
  56. return super()._print_Function(expr, exp)
  57. def _print_Derivative(self, der_expr):
  58. from sympy.physics.vector.functions import dynamicsymbols
  59. # make sure it is in the right form
  60. der_expr = der_expr.doit()
  61. if not isinstance(der_expr, Derivative):
  62. return r"\left(%s\right)" % self.doprint(der_expr)
  63. # check if expr is a dynamicsymbol
  64. t = dynamicsymbols._t
  65. expr = der_expr.expr
  66. red = expr.atoms(AppliedUndef)
  67. syms = der_expr.variables
  68. test1 = not all(True for i in red if i.free_symbols == {t})
  69. test2 = not all(t == i for i in syms)
  70. if test1 or test2:
  71. return super()._print_Derivative(der_expr)
  72. # done checking
  73. dots = len(syms)
  74. base = self._print_Function(expr)
  75. base_split = base.split('_', 1)
  76. base = base_split[0]
  77. if dots == 1:
  78. base = r"\dot{%s}" % base
  79. elif dots == 2:
  80. base = r"\ddot{%s}" % base
  81. elif dots == 3:
  82. base = r"\dddot{%s}" % base
  83. elif dots == 4:
  84. base = r"\ddddot{%s}" % base
  85. else: # Fallback to standard printing
  86. return super()._print_Derivative(der_expr)
  87. if len(base_split) != 1:
  88. base += '_' + base_split[1]
  89. return base
  90. class VectorPrettyPrinter(PrettyPrinter):
  91. """Pretty Printer for vectorialexpressions. """
  92. def _print_Derivative(self, deriv):
  93. from sympy.physics.vector.functions import dynamicsymbols
  94. # XXX use U('PARTIAL DIFFERENTIAL') here ?
  95. t = dynamicsymbols._t
  96. dot_i = 0
  97. syms = list(reversed(deriv.variables))
  98. while len(syms) > 0:
  99. if syms[-1] == t:
  100. syms.pop()
  101. dot_i += 1
  102. else:
  103. return super()._print_Derivative(deriv)
  104. if not (isinstance(type(deriv.expr), UndefinedFunction) and
  105. (deriv.expr.args == (t,))):
  106. return super()._print_Derivative(deriv)
  107. else:
  108. pform = self._print_Function(deriv.expr)
  109. # the following condition would happen with some sort of non-standard
  110. # dynamic symbol I guess, so we'll just print the SymPy way
  111. if len(pform.picture) > 1:
  112. return super()._print_Derivative(deriv)
  113. # There are only special symbols up to fourth-order derivatives
  114. if dot_i >= 5:
  115. return super()._print_Derivative(deriv)
  116. # Deal with special symbols
  117. dots = {0: "",
  118. 1: "\N{COMBINING DOT ABOVE}",
  119. 2: "\N{COMBINING DIAERESIS}",
  120. 3: "\N{COMBINING THREE DOTS ABOVE}",
  121. 4: "\N{COMBINING FOUR DOTS ABOVE}"}
  122. d = pform.__dict__
  123. # if unicode is false then calculate number of apostrophes needed and
  124. # add to output
  125. if not self._use_unicode:
  126. apostrophes = ""
  127. for i in range(0, dot_i):
  128. apostrophes += "'"
  129. d['picture'][0] += apostrophes + "(t)"
  130. else:
  131. d['picture'] = [center_accent(d['picture'][0], dots[dot_i])]
  132. return pform
  133. def _print_Function(self, e):
  134. from sympy.physics.vector.functions import dynamicsymbols
  135. t = dynamicsymbols._t
  136. # XXX works only for applied functions
  137. func = e.func
  138. args = e.args
  139. func_name = func.__name__
  140. pform = self._print_Symbol(Symbol(func_name))
  141. # If this function is an Undefined function of t, it is probably a
  142. # dynamic symbol, so we'll skip the (t). The rest of the code is
  143. # identical to the normal PrettyPrinter code
  144. if not (isinstance(func, UndefinedFunction) and (args == (t,))):
  145. return super()._print_Function(e)
  146. return pform
  147. def vprint(expr, **settings):
  148. r"""Function for printing of expressions generated in the
  149. sympy.physics vector package.
  150. Extends SymPy's StrPrinter, takes the same setting accepted by SymPy's
  151. :func:`~.sstr`, and is equivalent to ``print(sstr(foo))``.
  152. Parameters
  153. ==========
  154. expr : valid SymPy object
  155. SymPy expression to print.
  156. settings : args
  157. Same as the settings accepted by SymPy's sstr().
  158. Examples
  159. ========
  160. >>> from sympy.physics.vector import vprint, dynamicsymbols
  161. >>> u1 = dynamicsymbols('u1')
  162. >>> print(u1)
  163. u1(t)
  164. >>> vprint(u1)
  165. u1
  166. """
  167. outstr = vsprint(expr, **settings)
  168. import builtins
  169. if (outstr != 'None'):
  170. builtins._ = outstr
  171. print(outstr)
  172. def vsstrrepr(expr, **settings):
  173. """Function for displaying expression representation's with vector
  174. printing enabled.
  175. Parameters
  176. ==========
  177. expr : valid SymPy object
  178. SymPy expression to print.
  179. settings : args
  180. Same as the settings accepted by SymPy's sstrrepr().
  181. """
  182. p = VectorStrReprPrinter(settings)
  183. return p.doprint(expr)
  184. def vsprint(expr, **settings):
  185. r"""Function for displaying expressions generated in the
  186. sympy.physics vector package.
  187. Returns the output of vprint() as a string.
  188. Parameters
  189. ==========
  190. expr : valid SymPy object
  191. SymPy expression to print
  192. settings : args
  193. Same as the settings accepted by SymPy's sstr().
  194. Examples
  195. ========
  196. >>> from sympy.physics.vector import vsprint, dynamicsymbols
  197. >>> u1, u2 = dynamicsymbols('u1 u2')
  198. >>> u2d = dynamicsymbols('u2', level=1)
  199. >>> print("%s = %s" % (u1, u2 + u2d))
  200. u1(t) = u2(t) + Derivative(u2(t), t)
  201. >>> print("%s = %s" % (vsprint(u1), vsprint(u2 + u2d)))
  202. u1 = u2 + u2'
  203. """
  204. string_printer = VectorStrPrinter(settings)
  205. return string_printer.doprint(expr)
  206. def vpprint(expr, **settings):
  207. r"""Function for pretty printing of expressions generated in the
  208. sympy.physics vector package.
  209. Mainly used for expressions not inside a vector; the output of running
  210. scripts and generating equations of motion. Takes the same options as
  211. SymPy's :func:`~.pretty_print`; see that function for more information.
  212. Parameters
  213. ==========
  214. expr : valid SymPy object
  215. SymPy expression to pretty print
  216. settings : args
  217. Same as those accepted by SymPy's pretty_print.
  218. """
  219. pp = VectorPrettyPrinter(settings)
  220. # Note that this is copied from sympy.printing.pretty.pretty_print:
  221. # XXX: this is an ugly hack, but at least it works
  222. use_unicode = pp._settings['use_unicode']
  223. from sympy.printing.pretty.pretty_symbology import pretty_use_unicode
  224. uflag = pretty_use_unicode(use_unicode)
  225. try:
  226. return pp.doprint(expr)
  227. finally:
  228. pretty_use_unicode(uflag)
  229. def vlatex(expr, **settings):
  230. r"""Function for printing latex representation of sympy.physics.vector
  231. objects.
  232. For latex representation of Vectors, Dyadics, and dynamicsymbols. Takes the
  233. same options as SymPy's :func:`~.latex`; see that function for more
  234. information;
  235. Parameters
  236. ==========
  237. expr : valid SymPy object
  238. SymPy expression to represent in LaTeX form
  239. settings : args
  240. Same as latex()
  241. Examples
  242. ========
  243. >>> from sympy.physics.vector import vlatex, ReferenceFrame, dynamicsymbols
  244. >>> N = ReferenceFrame('N')
  245. >>> q1, q2 = dynamicsymbols('q1 q2')
  246. >>> q1d, q2d = dynamicsymbols('q1 q2', 1)
  247. >>> q1dd, q2dd = dynamicsymbols('q1 q2', 2)
  248. >>> vlatex(N.x + N.y)
  249. '\\mathbf{\\hat{n}_x} + \\mathbf{\\hat{n}_y}'
  250. >>> vlatex(q1 + q2)
  251. 'q_{1} + q_{2}'
  252. >>> vlatex(q1d)
  253. '\\dot{q}_{1}'
  254. >>> vlatex(q1 * q2d)
  255. 'q_{1} \\dot{q}_{2}'
  256. >>> vlatex(q1dd * q1 / q1d)
  257. '\\frac{q_{1} \\ddot{q}_{1}}{\\dot{q}_{1}}'
  258. """
  259. latex_printer = VectorLatexPrinter(settings)
  260. return latex_printer.doprint(expr)
  261. def init_vprinting(**kwargs):
  262. """Initializes time derivative printing for all SymPy objects, i.e. any
  263. functions of time will be displayed in a more compact notation. The main
  264. benefit of this is for printing of time derivatives; instead of
  265. displaying as ``Derivative(f(t),t)``, it will display ``f'``. This is
  266. only actually needed for when derivatives are present and are not in a
  267. physics.vector.Vector or physics.vector.Dyadic object. This function is a
  268. light wrapper to :func:`~.init_printing`. Any keyword
  269. arguments for it are valid here.
  270. {0}
  271. Examples
  272. ========
  273. >>> from sympy import Function, symbols
  274. >>> t, x = symbols('t, x')
  275. >>> omega = Function('omega')
  276. >>> omega(x).diff()
  277. Derivative(omega(x), x)
  278. >>> omega(t).diff()
  279. Derivative(omega(t), t)
  280. Now use the string printer:
  281. >>> from sympy.physics.vector import init_vprinting
  282. >>> init_vprinting(pretty_print=False)
  283. >>> omega(x).diff()
  284. Derivative(omega(x), x)
  285. >>> omega(t).diff()
  286. omega'
  287. """
  288. kwargs['str_printer'] = vsstrrepr
  289. kwargs['pretty_printer'] = vpprint
  290. kwargs['latex_printer'] = vlatex
  291. init_printing(**kwargs)
  292. params = init_printing.__doc__.split('Examples\n ========')[0] # type: ignore
  293. init_vprinting.__doc__ = init_vprinting.__doc__.format(params) # type: ignore