octave.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. """
  2. Octave (and Matlab) code printer
  3. The `OctaveCodePrinter` converts SymPy expressions into Octave expressions.
  4. It uses a subset of the Octave language for Matlab compatibility.
  5. A complete code generator, which uses `octave_code` extensively, can be found
  6. in `sympy.utilities.codegen`. The `codegen` module can be used to generate
  7. complete source code files.
  8. """
  9. from __future__ import annotations
  10. from typing import Any
  11. from sympy.core import Mul, Pow, S, Rational
  12. from sympy.core.mul import _keep_coeff
  13. from sympy.core.numbers import equal_valued
  14. from sympy.printing.codeprinter import CodePrinter
  15. from sympy.printing.precedence import precedence, PRECEDENCE
  16. from re import search
  17. # List of known functions. First, those that have the same name in
  18. # SymPy and Octave. This is almost certainly incomplete!
  19. known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc",
  20. "asin", "acos", "acot", "atan", "atan2", "asec", "acsc",
  21. "sinh", "cosh", "tanh", "coth", "csch", "sech",
  22. "asinh", "acosh", "atanh", "acoth", "asech", "acsch",
  23. "erfc", "erfi", "erf", "erfinv", "erfcinv",
  24. "besseli", "besselj", "besselk", "bessely",
  25. "bernoulli", "beta", "euler", "exp", "factorial", "floor",
  26. "fresnelc", "fresnels", "gamma", "harmonic", "log",
  27. "polylog", "sign", "zeta", "legendre"]
  28. # These functions have different names ("SymPy": "Octave"), more
  29. # generally a mapping to (argument_conditions, octave_function).
  30. known_fcns_src2 = {
  31. "Abs": "abs",
  32. "arg": "angle", # arg/angle ok in Octave but only angle in Matlab
  33. "binomial": "bincoeff",
  34. "ceiling": "ceil",
  35. "chebyshevu": "chebyshevU",
  36. "chebyshevt": "chebyshevT",
  37. "Chi": "coshint",
  38. "Ci": "cosint",
  39. "conjugate": "conj",
  40. "DiracDelta": "dirac",
  41. "Heaviside": "heaviside",
  42. "im": "imag",
  43. "laguerre": "laguerreL",
  44. "LambertW": "lambertw",
  45. "li": "logint",
  46. "loggamma": "gammaln",
  47. "Max": "max",
  48. "Min": "min",
  49. "Mod": "mod",
  50. "polygamma": "psi",
  51. "re": "real",
  52. "RisingFactorial": "pochhammer",
  53. "Shi": "sinhint",
  54. "Si": "sinint",
  55. }
  56. class OctaveCodePrinter(CodePrinter):
  57. """
  58. A printer to convert expressions to strings of Octave/Matlab code.
  59. """
  60. printmethod = "_octave"
  61. language = "Octave"
  62. _operators = {
  63. 'and': '&',
  64. 'or': '|',
  65. 'not': '~',
  66. }
  67. _default_settings: dict[str, Any] = {
  68. 'order': None,
  69. 'full_prec': 'auto',
  70. 'precision': 17,
  71. 'user_functions': {},
  72. 'human': True,
  73. 'allow_unknown_functions': False,
  74. 'contract': True,
  75. 'inline': True,
  76. }
  77. # Note: contract is for expressing tensors as loops (if True), or just
  78. # assignment (if False). FIXME: this should be looked a more carefully
  79. # for Octave.
  80. def __init__(self, settings={}):
  81. super().__init__(settings)
  82. self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1))
  83. self.known_functions.update(dict(known_fcns_src2))
  84. userfuncs = settings.get('user_functions', {})
  85. self.known_functions.update(userfuncs)
  86. def _rate_index_position(self, p):
  87. return p*5
  88. def _get_statement(self, codestring):
  89. return "%s;" % codestring
  90. def _get_comment(self, text):
  91. return "% {}".format(text)
  92. def _declare_number_const(self, name, value):
  93. return "{} = {};".format(name, value)
  94. def _format_code(self, lines):
  95. return self.indent_code(lines)
  96. def _traverse_matrix_indices(self, mat):
  97. # Octave uses Fortran order (column-major)
  98. rows, cols = mat.shape
  99. return ((i, j) for j in range(cols) for i in range(rows))
  100. def _get_loop_opening_ending(self, indices):
  101. open_lines = []
  102. close_lines = []
  103. for i in indices:
  104. # Octave arrays start at 1 and end at dimension
  105. var, start, stop = map(self._print,
  106. [i.label, i.lower + 1, i.upper + 1])
  107. open_lines.append("for %s = %s:%s" % (var, start, stop))
  108. close_lines.append("end")
  109. return open_lines, close_lines
  110. def _print_Mul(self, expr):
  111. # print complex numbers nicely in Octave
  112. if (expr.is_number and expr.is_imaginary and
  113. (S.ImaginaryUnit*expr).is_Integer):
  114. return "%si" % self._print(-S.ImaginaryUnit*expr)
  115. # cribbed from str.py
  116. prec = precedence(expr)
  117. c, e = expr.as_coeff_Mul()
  118. if c < 0:
  119. expr = _keep_coeff(-c, e)
  120. sign = "-"
  121. else:
  122. sign = ""
  123. a = [] # items in the numerator
  124. b = [] # items that are in the denominator (if any)
  125. pow_paren = [] # Will collect all pow with more than one base element and exp = -1
  126. if self.order not in ('old', 'none'):
  127. args = expr.as_ordered_factors()
  128. else:
  129. # use make_args in case expr was something like -x -> x
  130. args = Mul.make_args(expr)
  131. # Gather args for numerator/denominator
  132. for item in args:
  133. if (item.is_commutative and item.is_Pow and item.exp.is_Rational
  134. and item.exp.is_negative):
  135. if item.exp != -1:
  136. b.append(Pow(item.base, -item.exp, evaluate=False))
  137. else:
  138. if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160
  139. pow_paren.append(item)
  140. b.append(Pow(item.base, -item.exp))
  141. elif item.is_Rational and item is not S.Infinity:
  142. if item.p != 1:
  143. a.append(Rational(item.p))
  144. if item.q != 1:
  145. b.append(Rational(item.q))
  146. else:
  147. a.append(item)
  148. a = a or [S.One]
  149. a_str = [self.parenthesize(x, prec) for x in a]
  150. b_str = [self.parenthesize(x, prec) for x in b]
  151. # To parenthesize Pow with exp = -1 and having more than one Symbol
  152. for item in pow_paren:
  153. if item.base in b:
  154. b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
  155. # from here it differs from str.py to deal with "*" and ".*"
  156. def multjoin(a, a_str):
  157. # here we probably are assuming the constants will come first
  158. r = a_str[0]
  159. for i in range(1, len(a)):
  160. mulsym = '*' if a[i-1].is_number else '.*'
  161. r = r + mulsym + a_str[i]
  162. return r
  163. if not b:
  164. return sign + multjoin(a, a_str)
  165. elif len(b) == 1:
  166. divsym = '/' if b[0].is_number else './'
  167. return sign + multjoin(a, a_str) + divsym + b_str[0]
  168. else:
  169. divsym = '/' if all(bi.is_number for bi in b) else './'
  170. return (sign + multjoin(a, a_str) +
  171. divsym + "(%s)" % multjoin(b, b_str))
  172. def _print_Relational(self, expr):
  173. lhs_code = self._print(expr.lhs)
  174. rhs_code = self._print(expr.rhs)
  175. op = expr.rel_op
  176. return "{} {} {}".format(lhs_code, op, rhs_code)
  177. def _print_Pow(self, expr):
  178. powsymbol = '^' if all(x.is_number for x in expr.args) else '.^'
  179. PREC = precedence(expr)
  180. if equal_valued(expr.exp, 0.5):
  181. return "sqrt(%s)" % self._print(expr.base)
  182. if expr.is_commutative:
  183. if equal_valued(expr.exp, -0.5):
  184. sym = '/' if expr.base.is_number else './'
  185. return "1" + sym + "sqrt(%s)" % self._print(expr.base)
  186. if equal_valued(expr.exp, -1):
  187. sym = '/' if expr.base.is_number else './'
  188. return "1" + sym + "%s" % self.parenthesize(expr.base, PREC)
  189. return '%s%s%s' % (self.parenthesize(expr.base, PREC), powsymbol,
  190. self.parenthesize(expr.exp, PREC))
  191. def _print_MatPow(self, expr):
  192. PREC = precedence(expr)
  193. return '%s^%s' % (self.parenthesize(expr.base, PREC),
  194. self.parenthesize(expr.exp, PREC))
  195. def _print_MatrixSolve(self, expr):
  196. PREC = precedence(expr)
  197. return "%s \\ %s" % (self.parenthesize(expr.matrix, PREC),
  198. self.parenthesize(expr.vector, PREC))
  199. def _print_Pi(self, expr):
  200. return 'pi'
  201. def _print_ImaginaryUnit(self, expr):
  202. return "1i"
  203. def _print_Exp1(self, expr):
  204. return "exp(1)"
  205. def _print_GoldenRatio(self, expr):
  206. # FIXME: how to do better, e.g., for octave_code(2*GoldenRatio)?
  207. #return self._print((1+sqrt(S(5)))/2)
  208. return "(1+sqrt(5))/2"
  209. def _print_Assignment(self, expr):
  210. from sympy.codegen.ast import Assignment
  211. from sympy.functions.elementary.piecewise import Piecewise
  212. from sympy.tensor.indexed import IndexedBase
  213. # Copied from codeprinter, but remove special MatrixSymbol treatment
  214. lhs = expr.lhs
  215. rhs = expr.rhs
  216. # We special case assignments that take multiple lines
  217. if not self._settings["inline"] and isinstance(expr.rhs, Piecewise):
  218. # Here we modify Piecewise so each expression is now
  219. # an Assignment, and then continue on the print.
  220. expressions = []
  221. conditions = []
  222. for (e, c) in rhs.args:
  223. expressions.append(Assignment(lhs, e))
  224. conditions.append(c)
  225. temp = Piecewise(*zip(expressions, conditions))
  226. return self._print(temp)
  227. if self._settings["contract"] and (lhs.has(IndexedBase) or
  228. rhs.has(IndexedBase)):
  229. # Here we check if there is looping to be done, and if so
  230. # print the required loops.
  231. return self._doprint_loops(rhs, lhs)
  232. else:
  233. lhs_code = self._print(lhs)
  234. rhs_code = self._print(rhs)
  235. return self._get_statement("%s = %s" % (lhs_code, rhs_code))
  236. def _print_Infinity(self, expr):
  237. return 'inf'
  238. def _print_NegativeInfinity(self, expr):
  239. return '-inf'
  240. def _print_NaN(self, expr):
  241. return 'NaN'
  242. def _print_list(self, expr):
  243. return '{' + ', '.join(self._print(a) for a in expr) + '}'
  244. _print_tuple = _print_list
  245. _print_Tuple = _print_list
  246. _print_List = _print_list
  247. def _print_BooleanTrue(self, expr):
  248. return "true"
  249. def _print_BooleanFalse(self, expr):
  250. return "false"
  251. def _print_bool(self, expr):
  252. return str(expr).lower()
  253. # Could generate quadrature code for definite Integrals?
  254. #_print_Integral = _print_not_supported
  255. def _print_MatrixBase(self, A):
  256. # Handle zero dimensions:
  257. if (A.rows, A.cols) == (0, 0):
  258. return '[]'
  259. elif S.Zero in A.shape:
  260. return 'zeros(%s, %s)' % (A.rows, A.cols)
  261. elif (A.rows, A.cols) == (1, 1):
  262. # Octave does not distinguish between scalars and 1x1 matrices
  263. return self._print(A[0, 0])
  264. return "[%s]" % "; ".join(" ".join([self._print(a) for a in A[r, :]])
  265. for r in range(A.rows))
  266. def _print_SparseRepMatrix(self, A):
  267. from sympy.matrices import Matrix
  268. L = A.col_list();
  269. # make row vectors of the indices and entries
  270. I = Matrix([[k[0] + 1 for k in L]])
  271. J = Matrix([[k[1] + 1 for k in L]])
  272. AIJ = Matrix([[k[2] for k in L]])
  273. return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J),
  274. self._print(AIJ), A.rows, A.cols)
  275. def _print_MatrixElement(self, expr):
  276. return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \
  277. + '(%s, %s)' % (expr.i + 1, expr.j + 1)
  278. def _print_MatrixSlice(self, expr):
  279. def strslice(x, lim):
  280. l = x[0] + 1
  281. h = x[1]
  282. step = x[2]
  283. lstr = self._print(l)
  284. hstr = 'end' if h == lim else self._print(h)
  285. if step == 1:
  286. if l == 1 and h == lim:
  287. return ':'
  288. if l == h:
  289. return lstr
  290. else:
  291. return lstr + ':' + hstr
  292. else:
  293. return ':'.join((lstr, self._print(step), hstr))
  294. return (self._print(expr.parent) + '(' +
  295. strslice(expr.rowslice, expr.parent.shape[0]) + ', ' +
  296. strslice(expr.colslice, expr.parent.shape[1]) + ')')
  297. def _print_Indexed(self, expr):
  298. inds = [ self._print(i) for i in expr.indices ]
  299. return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds))
  300. def _print_Idx(self, expr):
  301. return self._print(expr.label)
  302. def _print_KroneckerDelta(self, expr):
  303. prec = PRECEDENCE["Pow"]
  304. return "double(%s == %s)" % tuple(self.parenthesize(x, prec)
  305. for x in expr.args)
  306. def _print_HadamardProduct(self, expr):
  307. return '.*'.join([self.parenthesize(arg, precedence(expr))
  308. for arg in expr.args])
  309. def _print_HadamardPower(self, expr):
  310. PREC = precedence(expr)
  311. return '.**'.join([
  312. self.parenthesize(expr.base, PREC),
  313. self.parenthesize(expr.exp, PREC)
  314. ])
  315. def _print_Identity(self, expr):
  316. shape = expr.shape
  317. if len(shape) == 2 and shape[0] == shape[1]:
  318. shape = [shape[0]]
  319. s = ", ".join(self._print(n) for n in shape)
  320. return "eye(" + s + ")"
  321. def _print_lowergamma(self, expr):
  322. # Octave implements regularized incomplete gamma function
  323. return "(gammainc({1}, {0}).*gamma({0}))".format(
  324. self._print(expr.args[0]), self._print(expr.args[1]))
  325. def _print_uppergamma(self, expr):
  326. return "(gammainc({1}, {0}, 'upper').*gamma({0}))".format(
  327. self._print(expr.args[0]), self._print(expr.args[1]))
  328. def _print_sinc(self, expr):
  329. #Note: Divide by pi because Octave implements normalized sinc function.
  330. return "sinc(%s)" % self._print(expr.args[0]/S.Pi)
  331. def _print_hankel1(self, expr):
  332. return "besselh(%s, 1, %s)" % (self._print(expr.order),
  333. self._print(expr.argument))
  334. def _print_hankel2(self, expr):
  335. return "besselh(%s, 2, %s)" % (self._print(expr.order),
  336. self._print(expr.argument))
  337. # Note: as of 2015, Octave doesn't have spherical Bessel functions
  338. def _print_jn(self, expr):
  339. from sympy.functions import sqrt, besselj
  340. x = expr.argument
  341. expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x)
  342. return self._print(expr2)
  343. def _print_yn(self, expr):
  344. from sympy.functions import sqrt, bessely
  345. x = expr.argument
  346. expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x)
  347. return self._print(expr2)
  348. def _print_airyai(self, expr):
  349. return "airy(0, %s)" % self._print(expr.args[0])
  350. def _print_airyaiprime(self, expr):
  351. return "airy(1, %s)" % self._print(expr.args[0])
  352. def _print_airybi(self, expr):
  353. return "airy(2, %s)" % self._print(expr.args[0])
  354. def _print_airybiprime(self, expr):
  355. return "airy(3, %s)" % self._print(expr.args[0])
  356. def _print_expint(self, expr):
  357. mu, x = expr.args
  358. if mu != 1:
  359. return self._print_not_supported(expr)
  360. return "expint(%s)" % self._print(x)
  361. def _one_or_two_reversed_args(self, expr):
  362. assert len(expr.args) <= 2
  363. return '{name}({args})'.format(
  364. name=self.known_functions[expr.__class__.__name__],
  365. args=", ".join([self._print(x) for x in reversed(expr.args)])
  366. )
  367. _print_DiracDelta = _print_LambertW = _one_or_two_reversed_args
  368. def _nested_binary_math_func(self, expr):
  369. return '{name}({arg1}, {arg2})'.format(
  370. name=self.known_functions[expr.__class__.__name__],
  371. arg1=self._print(expr.args[0]),
  372. arg2=self._print(expr.func(*expr.args[1:]))
  373. )
  374. _print_Max = _print_Min = _nested_binary_math_func
  375. def _print_Piecewise(self, expr):
  376. if expr.args[-1].cond != True:
  377. # We need the last conditional to be a True, otherwise the resulting
  378. # function may not return a result.
  379. raise ValueError("All Piecewise expressions must contain an "
  380. "(expr, True) statement to be used as a default "
  381. "condition. Without one, the generated "
  382. "expression may not evaluate to anything under "
  383. "some condition.")
  384. lines = []
  385. if self._settings["inline"]:
  386. # Express each (cond, expr) pair in a nested Horner form:
  387. # (condition) .* (expr) + (not cond) .* (<others>)
  388. # Expressions that result in multiple statements won't work here.
  389. ecpairs = ["({0}).*({1}) + (~({0})).*(".format
  390. (self._print(c), self._print(e))
  391. for e, c in expr.args[:-1]]
  392. elast = "%s" % self._print(expr.args[-1].expr)
  393. pw = " ...\n".join(ecpairs) + elast + ")"*len(ecpairs)
  394. # Note: current need these outer brackets for 2*pw. Would be
  395. # nicer to teach parenthesize() to do this for us when needed!
  396. return "(" + pw + ")"
  397. else:
  398. for i, (e, c) in enumerate(expr.args):
  399. if i == 0:
  400. lines.append("if (%s)" % self._print(c))
  401. elif i == len(expr.args) - 1 and c == True:
  402. lines.append("else")
  403. else:
  404. lines.append("elseif (%s)" % self._print(c))
  405. code0 = self._print(e)
  406. lines.append(code0)
  407. if i == len(expr.args) - 1:
  408. lines.append("end")
  409. return "\n".join(lines)
  410. def _print_zeta(self, expr):
  411. if len(expr.args) == 1:
  412. return "zeta(%s)" % self._print(expr.args[0])
  413. else:
  414. # Matlab two argument zeta is not equivalent to SymPy's
  415. return self._print_not_supported(expr)
  416. def indent_code(self, code):
  417. """Accepts a string of code or a list of code lines"""
  418. # code mostly copied from ccode
  419. if isinstance(code, str):
  420. code_lines = self.indent_code(code.splitlines(True))
  421. return ''.join(code_lines)
  422. tab = " "
  423. inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ')
  424. dec_regex = ('^end$', '^elseif ', '^else$')
  425. # pre-strip left-space from the code
  426. code = [ line.lstrip(' \t') for line in code ]
  427. increase = [ int(any(search(re, line) for re in inc_regex))
  428. for line in code ]
  429. decrease = [ int(any(search(re, line) for re in dec_regex))
  430. for line in code ]
  431. pretty = []
  432. level = 0
  433. for n, line in enumerate(code):
  434. if line in ('', '\n'):
  435. pretty.append(line)
  436. continue
  437. level -= decrease[n]
  438. pretty.append("%s%s" % (tab*level, line))
  439. level += increase[n]
  440. return pretty
  441. def octave_code(expr, assign_to=None, **settings):
  442. r"""Converts `expr` to a string of Octave (or Matlab) code.
  443. The string uses a subset of the Octave language for Matlab compatibility.
  444. Parameters
  445. ==========
  446. expr : Expr
  447. A SymPy expression to be converted.
  448. assign_to : optional
  449. When given, the argument is used as the name of the variable to which
  450. the expression is assigned. Can be a string, ``Symbol``,
  451. ``MatrixSymbol``, or ``Indexed`` type. This can be helpful for
  452. expressions that generate multi-line statements.
  453. precision : integer, optional
  454. The precision for numbers such as pi [default=16].
  455. user_functions : dict, optional
  456. A dictionary where keys are ``FunctionClass`` instances and values are
  457. their string representations. Alternatively, the dictionary value can
  458. be a list of tuples i.e. [(argument_test, cfunction_string)]. See
  459. below for examples.
  460. human : bool, optional
  461. If True, the result is a single string that may contain some constant
  462. declarations for the number symbols. If False, the same information is
  463. returned in a tuple of (symbols_to_declare, not_supported_functions,
  464. code_text). [default=True].
  465. contract: bool, optional
  466. If True, ``Indexed`` instances are assumed to obey tensor contraction
  467. rules and the corresponding nested loops over indices are generated.
  468. Setting contract=False will not generate loops, instead the user is
  469. responsible to provide values for the indices in the code.
  470. [default=True].
  471. inline: bool, optional
  472. If True, we try to create single-statement code instead of multiple
  473. statements. [default=True].
  474. Examples
  475. ========
  476. >>> from sympy import octave_code, symbols, sin, pi
  477. >>> x = symbols('x')
  478. >>> octave_code(sin(x).series(x).removeO())
  479. 'x.^5/120 - x.^3/6 + x'
  480. >>> from sympy import Rational, ceiling
  481. >>> x, y, tau = symbols("x, y, tau")
  482. >>> octave_code((2*tau)**Rational(7, 2))
  483. '8*sqrt(2)*tau.^(7/2)'
  484. Note that element-wise (Hadamard) operations are used by default between
  485. symbols. This is because its very common in Octave to write "vectorized"
  486. code. It is harmless if the values are scalars.
  487. >>> octave_code(sin(pi*x*y), assign_to="s")
  488. 's = sin(pi*x.*y);'
  489. If you need a matrix product "*" or matrix power "^", you can specify the
  490. symbol as a ``MatrixSymbol``.
  491. >>> from sympy import Symbol, MatrixSymbol
  492. >>> n = Symbol('n', integer=True, positive=True)
  493. >>> A = MatrixSymbol('A', n, n)
  494. >>> octave_code(3*pi*A**3)
  495. '(3*pi)*A^3'
  496. This class uses several rules to decide which symbol to use a product.
  497. Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*".
  498. A HadamardProduct can be used to specify componentwise multiplication ".*"
  499. of two MatrixSymbols. There is currently there is no easy way to specify
  500. scalar symbols, so sometimes the code might have some minor cosmetic
  501. issues. For example, suppose x and y are scalars and A is a Matrix, then
  502. while a human programmer might write "(x^2*y)*A^3", we generate:
  503. >>> octave_code(x**2*y*A**3)
  504. '(x.^2.*y)*A^3'
  505. Matrices are supported using Octave inline notation. When using
  506. ``assign_to`` with matrices, the name can be specified either as a string
  507. or as a ``MatrixSymbol``. The dimensions must align in the latter case.
  508. >>> from sympy import Matrix, MatrixSymbol
  509. >>> mat = Matrix([[x**2, sin(x), ceiling(x)]])
  510. >>> octave_code(mat, assign_to='A')
  511. 'A = [x.^2 sin(x) ceil(x)];'
  512. ``Piecewise`` expressions are implemented with logical masking by default.
  513. Alternatively, you can pass "inline=False" to use if-else conditionals.
  514. Note that if the ``Piecewise`` lacks a default term, represented by
  515. ``(expr, True)`` then an error will be thrown. This is to prevent
  516. generating an expression that may not evaluate to anything.
  517. >>> from sympy import Piecewise
  518. >>> pw = Piecewise((x + 1, x > 0), (x, True))
  519. >>> octave_code(pw, assign_to=tau)
  520. 'tau = ((x > 0).*(x + 1) + (~(x > 0)).*(x));'
  521. Note that any expression that can be generated normally can also exist
  522. inside a Matrix:
  523. >>> mat = Matrix([[x**2, pw, sin(x)]])
  524. >>> octave_code(mat, assign_to='A')
  525. 'A = [x.^2 ((x > 0).*(x + 1) + (~(x > 0)).*(x)) sin(x)];'
  526. Custom printing can be defined for certain types by passing a dictionary of
  527. "type" : "function" to the ``user_functions`` kwarg. Alternatively, the
  528. dictionary value can be a list of tuples i.e., [(argument_test,
  529. cfunction_string)]. This can be used to call a custom Octave function.
  530. >>> from sympy import Function
  531. >>> f = Function('f')
  532. >>> g = Function('g')
  533. >>> custom_functions = {
  534. ... "f": "existing_octave_fcn",
  535. ... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"),
  536. ... (lambda x: not x.is_Matrix, "my_fcn")]
  537. ... }
  538. >>> mat = Matrix([[1, x]])
  539. >>> octave_code(f(x) + g(x) + g(mat), user_functions=custom_functions)
  540. 'existing_octave_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])'
  541. Support for loops is provided through ``Indexed`` types. With
  542. ``contract=True`` these expressions will be turned into loops, whereas
  543. ``contract=False`` will just print the assignment expression that should be
  544. looped over:
  545. >>> from sympy import Eq, IndexedBase, Idx
  546. >>> len_y = 5
  547. >>> y = IndexedBase('y', shape=(len_y,))
  548. >>> t = IndexedBase('t', shape=(len_y,))
  549. >>> Dy = IndexedBase('Dy', shape=(len_y-1,))
  550. >>> i = Idx('i', len_y-1)
  551. >>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
  552. >>> octave_code(e.rhs, assign_to=e.lhs, contract=False)
  553. 'Dy(i) = (y(i + 1) - y(i))./(t(i + 1) - t(i));'
  554. """
  555. return OctaveCodePrinter(settings).doprint(expr, assign_to)
  556. def print_octave_code(expr, **settings):
  557. """Prints the Octave (or Matlab) representation of the given expression.
  558. See `octave_code` for the meaning of the optional arguments.
  559. """
  560. print(octave_code(expr, **settings))