aesaracode.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. from __future__ import annotations
  2. from typing import Any
  3. from sympy.external import import_module
  4. from sympy.printing.printer import Printer
  5. from sympy.utilities.iterables import is_sequence
  6. import sympy
  7. from functools import partial
  8. aesara = import_module('aesara')
  9. if aesara:
  10. aes = aesara.scalar
  11. aet = aesara.tensor
  12. from aesara.tensor import nlinalg
  13. from aesara.tensor.elemwise import Elemwise
  14. from aesara.tensor.elemwise import DimShuffle
  15. # `true_divide` replaced `true_div` in Aesara 2.8.11 (released 2023) to
  16. # match NumPy
  17. # XXX: Remove this when not needed to support older versions.
  18. true_divide = getattr(aet, 'true_divide', None)
  19. if true_divide is None:
  20. true_divide = aet.true_div
  21. mapping = {
  22. sympy.Add: aet.add,
  23. sympy.Mul: aet.mul,
  24. sympy.Abs: aet.abs,
  25. sympy.sign: aet.sgn,
  26. sympy.ceiling: aet.ceil,
  27. sympy.floor: aet.floor,
  28. sympy.log: aet.log,
  29. sympy.exp: aet.exp,
  30. sympy.sqrt: aet.sqrt,
  31. sympy.cos: aet.cos,
  32. sympy.acos: aet.arccos,
  33. sympy.sin: aet.sin,
  34. sympy.asin: aet.arcsin,
  35. sympy.tan: aet.tan,
  36. sympy.atan: aet.arctan,
  37. sympy.atan2: aet.arctan2,
  38. sympy.cosh: aet.cosh,
  39. sympy.acosh: aet.arccosh,
  40. sympy.sinh: aet.sinh,
  41. sympy.asinh: aet.arcsinh,
  42. sympy.tanh: aet.tanh,
  43. sympy.atanh: aet.arctanh,
  44. sympy.re: aet.real,
  45. sympy.im: aet.imag,
  46. sympy.arg: aet.angle,
  47. sympy.erf: aet.erf,
  48. sympy.gamma: aet.gamma,
  49. sympy.loggamma: aet.gammaln,
  50. sympy.Pow: aet.pow,
  51. sympy.Eq: aet.eq,
  52. sympy.StrictGreaterThan: aet.gt,
  53. sympy.StrictLessThan: aet.lt,
  54. sympy.LessThan: aet.le,
  55. sympy.GreaterThan: aet.ge,
  56. sympy.And: aet.bitwise_and, # bitwise
  57. sympy.Or: aet.bitwise_or, # bitwise
  58. sympy.Not: aet.invert, # bitwise
  59. sympy.Xor: aet.bitwise_xor, # bitwise
  60. sympy.Max: aet.maximum, # Sympy accept >2 inputs, Aesara only 2
  61. sympy.Min: aet.minimum, # Sympy accept >2 inputs, Aesara only 2
  62. sympy.conjugate: aet.conj,
  63. sympy.core.numbers.ImaginaryUnit: lambda:aet.complex(0,1),
  64. # Matrices
  65. sympy.MatAdd: Elemwise(aes.add),
  66. sympy.HadamardProduct: Elemwise(aes.mul),
  67. sympy.Trace: nlinalg.trace,
  68. sympy.Determinant : nlinalg.det,
  69. sympy.Inverse: nlinalg.matrix_inverse,
  70. sympy.Transpose: DimShuffle((False, False), [1, 0]),
  71. }
  72. class AesaraPrinter(Printer):
  73. """ Code printer which creates Aesara symbolic expression graphs.
  74. Parameters
  75. ==========
  76. cache : dict
  77. Cache dictionary to use. If None (default) will use
  78. the global cache. To create a printer which does not depend on or alter
  79. global state pass an empty dictionary. Note: the dictionary is not
  80. copied on initialization of the printer and will be updated in-place,
  81. so using the same dict object when creating multiple printers or making
  82. multiple calls to :func:`.aesara_code` or :func:`.aesara_function` means
  83. the cache is shared between all these applications.
  84. Attributes
  85. ==========
  86. cache : dict
  87. A cache of Aesara variables which have been created for SymPy
  88. symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
  89. :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
  90. ensure that all references to a given symbol in an expression (or
  91. multiple expressions) are printed as the same Aesara variable, which is
  92. created only once. Symbols are differentiated only by name and type. The
  93. format of the cache's contents should be considered opaque to the user.
  94. """
  95. printmethod = "_aesara"
  96. def __init__(self, *args, **kwargs):
  97. self.cache = kwargs.pop('cache', {})
  98. super().__init__(*args, **kwargs)
  99. def _get_key(self, s, name=None, dtype=None, broadcastable=None):
  100. """ Get the cache key for a SymPy object.
  101. Parameters
  102. ==========
  103. s : sympy.core.basic.Basic
  104. SymPy object to get key for.
  105. name : str
  106. Name of object, if it does not have a ``name`` attribute.
  107. """
  108. if name is None:
  109. name = s.name
  110. return (name, type(s), s.args, dtype, broadcastable)
  111. def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
  112. """
  113. Get the Aesara variable for a SymPy symbol from the cache, or create it
  114. if it does not exist.
  115. """
  116. # Defaults
  117. if name is None:
  118. name = s.name
  119. if dtype is None:
  120. dtype = 'floatX'
  121. if broadcastable is None:
  122. broadcastable = ()
  123. key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
  124. if key in self.cache:
  125. return self.cache[key]
  126. value = aet.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
  127. self.cache[key] = value
  128. return value
  129. def _print_Symbol(self, s, **kwargs):
  130. dtype = kwargs.get('dtypes', {}).get(s)
  131. bc = kwargs.get('broadcastables', {}).get(s)
  132. return self._get_or_create(s, dtype=dtype, broadcastable=bc)
  133. def _print_AppliedUndef(self, s, **kwargs):
  134. name = str(type(s)) + '_' + str(s.args[0])
  135. dtype = kwargs.get('dtypes', {}).get(s)
  136. bc = kwargs.get('broadcastables', {}).get(s)
  137. return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
  138. def _print_Basic(self, expr, **kwargs):
  139. op = mapping[type(expr)]
  140. children = [self._print(arg, **kwargs) for arg in expr.args]
  141. return op(*children)
  142. def _print_Number(self, n, **kwargs):
  143. # Integers already taken care of below, interpret as float
  144. return float(n.evalf())
  145. def _print_MatrixSymbol(self, X, **kwargs):
  146. dtype = kwargs.get('dtypes', {}).get(X)
  147. return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
  148. def _print_DenseMatrix(self, X, **kwargs):
  149. if not hasattr(aet, 'stacklists'):
  150. raise NotImplementedError(
  151. "Matrix translation not yet supported in this version of Aesara")
  152. return aet.stacklists([
  153. [self._print(arg, **kwargs) for arg in L]
  154. for L in X.tolist()
  155. ])
  156. _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
  157. def _print_MatMul(self, expr, **kwargs):
  158. children = [self._print(arg, **kwargs) for arg in expr.args]
  159. result = children[0]
  160. for child in children[1:]:
  161. result = aet.dot(result, child)
  162. return result
  163. def _print_MatPow(self, expr, **kwargs):
  164. children = [self._print(arg, **kwargs) for arg in expr.args]
  165. result = 1
  166. if isinstance(children[1], int) and children[1] > 0:
  167. for i in range(children[1]):
  168. result = aet.dot(result, children[0])
  169. else:
  170. raise NotImplementedError('''Only non-negative integer
  171. powers of matrices can be handled by Aesara at the moment''')
  172. return result
  173. def _print_MatrixSlice(self, expr, **kwargs):
  174. parent = self._print(expr.parent, **kwargs)
  175. rowslice = self._print(slice(*expr.rowslice), **kwargs)
  176. colslice = self._print(slice(*expr.colslice), **kwargs)
  177. return parent[rowslice, colslice]
  178. def _print_BlockMatrix(self, expr, **kwargs):
  179. nrows, ncols = expr.blocks.shape
  180. blocks = [[self._print(expr.blocks[r, c], **kwargs)
  181. for c in range(ncols)]
  182. for r in range(nrows)]
  183. return aet.join(0, *[aet.join(1, *row) for row in blocks])
  184. def _print_slice(self, expr, **kwargs):
  185. return slice(*[self._print(i, **kwargs)
  186. if isinstance(i, sympy.Basic) else i
  187. for i in (expr.start, expr.stop, expr.step)])
  188. def _print_Pi(self, expr, **kwargs):
  189. return 3.141592653589793
  190. def _print_Piecewise(self, expr, **kwargs):
  191. import numpy as np
  192. e, cond = expr.args[0].args # First condition and corresponding value
  193. # Print conditional expression and value for first condition
  194. p_cond = self._print(cond, **kwargs)
  195. p_e = self._print(e, **kwargs)
  196. # One condition only
  197. if len(expr.args) == 1:
  198. # Return value if condition else NaN
  199. return aet.switch(p_cond, p_e, np.nan)
  200. # Return value_1 if condition_1 else evaluate remaining conditions
  201. p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
  202. return aet.switch(p_cond, p_e, p_remaining)
  203. def _print_Rational(self, expr, **kwargs):
  204. return true_divide(self._print(expr.p, **kwargs),
  205. self._print(expr.q, **kwargs))
  206. def _print_Integer(self, expr, **kwargs):
  207. return expr.p
  208. def _print_factorial(self, expr, **kwargs):
  209. return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
  210. def _print_Derivative(self, deriv, **kwargs):
  211. from aesara.gradient import Rop
  212. rv = self._print(deriv.expr, **kwargs)
  213. for var in deriv.variables:
  214. var = self._print(var, **kwargs)
  215. rv = Rop(rv, var, aet.ones_like(var))
  216. return rv
  217. def emptyPrinter(self, expr):
  218. return expr
  219. def doprint(self, expr, dtypes=None, broadcastables=None):
  220. """ Convert a SymPy expression to a Aesara graph variable.
  221. The ``dtypes`` and ``broadcastables`` arguments are used to specify the
  222. data type, dimension, and broadcasting behavior of the Aesara variables
  223. corresponding to the free symbols in ``expr``. Each is a mapping from
  224. SymPy symbols to the value of the corresponding argument to
  225. ``aesara.tensor.var.TensorVariable``.
  226. See the corresponding `documentation page`__ for more information on
  227. broadcasting in Aesara.
  228. .. __: https://aesara.readthedocs.io/en/latest/tutorial/broadcasting.html
  229. Parameters
  230. ==========
  231. expr : sympy.core.expr.Expr
  232. SymPy expression to print.
  233. dtypes : dict
  234. Mapping from SymPy symbols to Aesara datatypes to use when creating
  235. new Aesara variables for those symbols. Corresponds to the ``dtype``
  236. argument to ``aesara.tensor.var.TensorVariable``. Defaults to ``'floatX'``
  237. for symbols not included in the mapping.
  238. broadcastables : dict
  239. Mapping from SymPy symbols to the value of the ``broadcastable``
  240. argument to ``aesara.tensor.var.TensorVariable`` to use when creating Aesara
  241. variables for those symbols. Defaults to the empty tuple for symbols
  242. not included in the mapping (resulting in a scalar).
  243. Returns
  244. =======
  245. aesara.graph.basic.Variable
  246. A variable corresponding to the expression's value in a Aesara
  247. symbolic expression graph.
  248. """
  249. if dtypes is None:
  250. dtypes = {}
  251. if broadcastables is None:
  252. broadcastables = {}
  253. return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
  254. global_cache: dict[Any, Any] = {}
  255. def aesara_code(expr, cache=None, **kwargs):
  256. """
  257. Convert a SymPy expression into a Aesara graph variable.
  258. Parameters
  259. ==========
  260. expr : sympy.core.expr.Expr
  261. SymPy expression object to convert.
  262. cache : dict
  263. Cached Aesara variables (see :class:`AesaraPrinter.cache
  264. <AesaraPrinter>`). Defaults to the module-level global cache.
  265. dtypes : dict
  266. Passed to :meth:`.AesaraPrinter.doprint`.
  267. broadcastables : dict
  268. Passed to :meth:`.AesaraPrinter.doprint`.
  269. Returns
  270. =======
  271. aesara.graph.basic.Variable
  272. A variable corresponding to the expression's value in a Aesara symbolic
  273. expression graph.
  274. """
  275. if not aesara:
  276. raise ImportError("aesara is required for aesara_code")
  277. if cache is None:
  278. cache = global_cache
  279. return AesaraPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
  280. def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
  281. r"""
  282. Get value of ``broadcastables`` argument to :func:`.aesara_code` from
  283. keyword arguments to :func:`.aesara_function`.
  284. Included for backwards compatibility.
  285. Parameters
  286. ==========
  287. inputs
  288. Sequence of input symbols.
  289. dim : int
  290. Common number of dimensions for all inputs. Overrides other arguments
  291. if given.
  292. dims : dict
  293. Mapping from input symbols to number of dimensions. Overrides
  294. ``broadcastables`` argument if given.
  295. broadcastables : dict
  296. Explicit value of ``broadcastables`` argument to
  297. :meth:`.AesaraPrinter.doprint`. If not None function will return this value unchanged.
  298. Returns
  299. =======
  300. dict
  301. Dictionary mapping elements of ``inputs`` to their "broadcastable"
  302. values (tuple of ``bool``\ s).
  303. """
  304. if dim is not None:
  305. return {s: (False,) * dim for s in inputs}
  306. if dims is not None:
  307. maxdim = max(dims.values())
  308. return {
  309. s: (False,) * d + (True,) * (maxdim - d)
  310. for s, d in dims.items()
  311. }
  312. if broadcastables is not None:
  313. return broadcastables
  314. return {}
  315. def aesara_function(inputs, outputs, scalar=False, *,
  316. dim=None, dims=None, broadcastables=None, **kwargs):
  317. """
  318. Create a Aesara function from SymPy expressions.
  319. The inputs and outputs are converted to Aesara variables using
  320. :func:`.aesara_code` and then passed to ``aesara.function``.
  321. Parameters
  322. ==========
  323. inputs
  324. Sequence of symbols which constitute the inputs of the function.
  325. outputs
  326. Sequence of expressions which constitute the outputs(s) of the
  327. function. The free symbols of each expression must be a subset of
  328. ``inputs``.
  329. scalar : bool
  330. Convert 0-dimensional arrays in output to scalars. This will return a
  331. Python wrapper function around the Aesara function object.
  332. cache : dict
  333. Cached Aesara variables (see :class:`AesaraPrinter.cache
  334. <AesaraPrinter>`). Defaults to the module-level global cache.
  335. dtypes : dict
  336. Passed to :meth:`.AesaraPrinter.doprint`.
  337. broadcastables : dict
  338. Passed to :meth:`.AesaraPrinter.doprint`.
  339. dims : dict
  340. Alternative to ``broadcastables`` argument. Mapping from elements of
  341. ``inputs`` to integers indicating the dimension of their associated
  342. arrays/tensors. Overrides ``broadcastables`` argument if given.
  343. dim : int
  344. Another alternative to the ``broadcastables`` argument. Common number of
  345. dimensions to use for all arrays/tensors.
  346. ``aesara_function([x, y], [...], dim=2)`` is equivalent to using
  347. ``broadcastables={x: (False, False), y: (False, False)}``.
  348. Returns
  349. =======
  350. callable
  351. A callable object which takes values of ``inputs`` as positional
  352. arguments and returns an output array for each of the expressions
  353. in ``outputs``. If ``outputs`` is a single expression the function will
  354. return a Numpy array, if it is a list of multiple expressions the
  355. function will return a list of arrays. See description of the ``squeeze``
  356. argument above for the behavior when a single output is passed in a list.
  357. The returned object will either be an instance of
  358. ``aesara.compile.function.types.Function`` or a Python wrapper
  359. function around one. In both cases, the returned value will have a
  360. ``aesara_function`` attribute which points to the return value of
  361. ``aesara.function``.
  362. Examples
  363. ========
  364. >>> from sympy.abc import x, y, z
  365. >>> from sympy.printing.aesaracode import aesara_function
  366. A simple function with one input and one output:
  367. >>> f1 = aesara_function([x], [x**2 - 1], scalar=True)
  368. >>> f1(3)
  369. 8.0
  370. A function with multiple inputs and one output:
  371. >>> f2 = aesara_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
  372. >>> f2(3, 4, 2)
  373. 5.0
  374. A function with multiple inputs and multiple outputs:
  375. >>> f3 = aesara_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
  376. >>> f3(2, 3)
  377. [13.0, -5.0]
  378. See also
  379. ========
  380. dim_handling
  381. """
  382. if not aesara:
  383. raise ImportError("Aesara is required for aesara_function")
  384. # Pop off non-aesara keyword args
  385. cache = kwargs.pop('cache', {})
  386. dtypes = kwargs.pop('dtypes', {})
  387. broadcastables = dim_handling(
  388. inputs, dim=dim, dims=dims, broadcastables=broadcastables,
  389. )
  390. # Print inputs/outputs
  391. code = partial(aesara_code, cache=cache, dtypes=dtypes,
  392. broadcastables=broadcastables)
  393. tinputs = list(map(code, inputs))
  394. toutputs = list(map(code, outputs))
  395. #fix constant expressions as variables
  396. toutputs = [output if isinstance(output, aesara.graph.basic.Variable) else aet.as_tensor_variable(output) for output in toutputs]
  397. if len(toutputs) == 1:
  398. toutputs = toutputs[0]
  399. # Compile aesara func
  400. func = aesara.function(tinputs, toutputs, **kwargs)
  401. is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
  402. # No wrapper required
  403. if not scalar or not any(is_0d):
  404. func.aesara_function = func
  405. return func
  406. # Create wrapper to convert 0-dimensional outputs to scalars
  407. def wrapper(*args):
  408. out = func(*args)
  409. # out can be array(1.0) or [array(1.0), array(2.0)]
  410. if is_sequence(out):
  411. return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
  412. else:
  413. return out[()]
  414. wrapper.__wrapped__ = func
  415. wrapper.__doc__ = func.__doc__
  416. wrapper.aesara_function = func
  417. return wrapper