ast.py 55 KB


  1. """
  2. Types used to represent a full function/module as an Abstract Syntax Tree.
  3. Most types are small, and are merely used as tokens in the AST. A tree diagram
  4. has been included below to illustrate the relationships between the AST types.
  5. AST Type Tree
  6. -------------
  7. ::
  8. *Basic*
  9. |
  10. |
  11. CodegenAST
  12. |
  13. |--->AssignmentBase
  14. | |--->Assignment
  15. | |--->AugmentedAssignment
  16. | |--->AddAugmentedAssignment
  17. | |--->SubAugmentedAssignment
  18. | |--->MulAugmentedAssignment
  19. | |--->DivAugmentedAssignment
  20. | |--->ModAugmentedAssignment
  21. |
  22. |--->CodeBlock
  23. |
  24. |
  25. |--->Token
  26. |--->Attribute
  27. |--->For
  28. |--->String
  29. | |--->QuotedString
  30. | |--->Comment
  31. |--->Type
  32. | |--->IntBaseType
  33. | | |--->_SizedIntType
  34. | | |--->SignedIntType
  35. | | |--->UnsignedIntType
  36. | |--->FloatBaseType
  37. | |--->FloatType
  38. | |--->ComplexBaseType
  39. | |--->ComplexType
  40. |--->Node
  41. | |--->Variable
  42. | | |---> Pointer
  43. | |--->FunctionPrototype
  44. | |--->FunctionDefinition
  45. |--->Element
  46. |--->Declaration
  47. |--->While
  48. |--->Scope
  49. |--->Stream
  50. |--->Print
  51. |--->FunctionCall
  52. |--->BreakToken
  53. |--->ContinueToken
  54. |--->NoneToken
  55. |--->Return
  56. Predefined types
  57. ----------------
  58. A number of ``Type`` instances are provided in the ``sympy.codegen.ast`` module
  59. for convenience. Perhaps the two most common ones for code-generation (of numeric
  60. codes) are ``float32`` and ``float64`` (known as single and double precision respectively).
  61. There are also precision generic versions of Types (for which the codeprinters selects the
  62. underlying data type at time of printing): ``real``, ``integer``, ``complex_``, ``bool_``.
  63. The other ``Type`` instances defined are:
  64. - ``intc``: Integer type used by C's "int".
  65. - ``intp``: Integer type used by C's "unsigned".
  66. - ``int8``, ``int16``, ``int32``, ``int64``: n-bit integers.
  67. - ``uint8``, ``uint16``, ``uint32``, ``uint64``: n-bit unsigned integers.
  68. - ``float80``: known as "extended precision" on modern x86/amd64 hardware.
  69. - ``complex64``: Complex number represented by two ``float32`` numbers
  70. - ``complex128``: Complex number represented by two ``float64`` numbers
  71. Using the nodes
  72. ---------------
  73. It is possible to construct simple algorithms using the AST nodes. Let's construct a loop applying
  74. Newton's method::
  75. >>> from sympy import symbols, cos
  76. >>> from sympy.codegen.ast import While, Assignment, aug_assign, Print
  77. >>> t, dx, x = symbols('tol delta val')
  78. >>> expr = cos(x) - x**3
  79. >>> whl = While(abs(dx) > t, [
  80. ... Assignment(dx, -expr/expr.diff(x)),
  81. ... aug_assign(x, '+', dx),
  82. ... Print([x])
  83. ... ])
  84. >>> from sympy import pycode
  85. >>> py_str = pycode(whl)
  86. >>> print(py_str)
  87. while (abs(delta) > tol):
  88. delta = (val**3 - math.cos(val))/(-3*val**2 - math.sin(val))
  89. val += delta
  90. print(val)
  91. >>> import math
  92. >>> tol, val, delta = 1e-5, 0.5, float('inf')
  93. >>> exec(py_str)
  94. 1.1121416371
  95. 0.909672693737
  96. 0.867263818209
  97. 0.865477135298
  98. 0.865474033111
  99. >>> print('%3.1g' % (math.cos(val) - val**3))
  100. -3e-11
  101. If we want to generate Fortran code for the same while loop we simple call ``fcode``::
  102. >>> from sympy import fcode
  103. >>> print(fcode(whl, standard=2003, source_format='free'))
  104. do while (abs(delta) > tol)
  105. delta = (val**3 - cos(val))/(-3*val**2 - sin(val))
  106. val = val + delta
  107. print *, val
  108. end do
  109. There is a function constructing a loop (or a complete function) like this in
  110. :mod:`sympy.codegen.algorithms`.
  111. """
  112. from __future__ import annotations
  113. from typing import Any
  114. from collections import defaultdict
  115. from sympy.core.relational import (Ge, Gt, Le, Lt)
  116. from sympy.core import Symbol, Tuple, Dummy
  117. from sympy.core.basic import Basic
  118. from sympy.core.expr import Expr, Atom
  119. from sympy.core.numbers import Float, Integer, oo
  120. from sympy.core.sympify import _sympify, sympify, SympifyError
  121. from sympy.utilities.iterables import (iterable, topological_sort,
  122. numbered_symbols, filter_symbols)
  123. def _mk_Tuple(args):
  124. """
  125. Create a SymPy Tuple object from an iterable, converting Python strings to
  126. AST strings.
  127. Parameters
  128. ==========
  129. args: iterable
  130. Arguments to :class:`sympy.Tuple`.
  131. Returns
  132. =======
  133. sympy.Tuple
  134. """
  135. args = [String(arg) if isinstance(arg, str) else arg for arg in args]
  136. return Tuple(*args)
  137. class CodegenAST(Basic):
  138. __slots__ = ()
  139. class Token(CodegenAST):
  140. """ Base class for the AST types.
  141. Explanation
  142. ===========
  143. Defining fields are set in ``_fields``. Attributes (defined in _fields)
  144. are only allowed to contain instances of Basic (unless atomic, see
  145. ``String``). The arguments to ``__new__()`` correspond to the attributes in
  146. the order defined in ``_fields`. The ``defaults`` class attribute is a
  147. dictionary mapping attribute names to their default values.
  148. Subclasses should not need to override the ``__new__()`` method. They may
  149. define a class or static method named ``_construct_<attr>`` for each
  150. attribute to process the value passed to ``__new__()``. Attributes listed
  151. in the class attribute ``not_in_args`` are not passed to :class:`~.Basic`.
  152. """
  153. __slots__: tuple[str, ...] = ()
  154. _fields = __slots__
  155. defaults: dict[str, Any] = {}
  156. not_in_args: list[str] = []
  157. indented_args = ['body']
  158. @property
  159. def is_Atom(self):
  160. return len(self._fields) == 0
  161. @classmethod
  162. def _get_constructor(cls, attr):
  163. """ Get the constructor function for an attribute by name. """
  164. return getattr(cls, '_construct_%s' % attr, lambda x: x)
  165. @classmethod
  166. def _construct(cls, attr, arg):
  167. """ Construct an attribute value from argument passed to ``__new__()``. """
  168. # arg may be ``NoneToken()``, so comparison is done using == instead of ``is`` operator
  169. if arg == None:
  170. return cls.defaults.get(attr, none)
  171. else:
  172. if isinstance(arg, Dummy): # SymPy's replace uses Dummy instances
  173. return arg
  174. else:
  175. return cls._get_constructor(attr)(arg)
  176. def __new__(cls, *args, **kwargs):
  177. # Pass through existing instances when given as sole argument
  178. if len(args) == 1 and not kwargs and isinstance(args[0], cls):
  179. return args[0]
  180. if len(args) > len(cls._fields):
  181. raise ValueError("Too many arguments (%d), expected at most %d" % (len(args), len(cls._fields)))
  182. attrvals = []
  183. # Process positional arguments
  184. for attrname, argval in zip(cls._fields, args):
  185. if attrname in kwargs:
  186. raise TypeError('Got multiple values for attribute %r' % attrname)
  187. attrvals.append(cls._construct(attrname, argval))
  188. # Process keyword arguments
  189. for attrname in cls._fields[len(args):]:
  190. if attrname in kwargs:
  191. argval = kwargs.pop(attrname)
  192. elif attrname in cls.defaults:
  193. argval = cls.defaults[attrname]
  194. else:
  195. raise TypeError('No value for %r given and attribute has no default' % attrname)
  196. attrvals.append(cls._construct(attrname, argval))
  197. if kwargs:
  198. raise ValueError("Unknown keyword arguments: %s" % ' '.join(kwargs))
  199. # Parent constructor
  200. basic_args = [
  201. val for attr, val in zip(cls._fields, attrvals)
  202. if attr not in cls.not_in_args
  203. ]
  204. obj = CodegenAST.__new__(cls, *basic_args)
  205. # Set attributes
  206. for attr, arg in zip(cls._fields, attrvals):
  207. setattr(obj, attr, arg)
  208. return obj
  209. def __eq__(self, other):
  210. if not isinstance(other, self.__class__):
  211. return False
  212. for attr in self._fields:
  213. if getattr(self, attr) != getattr(other, attr):
  214. return False
  215. return True
  216. def _hashable_content(self):
  217. return tuple([getattr(self, attr) for attr in self._fields])
  218. def __hash__(self):
  219. return super().__hash__()
  220. def _joiner(self, k, indent_level):
  221. return (',\n' + ' '*indent_level) if k in self.indented_args else ', '
  222. def _indented(self, printer, k, v, *args, **kwargs):
  223. il = printer._context['indent_level']
  224. def _print(arg):
  225. if isinstance(arg, Token):
  226. return printer._print(arg, *args, joiner=self._joiner(k, il), **kwargs)
  227. else:
  228. return printer._print(arg, *args, **kwargs)
  229. if isinstance(v, Tuple):
  230. joined = self._joiner(k, il).join([_print(arg) for arg in v.args])
  231. if k in self.indented_args:
  232. return '(\n' + ' '*il + joined + ',\n' + ' '*(il - 4) + ')'
  233. else:
  234. return ('({0},)' if len(v.args) == 1 else '({0})').format(joined)
  235. else:
  236. return _print(v)
  237. def _sympyrepr(self, printer, *args, joiner=', ', **kwargs):
  238. from sympy.printing.printer import printer_context
  239. exclude = kwargs.get('exclude', ())
  240. values = [getattr(self, k) for k in self._fields]
  241. indent_level = printer._context.get('indent_level', 0)
  242. arg_reprs = []
  243. for i, (attr, value) in enumerate(zip(self._fields, values)):
  244. if attr in exclude:
  245. continue
  246. # Skip attributes which have the default value
  247. if attr in self.defaults and value == self.defaults[attr]:
  248. continue
  249. ilvl = indent_level + 4 if attr in self.indented_args else 0
  250. with printer_context(printer, indent_level=ilvl):
  251. indented = self._indented(printer, attr, value, *args, **kwargs)
  252. arg_reprs.append(('{1}' if i == 0 else '{0}={1}').format(attr, indented.lstrip()))
  253. return "{}({})".format(self.__class__.__name__, joiner.join(arg_reprs))
  254. _sympystr = _sympyrepr
  255. def __repr__(self): # sympy.core.Basic.__repr__ uses sstr
  256. from sympy.printing import srepr
  257. return srepr(self)
  258. def kwargs(self, exclude=(), apply=None):
  259. """ Get instance's attributes as dict of keyword arguments.
  260. Parameters
  261. ==========
  262. exclude : collection of str
  263. Collection of keywords to exclude.
  264. apply : callable, optional
  265. Function to apply to all values.
  266. """
  267. kwargs = {k: getattr(self, k) for k in self._fields if k not in exclude}
  268. if apply is not None:
  269. return {k: apply(v) for k, v in kwargs.items()}
  270. else:
  271. return kwargs
  272. class BreakToken(Token):
  273. """ Represents 'break' in C/Python ('exit' in Fortran).
  274. Use the premade instance ``break_`` or instantiate manually.
  275. Examples
  276. ========
  277. >>> from sympy import ccode, fcode
  278. >>> from sympy.codegen.ast import break_
  279. >>> ccode(break_)
  280. 'break'
  281. >>> fcode(break_, source_format='free')
  282. 'exit'
  283. """
  284. break_ = BreakToken()
  285. class ContinueToken(Token):
  286. """ Represents 'continue' in C/Python ('cycle' in Fortran)
  287. Use the premade instance ``continue_`` or instantiate manually.
  288. Examples
  289. ========
  290. >>> from sympy import ccode, fcode
  291. >>> from sympy.codegen.ast import continue_
  292. >>> ccode(continue_)
  293. 'continue'
  294. >>> fcode(continue_, source_format='free')
  295. 'cycle'
  296. """
  297. continue_ = ContinueToken()
  298. class NoneToken(Token):
  299. """ The AST equivalence of Python's NoneType
  300. The corresponding instance of Python's ``None`` is ``none``.
  301. Examples
  302. ========
  303. >>> from sympy.codegen.ast import none, Variable
  304. >>> from sympy import pycode
  305. >>> print(pycode(Variable('x').as_Declaration(value=none)))
  306. x = None
  307. """
  308. def __eq__(self, other):
  309. return other is None or isinstance(other, NoneToken)
  310. def _hashable_content(self):
  311. return ()
  312. def __hash__(self):
  313. return super().__hash__()
  314. none = NoneToken()
  315. class AssignmentBase(CodegenAST):
  316. """ Abstract base class for Assignment and AugmentedAssignment.
  317. Attributes:
  318. ===========
  319. op : str
  320. Symbol for assignment operator, e.g. "=", "+=", etc.
  321. """
  322. def __new__(cls, lhs, rhs):
  323. lhs = _sympify(lhs)
  324. rhs = _sympify(rhs)
  325. cls._check_args(lhs, rhs)
  326. return super().__new__(cls, lhs, rhs)
  327. @property
  328. def lhs(self):
  329. return self.args[0]
  330. @property
  331. def rhs(self):
  332. return self.args[1]
  333. @classmethod
  334. def _check_args(cls, lhs, rhs):
  335. """ Check arguments to __new__ and raise exception if any problems found.
  336. Derived classes may wish to override this.
  337. """
  338. from sympy.matrices.expressions.matexpr import (
  339. MatrixElement, MatrixSymbol)
  340. from sympy.tensor.indexed import Indexed
  341. from sympy.tensor.array.expressions import ArrayElement
  342. # Tuple of things that can be on the lhs of an assignment
  343. assignable = (Symbol, MatrixSymbol, MatrixElement, Indexed, Element, Variable,
  344. ArrayElement)
  345. if not isinstance(lhs, assignable):
  346. raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
  347. # Indexed types implement shape, but don't define it until later. This
  348. # causes issues in assignment validation. For now, matrices are defined
  349. # as anything with a shape that is not an Indexed
  350. lhs_is_mat = hasattr(lhs, 'shape') and not isinstance(lhs, Indexed)
  351. rhs_is_mat = hasattr(rhs, 'shape') and not isinstance(rhs, Indexed)
  352. # If lhs and rhs have same structure, then this assignment is ok
  353. if lhs_is_mat:
  354. if not rhs_is_mat:
  355. raise ValueError("Cannot assign a scalar to a matrix.")
  356. elif lhs.shape != rhs.shape:
  357. raise ValueError("Dimensions of lhs and rhs do not align.")
  358. elif rhs_is_mat and not lhs_is_mat:
  359. raise ValueError("Cannot assign a matrix to a scalar.")
  360. class Assignment(AssignmentBase):
  361. """
  362. Represents variable assignment for code generation.
  363. Parameters
  364. ==========
  365. lhs : Expr
  366. SymPy object representing the lhs of the expression. These should be
  367. singular objects, such as one would use in writing code. Notable types
  368. include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that
  369. subclass these types are also supported.
  370. rhs : Expr
  371. SymPy object representing the rhs of the expression. This can be any
  372. type, provided its shape corresponds to that of the lhs. For example,
  373. a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as
  374. the dimensions will not align.
  375. Examples
  376. ========
  377. >>> from sympy import symbols, MatrixSymbol, Matrix
  378. >>> from sympy.codegen.ast import Assignment
  379. >>> x, y, z = symbols('x, y, z')
  380. >>> Assignment(x, y)
  381. Assignment(x, y)
  382. >>> Assignment(x, 0)
  383. Assignment(x, 0)
  384. >>> A = MatrixSymbol('A', 1, 3)
  385. >>> mat = Matrix([x, y, z]).T
  386. >>> Assignment(A, mat)
  387. Assignment(A, Matrix([[x, y, z]]))
  388. >>> Assignment(A[0, 1], x)
  389. Assignment(A[0, 1], x)
  390. """
  391. op = ':='
  392. class AugmentedAssignment(AssignmentBase):
  393. """
  394. Base class for augmented assignments.
  395. Attributes:
  396. ===========
  397. binop : str
  398. Symbol for binary operation being applied in the assignment, such as "+",
  399. "*", etc.
  400. """
  401. binop = None # type: str
  402. @property
  403. def op(self):
  404. return self.binop + '='
  405. class AddAugmentedAssignment(AugmentedAssignment):
  406. binop = '+'
  407. class SubAugmentedAssignment(AugmentedAssignment):
  408. binop = '-'
  409. class MulAugmentedAssignment(AugmentedAssignment):
  410. binop = '*'
  411. class DivAugmentedAssignment(AugmentedAssignment):
  412. binop = '/'
  413. class ModAugmentedAssignment(AugmentedAssignment):
  414. binop = '%'
  415. # Mapping from binary op strings to AugmentedAssignment subclasses
  416. augassign_classes = {
  417. cls.binop: cls for cls in [
  418. AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
  419. DivAugmentedAssignment, ModAugmentedAssignment
  420. ]
  421. }
  422. def aug_assign(lhs, op, rhs):
  423. """
  424. Create 'lhs op= rhs'.
  425. Explanation
  426. ===========
  427. Represents augmented variable assignment for code generation. This is a
  428. convenience function. You can also use the AugmentedAssignment classes
  429. directly, like AddAugmentedAssignment(x, y).
  430. Parameters
  431. ==========
  432. lhs : Expr
  433. SymPy object representing the lhs of the expression. These should be
  434. singular objects, such as one would use in writing code. Notable types
  435. include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that
  436. subclass these types are also supported.
  437. op : str
  438. Operator (+, -, /, \\*, %).
  439. rhs : Expr
  440. SymPy object representing the rhs of the expression. This can be any
  441. type, provided its shape corresponds to that of the lhs. For example,
  442. a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as
  443. the dimensions will not align.
  444. Examples
  445. ========
  446. >>> from sympy import symbols
  447. >>> from sympy.codegen.ast import aug_assign
  448. >>> x, y = symbols('x, y')
  449. >>> aug_assign(x, '+', y)
  450. AddAugmentedAssignment(x, y)
  451. """
  452. if op not in augassign_classes:
  453. raise ValueError("Unrecognized operator %s" % op)
  454. return augassign_classes[op](lhs, rhs)
  455. class CodeBlock(CodegenAST):
  456. """
  457. Represents a block of code.
  458. Explanation
  459. ===========
  460. For now only assignments are supported. This restriction will be lifted in
  461. the future.
  462. Useful attributes on this object are:
  463. ``left_hand_sides``:
  464. Tuple of left-hand sides of assignments, in order.
  465. ``left_hand_sides``:
  466. Tuple of right-hand sides of assignments, in order.
  467. ``free_symbols``: Free symbols of the expressions in the right-hand sides
  468. which do not appear in the left-hand side of an assignment.
  469. Useful methods on this object are:
  470. ``topological_sort``:
  471. Class method. Return a CodeBlock with assignments
  472. sorted so that variables are assigned before they
  473. are used.
  474. ``cse``:
  475. Return a new CodeBlock with common subexpressions eliminated and
  476. pulled out as assignments.
  477. Examples
  478. ========
  479. >>> from sympy import symbols, ccode
  480. >>> from sympy.codegen.ast import CodeBlock, Assignment
  481. >>> x, y = symbols('x y')
  482. >>> c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
  483. >>> print(ccode(c))
  484. x = 1;
  485. y = x + 1;
  486. """
  487. def __new__(cls, *args):
  488. left_hand_sides = []
  489. right_hand_sides = []
  490. for i in args:
  491. if isinstance(i, Assignment):
  492. lhs, rhs = i.args
  493. left_hand_sides.append(lhs)
  494. right_hand_sides.append(rhs)
  495. obj = CodegenAST.__new__(cls, *args)
  496. obj.left_hand_sides = Tuple(*left_hand_sides)
  497. obj.right_hand_sides = Tuple(*right_hand_sides)
  498. return obj
  499. def __iter__(self):
  500. return iter(self.args)
  501. def _sympyrepr(self, printer, *args, **kwargs):
  502. il = printer._context.get('indent_level', 0)
  503. joiner = ',\n' + ' '*il
  504. joined = joiner.join(map(printer._print, self.args))
  505. return ('{}(\n'.format(' '*(il-4) + self.__class__.__name__,) +
  506. ' '*il + joined + '\n' + ' '*(il - 4) + ')')
  507. _sympystr = _sympyrepr
  508. @property
  509. def free_symbols(self):
  510. return super().free_symbols - set(self.left_hand_sides)
  511. @classmethod
  512. def topological_sort(cls, assignments):
  513. """
  514. Return a CodeBlock with topologically sorted assignments so that
  515. variables are assigned before they are used.
  516. Examples
  517. ========
  518. The existing order of assignments is preserved as much as possible.
  519. This function assumes that variables are assigned to only once.
  520. This is a class constructor so that the default constructor for
  521. CodeBlock can error when variables are used before they are assigned.
  522. >>> from sympy import symbols
  523. >>> from sympy.codegen.ast import CodeBlock, Assignment
  524. >>> x, y, z = symbols('x y z')
  525. >>> assignments = [
  526. ... Assignment(x, y + z),
  527. ... Assignment(y, z + 1),
  528. ... Assignment(z, 2),
  529. ... ]
  530. >>> CodeBlock.topological_sort(assignments)
  531. CodeBlock(
  532. Assignment(z, 2),
  533. Assignment(y, z + 1),
  534. Assignment(x, y + z)
  535. )
  536. """
  537. if not all(isinstance(i, Assignment) for i in assignments):
  538. # Will support more things later
  539. raise NotImplementedError("CodeBlock.topological_sort only supports Assignments")
  540. if any(isinstance(i, AugmentedAssignment) for i in assignments):
  541. raise NotImplementedError("CodeBlock.topological_sort does not yet work with AugmentedAssignments")
  542. # Create a graph where the nodes are assignments and there is a directed edge
  543. # between nodes that use a variable and nodes that assign that
  544. # variable, like
  545. # [(x := 1, y := x + 1), (x := 1, z := y + z), (y := x + 1, z := y + z)]
  546. # If we then topologically sort these nodes, they will be in
  547. # assignment order, like
  548. # x := 1
  549. # y := x + 1
  550. # z := y + z
  551. # A = The nodes
  552. #
  553. # enumerate keeps nodes in the same order they are already in if
  554. # possible. It will also allow us to handle duplicate assignments to
  555. # the same variable when those are implemented.
  556. A = list(enumerate(assignments))
  557. # var_map = {variable: [nodes for which this variable is assigned to]}
  558. # like {x: [(1, x := y + z), (4, x := 2 * w)], ...}
  559. var_map = defaultdict(list)
  560. for node in A:
  561. i, a = node
  562. var_map[a.lhs].append(node)
  563. # E = Edges in the graph
  564. E = []
  565. for dst_node in A:
  566. i, a = dst_node
  567. for s in a.rhs.free_symbols:
  568. for src_node in var_map[s]:
  569. E.append((src_node, dst_node))
  570. ordered_assignments = topological_sort([A, E])
  571. # De-enumerate the result
  572. return cls(*[a for i, a in ordered_assignments])
  573. def cse(self, symbols=None, optimizations=None, postprocess=None,
  574. order='canonical'):
  575. """
  576. Return a new code block with common subexpressions eliminated.
  577. Explanation
  578. ===========
  579. See the docstring of :func:`sympy.simplify.cse_main.cse` for more
  580. information.
  581. Examples
  582. ========
  583. >>> from sympy import symbols, sin
  584. >>> from sympy.codegen.ast import CodeBlock, Assignment
  585. >>> x, y, z = symbols('x y z')
  586. >>> c = CodeBlock(
  587. ... Assignment(x, 1),
  588. ... Assignment(y, sin(x) + 1),
  589. ... Assignment(z, sin(x) - 1),
  590. ... )
  591. ...
  592. >>> c.cse()
  593. CodeBlock(
  594. Assignment(x, 1),
  595. Assignment(x0, sin(x)),
  596. Assignment(y, x0 + 1),
  597. Assignment(z, x0 - 1)
  598. )
  599. """
  600. from sympy.simplify.cse_main import cse
  601. # Check that the CodeBlock only contains assignments to unique variables
  602. if not all(isinstance(i, Assignment) for i in self.args):
  603. # Will support more things later
  604. raise NotImplementedError("CodeBlock.cse only supports Assignments")
  605. if any(isinstance(i, AugmentedAssignment) for i in self.args):
  606. raise NotImplementedError("CodeBlock.cse does not yet work with AugmentedAssignments")
  607. for i, lhs in enumerate(self.left_hand_sides):
  608. if lhs in self.left_hand_sides[:i]:
  609. raise NotImplementedError("Duplicate assignments to the same "
  610. "variable are not yet supported (%s)" % lhs)
  611. # Ensure new symbols for subexpressions do not conflict with existing
  612. existing_symbols = self.atoms(Symbol)
  613. if symbols is None:
  614. symbols = numbered_symbols()
  615. symbols = filter_symbols(symbols, existing_symbols)
  616. replacements, reduced_exprs = cse(list(self.right_hand_sides),
  617. symbols=symbols, optimizations=optimizations, postprocess=postprocess,
  618. order=order)
  619. new_block = [Assignment(var, expr) for var, expr in
  620. zip(self.left_hand_sides, reduced_exprs)]
  621. new_assignments = [Assignment(var, expr) for var, expr in replacements]
  622. return self.topological_sort(new_assignments + new_block)
  623. class For(Token):
  624. """Represents a 'for-loop' in the code.
  625. Expressions are of the form:
  626. "for target in iter:
  627. body..."
  628. Parameters
  629. ==========
  630. target : symbol
  631. iter : iterable
  632. body : CodeBlock or iterable
  633. ! When passed an iterable it is used to instantiate a CodeBlock.
  634. Examples
  635. ========
  636. >>> from sympy import symbols, Range
  637. >>> from sympy.codegen.ast import aug_assign, For
  638. >>> x, i, j, k = symbols('x i j k')
  639. >>> for_i = For(i, Range(10), [aug_assign(x, '+', i*j*k)])
  640. >>> for_i # doctest: -NORMALIZE_WHITESPACE
  641. For(i, iterable=Range(0, 10, 1), body=CodeBlock(
  642. AddAugmentedAssignment(x, i*j*k)
  643. ))
  644. >>> for_ji = For(j, Range(7), [for_i])
  645. >>> for_ji # doctest: -NORMALIZE_WHITESPACE
  646. For(j, iterable=Range(0, 7, 1), body=CodeBlock(
  647. For(i, iterable=Range(0, 10, 1), body=CodeBlock(
  648. AddAugmentedAssignment(x, i*j*k)
  649. ))
  650. ))
  651. >>> for_kji =For(k, Range(5), [for_ji])
  652. >>> for_kji # doctest: -NORMALIZE_WHITESPACE
  653. For(k, iterable=Range(0, 5, 1), body=CodeBlock(
  654. For(j, iterable=Range(0, 7, 1), body=CodeBlock(
  655. For(i, iterable=Range(0, 10, 1), body=CodeBlock(
  656. AddAugmentedAssignment(x, i*j*k)
  657. ))
  658. ))
  659. ))
  660. """
  661. __slots__ = _fields = ('target', 'iterable', 'body')
  662. _construct_target = staticmethod(_sympify)
  663. @classmethod
  664. def _construct_body(cls, itr):
  665. if isinstance(itr, CodeBlock):
  666. return itr
  667. else:
  668. return CodeBlock(*itr)
  669. @classmethod
  670. def _construct_iterable(cls, itr):
  671. if not iterable(itr):
  672. raise TypeError("iterable must be an iterable")
  673. if isinstance(itr, list): # _sympify errors on lists because they are mutable
  674. itr = tuple(itr)
  675. return _sympify(itr)
  676. class String(Atom, Token):
  677. """ SymPy object representing a string.
  678. Atomic object which is not an expression (as opposed to Symbol).
  679. Parameters
  680. ==========
  681. text : str
  682. Examples
  683. ========
  684. >>> from sympy.codegen.ast import String
  685. >>> f = String('foo')
  686. >>> f
  687. foo
  688. >>> str(f)
  689. 'foo'
  690. >>> f.text
  691. 'foo'
  692. >>> print(repr(f))
  693. String('foo')
  694. """
  695. __slots__ = _fields = ('text',)
  696. not_in_args = ['text']
  697. is_Atom = True
  698. @classmethod
  699. def _construct_text(cls, text):
  700. if not isinstance(text, str):
  701. raise TypeError("Argument text is not a string type.")
  702. return text
  703. def _sympystr(self, printer, *args, **kwargs):
  704. return self.text
  705. def kwargs(self, exclude = (), apply = None):
  706. return {}
  707. #to be removed when Atom is given a suitable func
  708. @property
  709. def func(self):
  710. return lambda: self
  711. def _latex(self, printer):
  712. from sympy.printing.latex import latex_escape
  713. return r'\texttt{{"{}"}}'.format(latex_escape(self.text))
  714. class QuotedString(String):
  715. """ Represents a string which should be printed with quotes. """
  716. class Comment(String):
  717. """ Represents a comment. """
  718. class Node(Token):
  719. """ Subclass of Token, carrying the attribute 'attrs' (Tuple)
  720. Examples
  721. ========
  722. >>> from sympy.codegen.ast import Node, value_const, pointer_const
  723. >>> n1 = Node([value_const])
  724. >>> n1.attr_params('value_const') # get the parameters of attribute (by name)
  725. ()
  726. >>> from sympy.codegen.fnodes import dimension
  727. >>> n2 = Node([value_const, dimension(5, 3)])
  728. >>> n2.attr_params(value_const) # get the parameters of attribute (by Attribute instance)
  729. ()
  730. >>> n2.attr_params('dimension') # get the parameters of attribute (by name)
  731. (5, 3)
  732. >>> n2.attr_params(pointer_const) is None
  733. True
  734. """
  735. __slots__: tuple[str, ...] = ('attrs',)
  736. _fields = __slots__
  737. defaults: dict[str, Any] = {'attrs': Tuple()}
  738. _construct_attrs = staticmethod(_mk_Tuple)
  739. def attr_params(self, looking_for):
  740. """ Returns the parameters of the Attribute with name ``looking_for`` in self.attrs """
  741. for attr in self.attrs:
  742. if str(attr.name) == str(looking_for):
  743. return attr.parameters
  744. class Type(Token):
  745. """ Represents a type.
  746. Explanation
  747. ===========
  748. The naming is a super-set of NumPy naming. Type has a classmethod
  749. ``from_expr`` which offer type deduction. It also has a method
  750. ``cast_check`` which casts the argument to its type, possibly raising an
  751. exception if rounding error is not within tolerances, or if the value is not
  752. representable by the underlying data type (e.g. unsigned integers).
  753. Parameters
  754. ==========
  755. name : str
  756. Name of the type, e.g. ``object``, ``int16``, ``float16`` (where the latter two
  757. would use the ``Type`` sub-classes ``IntType`` and ``FloatType`` respectively).
  758. If a ``Type`` instance is given, the said instance is returned.
  759. Examples
  760. ========
  761. >>> from sympy.codegen.ast import Type
  762. >>> t = Type.from_expr(42)
  763. >>> t
  764. integer
  765. >>> print(repr(t))
  766. IntBaseType(String('integer'))
  767. >>> from sympy.codegen.ast import uint8
  768. >>> uint8.cast_check(-1) # doctest: +ELLIPSIS
  769. Traceback (most recent call last):
  770. ...
  771. ValueError: Minimum value for data type bigger than new value.
  772. >>> from sympy.codegen.ast import float32
  773. >>> v6 = 0.123456
  774. >>> float32.cast_check(v6)
  775. 0.123456
  776. >>> v10 = 12345.67894
  777. >>> float32.cast_check(v10) # doctest: +ELLIPSIS
  778. Traceback (most recent call last):
  779. ...
  780. ValueError: Casting gives a significantly different value.
  781. >>> boost_mp50 = Type('boost::multiprecision::cpp_dec_float_50')
  782. >>> from sympy import cxxcode
  783. >>> from sympy.codegen.ast import Declaration, Variable
  784. >>> cxxcode(Declaration(Variable('x', type=boost_mp50)))
  785. 'boost::multiprecision::cpp_dec_float_50 x'
  786. References
  787. ==========
  788. .. [1] https://numpy.org/doc/stable/user/basics.types.html
  789. """
  790. __slots__: tuple[str, ...] = ('name',)
  791. _fields = __slots__
  792. _construct_name = String
  793. def _sympystr(self, printer, *args, **kwargs):
  794. return str(self.name)
  795. @classmethod
  796. def from_expr(cls, expr):
  797. """ Deduces type from an expression or a ``Symbol``.
  798. Parameters
  799. ==========
  800. expr : number or SymPy object
  801. The type will be deduced from type or properties.
  802. Examples
  803. ========
  804. >>> from sympy.codegen.ast import Type, integer, complex_
  805. >>> Type.from_expr(2) == integer
  806. True
  807. >>> from sympy import Symbol
  808. >>> Type.from_expr(Symbol('z', complex=True)) == complex_
  809. True
  810. >>> Type.from_expr(sum) # doctest: +ELLIPSIS
  811. Traceback (most recent call last):
  812. ...
  813. ValueError: Could not deduce type from expr.
  814. Raises
  815. ======
  816. ValueError when type deduction fails.
  817. """
  818. if isinstance(expr, (float, Float)):
  819. return real
  820. if isinstance(expr, (int, Integer)) or getattr(expr, 'is_integer', False):
  821. return integer
  822. if getattr(expr, 'is_real', False):
  823. return real
  824. if isinstance(expr, complex) or getattr(expr, 'is_complex', False):
  825. return complex_
  826. if isinstance(expr, bool) or getattr(expr, 'is_Relational', False):
  827. return bool_
  828. else:
  829. raise ValueError("Could not deduce type from expr.")
  830. def _check(self, value):
  831. pass
  832. def cast_check(self, value, rtol=None, atol=0, precision_targets=None):
  833. """ Casts a value to the data type of the instance.
  834. Parameters
  835. ==========
  836. value : number
  837. rtol : floating point number
  838. Relative tolerance. (will be deduced if not given).
  839. atol : floating point number
  840. Absolute tolerance (in addition to ``rtol``).
  841. type_aliases : dict
  842. Maps substitutions for Type, e.g. {integer: int64, real: float32}
  843. Examples
  844. ========
  845. >>> from sympy.codegen.ast import integer, float32, int8
  846. >>> integer.cast_check(3.0) == 3
  847. True
  848. >>> float32.cast_check(1e-40) # doctest: +ELLIPSIS
  849. Traceback (most recent call last):
  850. ...
  851. ValueError: Minimum value for data type bigger than new value.
  852. >>> int8.cast_check(256) # doctest: +ELLIPSIS
  853. Traceback (most recent call last):
  854. ...
  855. ValueError: Maximum value for data type smaller than new value.
  856. >>> v10 = 12345.67894
  857. >>> float32.cast_check(v10) # doctest: +ELLIPSIS
  858. Traceback (most recent call last):
  859. ...
  860. ValueError: Casting gives a significantly different value.
  861. >>> from sympy.codegen.ast import float64
  862. >>> float64.cast_check(v10)
  863. 12345.67894
  864. >>> from sympy import Float
  865. >>> v18 = Float('0.123456789012345646')
  866. >>> float64.cast_check(v18)
  867. Traceback (most recent call last):
  868. ...
  869. ValueError: Casting gives a significantly different value.
  870. >>> from sympy.codegen.ast import float80
  871. >>> float80.cast_check(v18)
  872. 0.123456789012345649
  873. """
  874. val = sympify(value)
  875. ten = Integer(10)
  876. exp10 = getattr(self, 'decimal_dig', None)
  877. if rtol is None:
  878. rtol = 1e-15 if exp10 is None else 2.0*ten**(-exp10)
  879. def tol(num):
  880. return atol + rtol*abs(num)
  881. new_val = self.cast_nocheck(value)
  882. self._check(new_val)
  883. delta = new_val - val
  884. if abs(delta) > tol(val): # rounding, e.g. int(3.5) != 3.5
  885. raise ValueError("Casting gives a significantly different value.")
  886. return new_val
  887. def _latex(self, printer):
  888. from sympy.printing.latex import latex_escape
  889. type_name = latex_escape(self.__class__.__name__)
  890. name = latex_escape(self.name.text)
  891. return r"\text{{{}}}\left(\texttt{{{}}}\right)".format(type_name, name)
  892. class IntBaseType(Type):
  893. """ Integer base type, contains no size information. """
  894. __slots__ = ()
  895. cast_nocheck = lambda self, i: Integer(int(i))
  896. class _SizedIntType(IntBaseType):
  897. __slots__ = ('nbits',)
  898. _fields = Type._fields + __slots__
  899. _construct_nbits = Integer
  900. def _check(self, value):
  901. if value < self.min:
  902. raise ValueError("Value is too small: %d < %d" % (value, self.min))
  903. if value > self.max:
  904. raise ValueError("Value is too big: %d > %d" % (value, self.max))
  905. class SignedIntType(_SizedIntType):
  906. """ Represents a signed integer type. """
  907. __slots__ = ()
  908. @property
  909. def min(self):
  910. return -2**(self.nbits-1)
  911. @property
  912. def max(self):
  913. return 2**(self.nbits-1) - 1
  914. class UnsignedIntType(_SizedIntType):
  915. """ Represents an unsigned integer type. """
  916. __slots__ = ()
  917. @property
  918. def min(self):
  919. return 0
  920. @property
  921. def max(self):
  922. return 2**self.nbits - 1
  923. two = Integer(2)
  924. class FloatBaseType(Type):
  925. """ Represents a floating point number type. """
  926. __slots__ = ()
  927. cast_nocheck = Float
  928. class FloatType(FloatBaseType):
  929. """ Represents a floating point type with fixed bit width.
  930. Base 2 & one sign bit is assumed.
  931. Parameters
  932. ==========
  933. name : str
  934. Name of the type.
  935. nbits : integer
  936. Number of bits used (storage).
  937. nmant : integer
  938. Number of bits used to represent the mantissa.
  939. nexp : integer
  940. Number of bits used to represent the mantissa.
  941. Examples
  942. ========
  943. >>> from sympy import S
  944. >>> from sympy.codegen.ast import FloatType
  945. >>> half_precision = FloatType('f16', nbits=16, nmant=10, nexp=5)
  946. >>> half_precision.max
  947. 65504
  948. >>> half_precision.tiny == S(2)**-14
  949. True
  950. >>> half_precision.eps == S(2)**-10
  951. True
  952. >>> half_precision.dig == 3
  953. True
  954. >>> half_precision.decimal_dig == 5
  955. True
  956. >>> half_precision.cast_check(1.0)
  957. 1.0
  958. >>> half_precision.cast_check(1e5) # doctest: +ELLIPSIS
  959. Traceback (most recent call last):
  960. ...
  961. ValueError: Maximum value for data type smaller than new value.
  962. """
  963. __slots__ = ('nbits', 'nmant', 'nexp',)
  964. _fields = Type._fields + __slots__
  965. _construct_nbits = _construct_nmant = _construct_nexp = Integer
  966. @property
  967. def max_exponent(self):
  968. """ The largest positive number n, such that 2**(n - 1) is a representable finite value. """
  969. # cf. C++'s ``std::numeric_limits::max_exponent``
  970. return two**(self.nexp - 1)
  971. @property
  972. def min_exponent(self):
  973. """ The lowest negative number n, such that 2**(n - 1) is a valid normalized number. """
  974. # cf. C++'s ``std::numeric_limits::min_exponent``
  975. return 3 - self.max_exponent
  976. @property
  977. def max(self):
  978. """ Maximum value representable. """
  979. return (1 - two**-(self.nmant+1))*two**self.max_exponent
  980. @property
  981. def tiny(self):
  982. """ The minimum positive normalized value. """
  983. # See C macros: FLT_MIN, DBL_MIN, LDBL_MIN
  984. # or C++'s ``std::numeric_limits::min``
  985. # or numpy.finfo(dtype).tiny
  986. return two**(self.min_exponent - 1)
  987. @property
  988. def eps(self):
  989. """ Difference between 1.0 and the next representable value. """
  990. return two**(-self.nmant)
  991. @property
  992. def dig(self):
  993. """ Number of decimal digits that are guaranteed to be preserved in text.
  994. When converting text -> float -> text, you are guaranteed that at least ``dig``
  995. number of digits are preserved with respect to rounding or overflow.
  996. """
  997. from sympy.functions import floor, log
  998. return floor(self.nmant * log(2)/log(10))
  999. @property
  1000. def decimal_dig(self):
  1001. """ Number of digits needed to store & load without loss.
  1002. Explanation
  1003. ===========
  1004. Number of decimal digits needed to guarantee that two consecutive conversions
  1005. (float -> text -> float) to be idempotent. This is useful when one do not want
  1006. to loose precision due to rounding errors when storing a floating point value
  1007. as text.
  1008. """
  1009. from sympy.functions import ceiling, log
  1010. return ceiling((self.nmant + 1) * log(2)/log(10) + 1)
  1011. def cast_nocheck(self, value):
  1012. """ Casts without checking if out of bounds or subnormal. """
  1013. if value == oo: # float(oo) or oo
  1014. return float(oo)
  1015. elif value == -oo: # float(-oo) or -oo
  1016. return float(-oo)
  1017. return Float(str(sympify(value).evalf(self.decimal_dig)), self.decimal_dig)
  1018. def _check(self, value):
  1019. if value < -self.max:
  1020. raise ValueError("Value is too small: %d < %d" % (value, -self.max))
  1021. if value > self.max:
  1022. raise ValueError("Value is too big: %d > %d" % (value, self.max))
  1023. if abs(value) < self.tiny:
  1024. raise ValueError("Smallest (absolute) value for data type bigger than new value.")
  1025. class ComplexBaseType(FloatBaseType):
  1026. __slots__ = ()
  1027. def cast_nocheck(self, value):
  1028. """ Casts without checking if out of bounds or subnormal. """
  1029. from sympy.functions import re, im
  1030. return (
  1031. super().cast_nocheck(re(value)) +
  1032. super().cast_nocheck(im(value))*1j
  1033. )
  1034. def _check(self, value):
  1035. from sympy.functions import re, im
  1036. super()._check(re(value))
  1037. super()._check(im(value))
  1038. class ComplexType(ComplexBaseType, FloatType):
  1039. """ Represents a complex floating point number. """
  1040. __slots__ = ()
  1041. # NumPy types:
  1042. intc = IntBaseType('intc')
  1043. intp = IntBaseType('intp')
  1044. int8 = SignedIntType('int8', 8)
  1045. int16 = SignedIntType('int16', 16)
  1046. int32 = SignedIntType('int32', 32)
  1047. int64 = SignedIntType('int64', 64)
  1048. uint8 = UnsignedIntType('uint8', 8)
  1049. uint16 = UnsignedIntType('uint16', 16)
  1050. uint32 = UnsignedIntType('uint32', 32)
  1051. uint64 = UnsignedIntType('uint64', 64)
  1052. float16 = FloatType('float16', 16, nexp=5, nmant=10) # IEEE 754 binary16, Half precision
  1053. float32 = FloatType('float32', 32, nexp=8, nmant=23) # IEEE 754 binary32, Single precision
  1054. float64 = FloatType('float64', 64, nexp=11, nmant=52) # IEEE 754 binary64, Double precision
  1055. float80 = FloatType('float80', 80, nexp=15, nmant=63) # x86 extended precision (1 integer part bit), "long double"
  1056. float128 = FloatType('float128', 128, nexp=15, nmant=112) # IEEE 754 binary128, Quadruple precision
  1057. float256 = FloatType('float256', 256, nexp=19, nmant=236) # IEEE 754 binary256, Octuple precision
  1058. complex64 = ComplexType('complex64', nbits=64, **float32.kwargs(exclude=('name', 'nbits')))
  1059. complex128 = ComplexType('complex128', nbits=128, **float64.kwargs(exclude=('name', 'nbits')))
  1060. # Generic types (precision may be chosen by code printers):
  1061. untyped = Type('untyped')
  1062. real = FloatBaseType('real')
  1063. integer = IntBaseType('integer')
  1064. complex_ = ComplexBaseType('complex')
  1065. bool_ = Type('bool')
  1066. class Attribute(Token):
  1067. """ Attribute (possibly parametrized)
  1068. For use with :class:`sympy.codegen.ast.Node` (which takes instances of
  1069. ``Attribute`` as ``attrs``).
  1070. Parameters
  1071. ==========
  1072. name : str
  1073. parameters : Tuple
  1074. Examples
  1075. ========
  1076. >>> from sympy.codegen.ast import Attribute
  1077. >>> volatile = Attribute('volatile')
  1078. >>> volatile
  1079. volatile
  1080. >>> print(repr(volatile))
  1081. Attribute(String('volatile'))
  1082. >>> a = Attribute('foo', [1, 2, 3])
  1083. >>> a
  1084. foo(1, 2, 3)
  1085. >>> a.parameters == (1, 2, 3)
  1086. True
  1087. """
  1088. __slots__ = _fields = ('name', 'parameters')
  1089. defaults = {'parameters': Tuple()}
  1090. _construct_name = String
  1091. _construct_parameters = staticmethod(_mk_Tuple)
  1092. def _sympystr(self, printer, *args, **kwargs):
  1093. result = str(self.name)
  1094. if self.parameters:
  1095. result += '(%s)' % ', '.join((printer._print(
  1096. arg, *args, **kwargs) for arg in self.parameters))
  1097. return result
  1098. value_const = Attribute('value_const')
  1099. pointer_const = Attribute('pointer_const')
  1100. class Variable(Node):
  1101. """ Represents a variable.
  1102. Parameters
  1103. ==========
  1104. symbol : Symbol
  1105. type : Type (optional)
  1106. Type of the variable.
  1107. attrs : iterable of Attribute instances
  1108. Will be stored as a Tuple.
  1109. Examples
  1110. ========
  1111. >>> from sympy import Symbol
  1112. >>> from sympy.codegen.ast import Variable, float32, integer
  1113. >>> x = Symbol('x')
  1114. >>> v = Variable(x, type=float32)
  1115. >>> v.attrs
  1116. ()
  1117. >>> v == Variable('x')
  1118. False
  1119. >>> v == Variable('x', type=float32)
  1120. True
  1121. >>> v
  1122. Variable(x, type=float32)
  1123. One may also construct a ``Variable`` instance with the type deduced from
  1124. assumptions about the symbol using the ``deduced`` classmethod:
  1125. >>> i = Symbol('i', integer=True)
  1126. >>> v = Variable.deduced(i)
  1127. >>> v.type == integer
  1128. True
  1129. >>> v == Variable('i')
  1130. False
  1131. >>> from sympy.codegen.ast import value_const
  1132. >>> value_const in v.attrs
  1133. False
  1134. >>> w = Variable('w', attrs=[value_const])
  1135. >>> w
  1136. Variable(w, attrs=(value_const,))
  1137. >>> value_const in w.attrs
  1138. True
  1139. >>> w.as_Declaration(value=42)
  1140. Declaration(Variable(w, value=42, attrs=(value_const,)))
  1141. """
  1142. __slots__ = ('symbol', 'type', 'value')
  1143. _fields = __slots__ + Node._fields
  1144. defaults = Node.defaults.copy()
  1145. defaults.update({'type': untyped, 'value': none})
  1146. _construct_symbol = staticmethod(sympify)
  1147. _construct_value = staticmethod(sympify)
  1148. @classmethod
  1149. def deduced(cls, symbol, value=None, attrs=Tuple(), cast_check=True):
  1150. """ Alt. constructor with type deduction from ``Type.from_expr``.
  1151. Deduces type primarily from ``symbol``, secondarily from ``value``.
  1152. Parameters
  1153. ==========
  1154. symbol : Symbol
  1155. value : expr
  1156. (optional) value of the variable.
  1157. attrs : iterable of Attribute instances
  1158. cast_check : bool
  1159. Whether to apply ``Type.cast_check`` on ``value``.
  1160. Examples
  1161. ========
  1162. >>> from sympy import Symbol
  1163. >>> from sympy.codegen.ast import Variable, complex_
  1164. >>> n = Symbol('n', integer=True)
  1165. >>> str(Variable.deduced(n).type)
  1166. 'integer'
  1167. >>> x = Symbol('x', real=True)
  1168. >>> v = Variable.deduced(x)
  1169. >>> v.type
  1170. real
  1171. >>> z = Symbol('z', complex=True)
  1172. >>> Variable.deduced(z).type == complex_
  1173. True
  1174. """
  1175. if isinstance(symbol, Variable):
  1176. return symbol
  1177. try:
  1178. type_ = Type.from_expr(symbol)
  1179. except ValueError:
  1180. type_ = Type.from_expr(value)
  1181. if value is not None and cast_check:
  1182. value = type_.cast_check(value)
  1183. return cls(symbol, type=type_, value=value, attrs=attrs)
  1184. def as_Declaration(self, **kwargs):
  1185. """ Convenience method for creating a Declaration instance.
  1186. Explanation
  1187. ===========
  1188. If the variable of the Declaration need to wrap a modified
  1189. variable keyword arguments may be passed (overriding e.g.
  1190. the ``value`` of the Variable instance).
  1191. Examples
  1192. ========
  1193. >>> from sympy.codegen.ast import Variable, NoneToken
  1194. >>> x = Variable('x')
  1195. >>> decl1 = x.as_Declaration()
  1196. >>> # value is special NoneToken() which must be tested with == operator
  1197. >>> decl1.variable.value is None # won't work
  1198. False
  1199. >>> decl1.variable.value == None # not PEP-8 compliant
  1200. True
  1201. >>> decl1.variable.value == NoneToken() # OK
  1202. True
  1203. >>> decl2 = x.as_Declaration(value=42.0)
  1204. >>> decl2.variable.value == 42.0
  1205. True
  1206. """
  1207. kw = self.kwargs()
  1208. kw.update(kwargs)
  1209. return Declaration(self.func(**kw))
  1210. def _relation(self, rhs, op):
  1211. try:
  1212. rhs = _sympify(rhs)
  1213. except SympifyError:
  1214. raise TypeError("Invalid comparison %s < %s" % (self, rhs))
  1215. return op(self, rhs, evaluate=False)
  1216. __lt__ = lambda self, other: self._relation(other, Lt)
  1217. __le__ = lambda self, other: self._relation(other, Le)
  1218. __ge__ = lambda self, other: self._relation(other, Ge)
  1219. __gt__ = lambda self, other: self._relation(other, Gt)
  1220. class Pointer(Variable):
  1221. """ Represents a pointer. See ``Variable``.
  1222. Examples
  1223. ========
  1224. Can create instances of ``Element``:
  1225. >>> from sympy import Symbol
  1226. >>> from sympy.codegen.ast import Pointer
  1227. >>> i = Symbol('i', integer=True)
  1228. >>> p = Pointer('x')
  1229. >>> p[i+1]
  1230. Element(x, indices=(i + 1,))
  1231. """
  1232. __slots__ = ()
  1233. def __getitem__(self, key):
  1234. try:
  1235. return Element(self.symbol, key)
  1236. except TypeError:
  1237. return Element(self.symbol, (key,))
  1238. class Element(Token):
  1239. """ Element in (a possibly N-dimensional) array.
  1240. Examples
  1241. ========
  1242. >>> from sympy.codegen.ast import Element
  1243. >>> elem = Element('x', 'ijk')
  1244. >>> elem.symbol.name == 'x'
  1245. True
  1246. >>> elem.indices
  1247. (i, j, k)
  1248. >>> from sympy import ccode
  1249. >>> ccode(elem)
  1250. 'x[i][j][k]'
  1251. >>> ccode(Element('x', 'ijk', strides='lmn', offset='o'))
  1252. 'x[i*l + j*m + k*n + o]'
  1253. """
  1254. __slots__ = _fields = ('symbol', 'indices', 'strides', 'offset')
  1255. defaults = {'strides': none, 'offset': none}
  1256. _construct_symbol = staticmethod(sympify)
  1257. _construct_indices = staticmethod(lambda arg: Tuple(*arg))
  1258. _construct_strides = staticmethod(lambda arg: Tuple(*arg))
  1259. _construct_offset = staticmethod(sympify)
  1260. class Declaration(Token):
  1261. """ Represents a variable declaration
  1262. Parameters
  1263. ==========
  1264. variable : Variable
  1265. Examples
  1266. ========
  1267. >>> from sympy.codegen.ast import Declaration, NoneToken, untyped
  1268. >>> z = Declaration('z')
  1269. >>> z.variable.type == untyped
  1270. True
  1271. >>> # value is special NoneToken() which must be tested with == operator
  1272. >>> z.variable.value is None # won't work
  1273. False
  1274. >>> z.variable.value == None # not PEP-8 compliant
  1275. True
  1276. >>> z.variable.value == NoneToken() # OK
  1277. True
  1278. """
  1279. __slots__ = _fields = ('variable',)
  1280. _construct_variable = Variable
  1281. class While(Token):
  1282. """ Represents a 'for-loop' in the code.
  1283. Expressions are of the form:
  1284. "while condition:
  1285. body..."
  1286. Parameters
  1287. ==========
  1288. condition : expression convertible to Boolean
  1289. body : CodeBlock or iterable
  1290. When passed an iterable it is used to instantiate a CodeBlock.
  1291. Examples
  1292. ========
  1293. >>> from sympy import symbols, Gt, Abs
  1294. >>> from sympy.codegen import aug_assign, Assignment, While
  1295. >>> x, dx = symbols('x dx')
  1296. >>> expr = 1 - x**2
  1297. >>> whl = While(Gt(Abs(dx), 1e-9), [
  1298. ... Assignment(dx, -expr/expr.diff(x)),
  1299. ... aug_assign(x, '+', dx)
  1300. ... ])
  1301. """
  1302. __slots__ = _fields = ('condition', 'body')
  1303. _construct_condition = staticmethod(lambda cond: _sympify(cond))
  1304. @classmethod
  1305. def _construct_body(cls, itr):
  1306. if isinstance(itr, CodeBlock):
  1307. return itr
  1308. else:
  1309. return CodeBlock(*itr)
  1310. class Scope(Token):
  1311. """ Represents a scope in the code.
  1312. Parameters
  1313. ==========
  1314. body : CodeBlock or iterable
  1315. When passed an iterable it is used to instantiate a CodeBlock.
  1316. """
  1317. __slots__ = _fields = ('body',)
  1318. @classmethod
  1319. def _construct_body(cls, itr):
  1320. if isinstance(itr, CodeBlock):
  1321. return itr
  1322. else:
  1323. return CodeBlock(*itr)
  1324. class Stream(Token):
  1325. """ Represents a stream.
  1326. There are two predefined Stream instances ``stdout`` & ``stderr``.
  1327. Parameters
  1328. ==========
  1329. name : str
  1330. Examples
  1331. ========
  1332. >>> from sympy import pycode, Symbol
  1333. >>> from sympy.codegen.ast import Print, stderr, QuotedString
  1334. >>> print(pycode(Print(['x'], file=stderr)))
  1335. print(x, file=sys.stderr)
  1336. >>> x = Symbol('x')
  1337. >>> print(pycode(Print([QuotedString('x')], file=stderr))) # print literally "x"
  1338. print("x", file=sys.stderr)
  1339. """
  1340. __slots__ = _fields = ('name',)
  1341. _construct_name = String
  1342. stdout = Stream('stdout')
  1343. stderr = Stream('stderr')
  1344. class Print(Token):
  1345. """ Represents print command in the code.
  1346. Parameters
  1347. ==========
  1348. formatstring : str
  1349. *args : Basic instances (or convertible to such through sympify)
  1350. Examples
  1351. ========
  1352. >>> from sympy.codegen.ast import Print
  1353. >>> from sympy import pycode
  1354. >>> print(pycode(Print('x y'.split(), "coordinate: %12.5g %12.5g")))
  1355. print("coordinate: %12.5g %12.5g" % (x, y))
  1356. """
  1357. __slots__ = _fields = ('print_args', 'format_string', 'file')
  1358. defaults = {'format_string': none, 'file': none}
  1359. _construct_print_args = staticmethod(_mk_Tuple)
  1360. _construct_format_string = QuotedString
  1361. _construct_file = Stream
  1362. class FunctionPrototype(Node):
  1363. """ Represents a function prototype
  1364. Allows the user to generate forward declaration in e.g. C/C++.
  1365. Parameters
  1366. ==========
  1367. return_type : Type
  1368. name : str
  1369. parameters: iterable of Variable instances
  1370. attrs : iterable of Attribute instances
  1371. Examples
  1372. ========
  1373. >>> from sympy import ccode, symbols
  1374. >>> from sympy.codegen.ast import real, FunctionPrototype
  1375. >>> x, y = symbols('x y', real=True)
  1376. >>> fp = FunctionPrototype(real, 'foo', [x, y])
  1377. >>> ccode(fp)
  1378. 'double foo(double x, double y)'
  1379. """
  1380. __slots__ = ('return_type', 'name', 'parameters')
  1381. _fields: tuple[str, ...] = __slots__ + Node._fields
  1382. _construct_return_type = Type
  1383. _construct_name = String
  1384. @staticmethod
  1385. def _construct_parameters(args):
  1386. def _var(arg):
  1387. if isinstance(arg, Declaration):
  1388. return arg.variable
  1389. elif isinstance(arg, Variable):
  1390. return arg
  1391. else:
  1392. return Variable.deduced(arg)
  1393. return Tuple(*map(_var, args))
  1394. @classmethod
  1395. def from_FunctionDefinition(cls, func_def):
  1396. if not isinstance(func_def, FunctionDefinition):
  1397. raise TypeError("func_def is not an instance of FunctionDefinition")
  1398. return cls(**func_def.kwargs(exclude=('body',)))
  1399. class FunctionDefinition(FunctionPrototype):
  1400. """ Represents a function definition in the code.
  1401. Parameters
  1402. ==========
  1403. return_type : Type
  1404. name : str
  1405. parameters: iterable of Variable instances
  1406. body : CodeBlock or iterable
  1407. attrs : iterable of Attribute instances
  1408. Examples
  1409. ========
  1410. >>> from sympy import ccode, symbols
  1411. >>> from sympy.codegen.ast import real, FunctionPrototype
  1412. >>> x, y = symbols('x y', real=True)
  1413. >>> fp = FunctionPrototype(real, 'foo', [x, y])
  1414. >>> ccode(fp)
  1415. 'double foo(double x, double y)'
  1416. >>> from sympy.codegen.ast import FunctionDefinition, Return
  1417. >>> body = [Return(x*y)]
  1418. >>> fd = FunctionDefinition.from_FunctionPrototype(fp, body)
  1419. >>> print(ccode(fd))
  1420. double foo(double x, double y){
  1421. return x*y;
  1422. }
  1423. """
  1424. __slots__ = ('body', )
  1425. _fields = FunctionPrototype._fields[:-1] + __slots__ + Node._fields
  1426. @classmethod
  1427. def _construct_body(cls, itr):
  1428. if isinstance(itr, CodeBlock):
  1429. return itr
  1430. else:
  1431. return CodeBlock(*itr)
  1432. @classmethod
  1433. def from_FunctionPrototype(cls, func_proto, body):
  1434. if not isinstance(func_proto, FunctionPrototype):
  1435. raise TypeError("func_proto is not an instance of FunctionPrototype")
  1436. return cls(body=body, **func_proto.kwargs())
  1437. class Return(Token):
  1438. """ Represents a return command in the code.
  1439. Parameters
  1440. ==========
  1441. return : Basic
  1442. Examples
  1443. ========
  1444. >>> from sympy.codegen.ast import Return
  1445. >>> from sympy.printing.pycode import pycode
  1446. >>> from sympy import Symbol
  1447. >>> x = Symbol('x')
  1448. >>> print(pycode(Return(x)))
  1449. return x
  1450. """
  1451. __slots__ = _fields = ('return',)
  1452. _construct_return=staticmethod(_sympify)
  1453. class FunctionCall(Token, Expr):
  1454. """ Represents a call to a function in the code.
  1455. Parameters
  1456. ==========
  1457. name : str
  1458. function_args : Tuple
  1459. Examples
  1460. ========
  1461. >>> from sympy.codegen.ast import FunctionCall
  1462. >>> from sympy import pycode
  1463. >>> fcall = FunctionCall('foo', 'bar baz'.split())
  1464. >>> print(pycode(fcall))
  1465. foo(bar, baz)
  1466. """
  1467. __slots__ = _fields = ('name', 'function_args')
  1468. _construct_name = String
  1469. _construct_function_args = staticmethod(lambda args: Tuple(*args))