theanocode.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. """
  2. .. deprecated:: 1.8
  3. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  4. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  5. :ref:`theanocode-deprecated` for more information.
  6. """
  7. from __future__ import annotations
  8. from typing import Any
  9. from sympy.external import import_module
  10. from sympy.printing.printer import Printer
  11. from sympy.utilities.iterables import is_sequence
  12. import sympy
  13. from functools import partial
  14. from sympy.utilities.decorator import doctest_depends_on
  15. from sympy.utilities.exceptions import sympy_deprecation_warning
  16. theano = import_module('theano')
  17. if theano:
  18. ts = theano.scalar
  19. tt = theano.tensor
  20. from theano.sandbox import linalg as tlinalg
  21. mapping = {
  22. sympy.Add: tt.add,
  23. sympy.Mul: tt.mul,
  24. sympy.Abs: tt.abs_,
  25. sympy.sign: tt.sgn,
  26. sympy.ceiling: tt.ceil,
  27. sympy.floor: tt.floor,
  28. sympy.log: tt.log,
  29. sympy.exp: tt.exp,
  30. sympy.sqrt: tt.sqrt,
  31. sympy.cos: tt.cos,
  32. sympy.acos: tt.arccos,
  33. sympy.sin: tt.sin,
  34. sympy.asin: tt.arcsin,
  35. sympy.tan: tt.tan,
  36. sympy.atan: tt.arctan,
  37. sympy.atan2: tt.arctan2,
  38. sympy.cosh: tt.cosh,
  39. sympy.acosh: tt.arccosh,
  40. sympy.sinh: tt.sinh,
  41. sympy.asinh: tt.arcsinh,
  42. sympy.tanh: tt.tanh,
  43. sympy.atanh: tt.arctanh,
  44. sympy.re: tt.real,
  45. sympy.im: tt.imag,
  46. sympy.arg: tt.angle,
  47. sympy.erf: tt.erf,
  48. sympy.gamma: tt.gamma,
  49. sympy.loggamma: tt.gammaln,
  50. sympy.Pow: tt.pow,
  51. sympy.Eq: tt.eq,
  52. sympy.StrictGreaterThan: tt.gt,
  53. sympy.StrictLessThan: tt.lt,
  54. sympy.LessThan: tt.le,
  55. sympy.GreaterThan: tt.ge,
  56. sympy.And: tt.and_,
  57. sympy.Or: tt.or_,
  58. sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2
  59. sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2
  60. sympy.conjugate: tt.conj,
  61. sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1),
  62. # Matrices
  63. sympy.MatAdd: tt.Elemwise(ts.add),
  64. sympy.HadamardProduct: tt.Elemwise(ts.mul),
  65. sympy.Trace: tlinalg.trace,
  66. sympy.Determinant : tlinalg.det,
  67. sympy.Inverse: tlinalg.matrix_inverse,
  68. sympy.Transpose: tt.DimShuffle((False, False), [1, 0]),
  69. }
  70. class TheanoPrinter(Printer):
  71. """ Code printer which creates Theano symbolic expression graphs.
  72. Parameters
  73. ==========
  74. cache : dict
  75. Cache dictionary to use. If None (default) will use
  76. the global cache. To create a printer which does not depend on or alter
  77. global state pass an empty dictionary. Note: the dictionary is not
  78. copied on initialization of the printer and will be updated in-place,
  79. so using the same dict object when creating multiple printers or making
  80. multiple calls to :func:`.theano_code` or :func:`.theano_function` means
  81. the cache is shared between all these applications.
  82. Attributes
  83. ==========
  84. cache : dict
  85. A cache of Theano variables which have been created for SymPy
  86. symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
  87. :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
  88. ensure that all references to a given symbol in an expression (or
  89. multiple expressions) are printed as the same Theano variable, which is
  90. created only once. Symbols are differentiated only by name and type. The
  91. format of the cache's contents should be considered opaque to the user.
  92. """
  93. printmethod = "_theano"
  94. def __init__(self, *args, **kwargs):
  95. self.cache = kwargs.pop('cache', {})
  96. super().__init__(*args, **kwargs)
  97. def _get_key(self, s, name=None, dtype=None, broadcastable=None):
  98. """ Get the cache key for a SymPy object.
  99. Parameters
  100. ==========
  101. s : sympy.core.basic.Basic
  102. SymPy object to get key for.
  103. name : str
  104. Name of object, if it does not have a ``name`` attribute.
  105. """
  106. if name is None:
  107. name = s.name
  108. return (name, type(s), s.args, dtype, broadcastable)
  109. def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
  110. """
  111. Get the Theano variable for a SymPy symbol from the cache, or create it
  112. if it does not exist.
  113. """
  114. # Defaults
  115. if name is None:
  116. name = s.name
  117. if dtype is None:
  118. dtype = 'floatX'
  119. if broadcastable is None:
  120. broadcastable = ()
  121. key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
  122. if key in self.cache:
  123. return self.cache[key]
  124. value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
  125. self.cache[key] = value
  126. return value
  127. def _print_Symbol(self, s, **kwargs):
  128. dtype = kwargs.get('dtypes', {}).get(s)
  129. bc = kwargs.get('broadcastables', {}).get(s)
  130. return self._get_or_create(s, dtype=dtype, broadcastable=bc)
  131. def _print_AppliedUndef(self, s, **kwargs):
  132. name = str(type(s)) + '_' + str(s.args[0])
  133. dtype = kwargs.get('dtypes', {}).get(s)
  134. bc = kwargs.get('broadcastables', {}).get(s)
  135. return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
  136. def _print_Basic(self, expr, **kwargs):
  137. op = mapping[type(expr)]
  138. children = [self._print(arg, **kwargs) for arg in expr.args]
  139. return op(*children)
  140. def _print_Number(self, n, **kwargs):
  141. # Integers already taken care of below, interpret as float
  142. return float(n.evalf())
  143. def _print_MatrixSymbol(self, X, **kwargs):
  144. dtype = kwargs.get('dtypes', {}).get(X)
  145. return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
  146. def _print_DenseMatrix(self, X, **kwargs):
  147. if not hasattr(tt, 'stacklists'):
  148. raise NotImplementedError(
  149. "Matrix translation not yet supported in this version of Theano")
  150. return tt.stacklists([
  151. [self._print(arg, **kwargs) for arg in L]
  152. for L in X.tolist()
  153. ])
  154. _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
  155. def _print_MatMul(self, expr, **kwargs):
  156. children = [self._print(arg, **kwargs) for arg in expr.args]
  157. result = children[0]
  158. for child in children[1:]:
  159. result = tt.dot(result, child)
  160. return result
  161. def _print_MatPow(self, expr, **kwargs):
  162. children = [self._print(arg, **kwargs) for arg in expr.args]
  163. result = 1
  164. if isinstance(children[1], int) and children[1] > 0:
  165. for i in range(children[1]):
  166. result = tt.dot(result, children[0])
  167. else:
  168. raise NotImplementedError('''Only non-negative integer
  169. powers of matrices can be handled by Theano at the moment''')
  170. return result
  171. def _print_MatrixSlice(self, expr, **kwargs):
  172. parent = self._print(expr.parent, **kwargs)
  173. rowslice = self._print(slice(*expr.rowslice), **kwargs)
  174. colslice = self._print(slice(*expr.colslice), **kwargs)
  175. return parent[rowslice, colslice]
  176. def _print_BlockMatrix(self, expr, **kwargs):
  177. nrows, ncols = expr.blocks.shape
  178. blocks = [[self._print(expr.blocks[r, c], **kwargs)
  179. for c in range(ncols)]
  180. for r in range(nrows)]
  181. return tt.join(0, *[tt.join(1, *row) for row in blocks])
  182. def _print_slice(self, expr, **kwargs):
  183. return slice(*[self._print(i, **kwargs)
  184. if isinstance(i, sympy.Basic) else i
  185. for i in (expr.start, expr.stop, expr.step)])
  186. def _print_Pi(self, expr, **kwargs):
  187. return 3.141592653589793
  188. def _print_Exp1(self, expr, **kwargs):
  189. return ts.exp(1)
  190. def _print_Piecewise(self, expr, **kwargs):
  191. import numpy as np
  192. e, cond = expr.args[0].args # First condition and corresponding value
  193. # Print conditional expression and value for first condition
  194. p_cond = self._print(cond, **kwargs)
  195. p_e = self._print(e, **kwargs)
  196. # One condition only
  197. if len(expr.args) == 1:
  198. # Return value if condition else NaN
  199. return tt.switch(p_cond, p_e, np.nan)
  200. # Return value_1 if condition_1 else evaluate remaining conditions
  201. p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
  202. return tt.switch(p_cond, p_e, p_remaining)
  203. def _print_Rational(self, expr, **kwargs):
  204. return tt.true_div(self._print(expr.p, **kwargs),
  205. self._print(expr.q, **kwargs))
  206. def _print_Integer(self, expr, **kwargs):
  207. return expr.p
  208. def _print_factorial(self, expr, **kwargs):
  209. return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
  210. def _print_Derivative(self, deriv, **kwargs):
  211. rv = self._print(deriv.expr, **kwargs)
  212. for var in deriv.variables:
  213. var = self._print(var, **kwargs)
  214. rv = tt.Rop(rv, var, tt.ones_like(var))
  215. return rv
  216. def emptyPrinter(self, expr):
  217. return expr
  218. def doprint(self, expr, dtypes=None, broadcastables=None):
  219. """ Convert a SymPy expression to a Theano graph variable.
  220. The ``dtypes`` and ``broadcastables`` arguments are used to specify the
  221. data type, dimension, and broadcasting behavior of the Theano variables
  222. corresponding to the free symbols in ``expr``. Each is a mapping from
  223. SymPy symbols to the value of the corresponding argument to
  224. ``theano.tensor.Tensor``.
  225. See the corresponding `documentation page`__ for more information on
  226. broadcasting in Theano.
  227. .. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html
  228. Parameters
  229. ==========
  230. expr : sympy.core.expr.Expr
  231. SymPy expression to print.
  232. dtypes : dict
  233. Mapping from SymPy symbols to Theano datatypes to use when creating
  234. new Theano variables for those symbols. Corresponds to the ``dtype``
  235. argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'``
  236. for symbols not included in the mapping.
  237. broadcastables : dict
  238. Mapping from SymPy symbols to the value of the ``broadcastable``
  239. argument to ``theano.tensor.Tensor`` to use when creating Theano
  240. variables for those symbols. Defaults to the empty tuple for symbols
  241. not included in the mapping (resulting in a scalar).
  242. Returns
  243. =======
  244. theano.gof.graph.Variable
  245. A variable corresponding to the expression's value in a Theano
  246. symbolic expression graph.
  247. """
  248. if dtypes is None:
  249. dtypes = {}
  250. if broadcastables is None:
  251. broadcastables = {}
  252. return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
  253. global_cache: dict[Any, Any] = {}
  254. def theano_code(expr, cache=None, **kwargs):
  255. """
  256. Convert a SymPy expression into a Theano graph variable.
  257. .. deprecated:: 1.8
  258. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  259. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  260. :ref:`theanocode-deprecated` for more information.
  261. Parameters
  262. ==========
  263. expr : sympy.core.expr.Expr
  264. SymPy expression object to convert.
  265. cache : dict
  266. Cached Theano variables (see :class:`TheanoPrinter.cache
  267. <TheanoPrinter>`). Defaults to the module-level global cache.
  268. dtypes : dict
  269. Passed to :meth:`.TheanoPrinter.doprint`.
  270. broadcastables : dict
  271. Passed to :meth:`.TheanoPrinter.doprint`.
  272. Returns
  273. =======
  274. theano.gof.graph.Variable
  275. A variable corresponding to the expression's value in a Theano symbolic
  276. expression graph.
  277. """
  278. sympy_deprecation_warning(
  279. """
  280. sympy.printing.theanocode is deprecated. Theano has been renamed to
  281. Aesara. Use sympy.printing.aesaracode instead.""",
  282. deprecated_since_version="1.8",
  283. active_deprecations_target='theanocode-deprecated')
  284. if not theano:
  285. raise ImportError("theano is required for theano_code")
  286. if cache is None:
  287. cache = global_cache
  288. return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
  289. def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
  290. r"""
  291. Get value of ``broadcastables`` argument to :func:`.theano_code` from
  292. keyword arguments to :func:`.theano_function`.
  293. Included for backwards compatibility.
  294. Parameters
  295. ==========
  296. inputs
  297. Sequence of input symbols.
  298. dim : int
  299. Common number of dimensions for all inputs. Overrides other arguments
  300. if given.
  301. dims : dict
  302. Mapping from input symbols to number of dimensions. Overrides
  303. ``broadcastables`` argument if given.
  304. broadcastables : dict
  305. Explicit value of ``broadcastables`` argument to
  306. :meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged.
  307. Returns
  308. =======
  309. dict
  310. Dictionary mapping elements of ``inputs`` to their "broadcastable"
  311. values (tuple of ``bool``\ s).
  312. """
  313. if dim is not None:
  314. return {s: (False,) * dim for s in inputs}
  315. if dims is not None:
  316. maxdim = max(dims.values())
  317. return {
  318. s: (False,) * d + (True,) * (maxdim - d)
  319. for s, d in dims.items()
  320. }
  321. if broadcastables is not None:
  322. return broadcastables
  323. return {}
  324. @doctest_depends_on(modules=('theano',))
  325. def theano_function(inputs, outputs, scalar=False, *,
  326. dim=None, dims=None, broadcastables=None, **kwargs):
  327. """
  328. Create a Theano function from SymPy expressions.
  329. .. deprecated:: 1.8
  330. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  331. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  332. :ref:`theanocode-deprecated` for more information.
  333. The inputs and outputs are converted to Theano variables using
  334. :func:`.theano_code` and then passed to ``theano.function``.
  335. Parameters
  336. ==========
  337. inputs
  338. Sequence of symbols which constitute the inputs of the function.
  339. outputs
  340. Sequence of expressions which constitute the outputs(s) of the
  341. function. The free symbols of each expression must be a subset of
  342. ``inputs``.
  343. scalar : bool
  344. Convert 0-dimensional arrays in output to scalars. This will return a
  345. Python wrapper function around the Theano function object.
  346. cache : dict
  347. Cached Theano variables (see :class:`TheanoPrinter.cache
  348. <TheanoPrinter>`). Defaults to the module-level global cache.
  349. dtypes : dict
  350. Passed to :meth:`.TheanoPrinter.doprint`.
  351. broadcastables : dict
  352. Passed to :meth:`.TheanoPrinter.doprint`.
  353. dims : dict
  354. Alternative to ``broadcastables`` argument. Mapping from elements of
  355. ``inputs`` to integers indicating the dimension of their associated
  356. arrays/tensors. Overrides ``broadcastables`` argument if given.
  357. dim : int
  358. Another alternative to the ``broadcastables`` argument. Common number of
  359. dimensions to use for all arrays/tensors.
  360. ``theano_function([x, y], [...], dim=2)`` is equivalent to using
  361. ``broadcastables={x: (False, False), y: (False, False)}``.
  362. Returns
  363. =======
  364. callable
  365. A callable object which takes values of ``inputs`` as positional
  366. arguments and returns an output array for each of the expressions
  367. in ``outputs``. If ``outputs`` is a single expression the function will
  368. return a Numpy array, if it is a list of multiple expressions the
  369. function will return a list of arrays. See description of the ``squeeze``
  370. argument above for the behavior when a single output is passed in a list.
  371. The returned object will either be an instance of
  372. ``theano.compile.function_module.Function`` or a Python wrapper
  373. function around one. In both cases, the returned value will have a
  374. ``theano_function`` attribute which points to the return value of
  375. ``theano.function``.
  376. Examples
  377. ========
  378. >>> from sympy.abc import x, y, z
  379. >>> from sympy.printing.theanocode import theano_function
  380. A simple function with one input and one output:
  381. >>> f1 = theano_function([x], [x**2 - 1], scalar=True)
  382. >>> f1(3)
  383. 8.0
  384. A function with multiple inputs and one output:
  385. >>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
  386. >>> f2(3, 4, 2)
  387. 5.0
  388. A function with multiple inputs and multiple outputs:
  389. >>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
  390. >>> f3(2, 3)
  391. [13.0, -5.0]
  392. See also
  393. ========
  394. dim_handling
  395. """
  396. sympy_deprecation_warning(
  397. """
  398. sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""",
  399. deprecated_since_version="1.8",
  400. active_deprecations_target='theanocode-deprecated')
  401. if not theano:
  402. raise ImportError("theano is required for theano_function")
  403. # Pop off non-theano keyword args
  404. cache = kwargs.pop('cache', {})
  405. dtypes = kwargs.pop('dtypes', {})
  406. broadcastables = dim_handling(
  407. inputs, dim=dim, dims=dims, broadcastables=broadcastables,
  408. )
  409. # Print inputs/outputs
  410. code = partial(theano_code, cache=cache, dtypes=dtypes,
  411. broadcastables=broadcastables)
  412. tinputs = list(map(code, inputs))
  413. toutputs = list(map(code, outputs))
  414. #fix constant expressions as variables
  415. toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs]
  416. if len(toutputs) == 1:
  417. toutputs = toutputs[0]
  418. # Compile theano func
  419. func = theano.function(tinputs, toutputs, **kwargs)
  420. is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
  421. # No wrapper required
  422. if not scalar or not any(is_0d):
  423. func.theano_function = func
  424. return func
  425. # Create wrapper to convert 0-dimensional outputs to scalars
  426. def wrapper(*args):
  427. out = func(*args)
  428. # out can be array(1.0) or [array(1.0), array(2.0)]
  429. if is_sequence(out):
  430. return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
  431. else:
  432. return out[()]
  433. wrapper.__wrapped__ = func
  434. wrapper.__doc__ = func.__doc__
  435. wrapper.theano_function = func
  436. return wrapper