numpy.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. from sympy.core import S
  2. from sympy.core.function import Lambda
  3. from sympy.core.power import Pow
  4. from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter
  5. from .codeprinter import CodePrinter
  6. _not_in_numpy = 'erf erfc factorial gamma loggamma'.split()
  7. _in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy]
  8. _known_functions_numpy = dict(_in_numpy, **{
  9. 'acos': 'arccos',
  10. 'acosh': 'arccosh',
  11. 'asin': 'arcsin',
  12. 'asinh': 'arcsinh',
  13. 'atan': 'arctan',
  14. 'atan2': 'arctan2',
  15. 'atanh': 'arctanh',
  16. 'exp2': 'exp2',
  17. 'sign': 'sign',
  18. 'logaddexp': 'logaddexp',
  19. 'logaddexp2': 'logaddexp2',
  20. })
  21. _known_constants_numpy = {
  22. 'Exp1': 'e',
  23. 'Pi': 'pi',
  24. 'EulerGamma': 'euler_gamma',
  25. 'NaN': 'nan',
  26. 'Infinity': 'PINF',
  27. 'NegativeInfinity': 'NINF'
  28. }
  29. _numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()}
  30. _numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()}
  31. class NumPyPrinter(ArrayPrinter, PythonCodePrinter):
  32. """
  33. Numpy printer which handles vectorized piecewise functions,
  34. logical operators, etc.
  35. """
  36. _module = 'numpy'
  37. _kf = _numpy_known_functions
  38. _kc = _numpy_known_constants
  39. def __init__(self, settings=None):
  40. """
  41. `settings` is passed to CodePrinter.__init__()
  42. `module` specifies the array module to use, currently 'NumPy', 'CuPy'
  43. or 'JAX'.
  44. """
  45. self.language = "Python with {}".format(self._module)
  46. self.printmethod = "_{}code".format(self._module)
  47. self._kf = {**PythonCodePrinter._kf, **self._kf}
  48. super().__init__(settings=settings)
  49. def _print_seq(self, seq):
  50. "General sequence printer: converts to tuple"
  51. # Print tuples here instead of lists because numba supports
  52. # tuples in nopython mode.
  53. delimiter=', '
  54. return '({},)'.format(delimiter.join(self._print(item) for item in seq))
  55. def _print_MatMul(self, expr):
  56. "Matrix multiplication printer"
  57. if expr.as_coeff_matrices()[0] is not S.One:
  58. expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])]
  59. return '({})'.format(').dot('.join(self._print(i) for i in expr_list))
  60. return '({})'.format(').dot('.join(self._print(i) for i in expr.args))
  61. def _print_MatPow(self, expr):
  62. "Matrix power printer"
  63. return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'),
  64. self._print(expr.args[0]), self._print(expr.args[1]))
  65. def _print_Inverse(self, expr):
  66. "Matrix inverse printer"
  67. return '{}({})'.format(self._module_format(self._module + '.linalg.inv'),
  68. self._print(expr.args[0]))
  69. def _print_DotProduct(self, expr):
  70. # DotProduct allows any shape order, but numpy.dot does matrix
  71. # multiplication, so we have to make sure it gets 1 x n by n x 1.
  72. arg1, arg2 = expr.args
  73. if arg1.shape[0] != 1:
  74. arg1 = arg1.T
  75. if arg2.shape[1] != 1:
  76. arg2 = arg2.T
  77. return "%s(%s, %s)" % (self._module_format(self._module + '.dot'),
  78. self._print(arg1),
  79. self._print(arg2))
  80. def _print_MatrixSolve(self, expr):
  81. return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'),
  82. self._print(expr.matrix),
  83. self._print(expr.vector))
  84. def _print_ZeroMatrix(self, expr):
  85. return '{}({})'.format(self._module_format(self._module + '.zeros'),
  86. self._print(expr.shape))
  87. def _print_OneMatrix(self, expr):
  88. return '{}({})'.format(self._module_format(self._module + '.ones'),
  89. self._print(expr.shape))
  90. def _print_FunctionMatrix(self, expr):
  91. from sympy.abc import i, j
  92. lamda = expr.lamda
  93. if not isinstance(lamda, Lambda):
  94. lamda = Lambda((i, j), lamda(i, j))
  95. return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'),
  96. ', '.join(self._print(arg) for arg in lamda.args[0]),
  97. self._print(lamda.args[1]), self._print(expr.shape))
  98. def _print_HadamardProduct(self, expr):
  99. func = self._module_format(self._module + '.multiply')
  100. return ''.join('{}({}, '.format(func, self._print(arg)) \
  101. for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
  102. ')' * (len(expr.args) - 1))
  103. def _print_KroneckerProduct(self, expr):
  104. func = self._module_format(self._module + '.kron')
  105. return ''.join('{}({}, '.format(func, self._print(arg)) \
  106. for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
  107. ')' * (len(expr.args) - 1))
  108. def _print_Adjoint(self, expr):
  109. return '{}({}({}))'.format(
  110. self._module_format(self._module + '.conjugate'),
  111. self._module_format(self._module + '.transpose'),
  112. self._print(expr.args[0]))
  113. def _print_DiagonalOf(self, expr):
  114. vect = '{}({})'.format(
  115. self._module_format(self._module + '.diag'),
  116. self._print(expr.arg))
  117. return '{}({}, (-1, 1))'.format(
  118. self._module_format(self._module + '.reshape'), vect)
  119. def _print_DiagMatrix(self, expr):
  120. return '{}({})'.format(self._module_format(self._module + '.diagflat'),
  121. self._print(expr.args[0]))
  122. def _print_DiagonalMatrix(self, expr):
  123. return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'),
  124. self._print(expr.arg), self._module_format(self._module + '.eye'),
  125. self._print(expr.shape[0]), self._print(expr.shape[1]))
  126. def _print_Piecewise(self, expr):
  127. "Piecewise function printer"
  128. from sympy.logic.boolalg import ITE, simplify_logic
  129. def print_cond(cond):
  130. """ Problem having an ITE in the cond. """
  131. if cond.has(ITE):
  132. return self._print(simplify_logic(cond))
  133. else:
  134. return self._print(cond)
  135. exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args))
  136. conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args))
  137. # If [default_value, True] is a (expr, cond) sequence in a Piecewise object
  138. # it will behave the same as passing the 'default' kwarg to select()
  139. # *as long as* it is the last element in expr.args.
  140. # If this is not the case, it may be triggered prematurely.
  141. return '{}({}, {}, default={})'.format(
  142. self._module_format(self._module + '.select'), conds, exprs,
  143. self._print(S.NaN))
  144. def _print_Relational(self, expr):
  145. "Relational printer for Equality and Unequality"
  146. op = {
  147. '==' :'equal',
  148. '!=' :'not_equal',
  149. '<' :'less',
  150. '<=' :'less_equal',
  151. '>' :'greater',
  152. '>=' :'greater_equal',
  153. }
  154. if expr.rel_op in op:
  155. lhs = self._print(expr.lhs)
  156. rhs = self._print(expr.rhs)
  157. return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]),
  158. lhs=lhs, rhs=rhs)
  159. return super()._print_Relational(expr)
  160. def _print_And(self, expr):
  161. "Logical And printer"
  162. # We have to override LambdaPrinter because it uses Python 'and' keyword.
  163. # If LambdaPrinter didn't define it, we could use StrPrinter's
  164. # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
  165. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args))
  166. def _print_Or(self, expr):
  167. "Logical Or printer"
  168. # We have to override LambdaPrinter because it uses Python 'or' keyword.
  169. # If LambdaPrinter didn't define it, we could use StrPrinter's
  170. # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
  171. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args))
  172. def _print_Not(self, expr):
  173. "Logical Not printer"
  174. # We have to override LambdaPrinter because it uses Python 'not' keyword.
  175. # If LambdaPrinter didn't define it, we would still have to define our
  176. # own because StrPrinter doesn't define it.
  177. return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args))
  178. def _print_Pow(self, expr, rational=False):
  179. # XXX Workaround for negative integer power error
  180. if expr.exp.is_integer and expr.exp.is_negative:
  181. expr = Pow(expr.base, expr.exp.evalf(), evaluate=False)
  182. return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt')
  183. def _print_Min(self, expr):
  184. return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amin'), ','.join(self._print(i) for i in expr.args))
  185. def _print_Max(self, expr):
  186. return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amax'), ','.join(self._print(i) for i in expr.args))
  187. def _print_arg(self, expr):
  188. return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0]))
  189. def _print_im(self, expr):
  190. return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0]))
  191. def _print_Mod(self, expr):
  192. return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join(
  193. (self._print(arg) for arg in expr.args)))
  194. def _print_re(self, expr):
  195. return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0]))
  196. def _print_sinc(self, expr):
  197. return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi))
  198. def _print_MatrixBase(self, expr):
  199. func = self.known_functions.get(expr.__class__.__name__, None)
  200. if func is None:
  201. func = self._module_format(self._module + '.array')
  202. return "%s(%s)" % (func, self._print(expr.tolist()))
  203. def _print_Identity(self, expr):
  204. shape = expr.shape
  205. if all(dim.is_Integer for dim in shape):
  206. return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0]))
  207. else:
  208. raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices")
  209. def _print_BlockMatrix(self, expr):
  210. return '{}({})'.format(self._module_format(self._module + '.block'),
  211. self._print(expr.args[0].tolist()))
  212. def _print_NDimArray(self, expr):
  213. if len(expr.shape) == 1:
  214. return self._module + '.array(' + self._print(expr.args[0]) + ')'
  215. if len(expr.shape) == 2:
  216. return self._print(expr.tomatrix())
  217. # Should be possible to extend to more dimensions
  218. return CodePrinter._print_not_supported(self, expr)
  219. _add = "add"
  220. _einsum = "einsum"
  221. _transpose = "transpose"
  222. _ones = "ones"
  223. _zeros = "zeros"
  224. _print_lowergamma = CodePrinter._print_not_supported
  225. _print_uppergamma = CodePrinter._print_not_supported
  226. _print_fresnelc = CodePrinter._print_not_supported
  227. _print_fresnels = CodePrinter._print_not_supported
  228. for func in _numpy_known_functions:
  229. setattr(NumPyPrinter, f'_print_{func}', _print_known_func)
  230. for const in _numpy_known_constants:
  231. setattr(NumPyPrinter, f'_print_{const}', _print_known_const)
  232. _known_functions_scipy_special = {
  233. 'Ei': 'expi',
  234. 'erf': 'erf',
  235. 'erfc': 'erfc',
  236. 'besselj': 'jv',
  237. 'bessely': 'yv',
  238. 'besseli': 'iv',
  239. 'besselk': 'kv',
  240. 'cosm1': 'cosm1',
  241. 'powm1': 'powm1',
  242. 'factorial': 'factorial',
  243. 'gamma': 'gamma',
  244. 'loggamma': 'gammaln',
  245. 'digamma': 'psi',
  246. 'polygamma': 'polygamma',
  247. 'RisingFactorial': 'poch',
  248. 'jacobi': 'eval_jacobi',
  249. 'gegenbauer': 'eval_gegenbauer',
  250. 'chebyshevt': 'eval_chebyt',
  251. 'chebyshevu': 'eval_chebyu',
  252. 'legendre': 'eval_legendre',
  253. 'hermite': 'eval_hermite',
  254. 'laguerre': 'eval_laguerre',
  255. 'assoc_laguerre': 'eval_genlaguerre',
  256. 'beta': 'beta',
  257. 'LambertW' : 'lambertw',
  258. }
  259. _known_constants_scipy_constants = {
  260. 'GoldenRatio': 'golden_ratio',
  261. 'Pi': 'pi',
  262. }
  263. _scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()}
  264. _scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()}
  265. class SciPyPrinter(NumPyPrinter):
  266. _kf = {**NumPyPrinter._kf, **_scipy_known_functions}
  267. _kc = {**NumPyPrinter._kc, **_scipy_known_constants}
  268. def __init__(self, settings=None):
  269. super().__init__(settings=settings)
  270. self.language = "Python with SciPy and NumPy"
  271. def _print_SparseRepMatrix(self, expr):
  272. i, j, data = [], [], []
  273. for (r, c), v in expr.todok().items():
  274. i.append(r)
  275. j.append(c)
  276. data.append(v)
  277. return "{name}(({data}, ({i}, {j})), shape={shape})".format(
  278. name=self._module_format('scipy.sparse.coo_matrix'),
  279. data=data, i=i, j=j, shape=expr.shape
  280. )
  281. _print_ImmutableSparseMatrix = _print_SparseRepMatrix
  282. # SciPy's lpmv has a different order of arguments from assoc_legendre
  283. def _print_assoc_legendre(self, expr):
  284. return "{0}({2}, {1}, {3})".format(
  285. self._module_format('scipy.special.lpmv'),
  286. self._print(expr.args[0]),
  287. self._print(expr.args[1]),
  288. self._print(expr.args[2]))
  289. def _print_lowergamma(self, expr):
  290. return "{0}({2})*{1}({2}, {3})".format(
  291. self._module_format('scipy.special.gamma'),
  292. self._module_format('scipy.special.gammainc'),
  293. self._print(expr.args[0]),
  294. self._print(expr.args[1]))
  295. def _print_uppergamma(self, expr):
  296. return "{0}({2})*{1}({2}, {3})".format(
  297. self._module_format('scipy.special.gamma'),
  298. self._module_format('scipy.special.gammaincc'),
  299. self._print(expr.args[0]),
  300. self._print(expr.args[1]))
  301. def _print_betainc(self, expr):
  302. betainc = self._module_format('scipy.special.betainc')
  303. beta = self._module_format('scipy.special.beta')
  304. args = [self._print(arg) for arg in expr.args]
  305. return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \
  306. * {beta}({args[0]}, {args[1]})"
  307. def _print_betainc_regularized(self, expr):
  308. return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format(
  309. self._module_format('scipy.special.betainc'),
  310. self._print(expr.args[0]),
  311. self._print(expr.args[1]),
  312. self._print(expr.args[2]),
  313. self._print(expr.args[3]))
  314. def _print_fresnels(self, expr):
  315. return "{}({})[0]".format(
  316. self._module_format("scipy.special.fresnel"),
  317. self._print(expr.args[0]))
  318. def _print_fresnelc(self, expr):
  319. return "{}({})[1]".format(
  320. self._module_format("scipy.special.fresnel"),
  321. self._print(expr.args[0]))
  322. def _print_airyai(self, expr):
  323. return "{}({})[0]".format(
  324. self._module_format("scipy.special.airy"),
  325. self._print(expr.args[0]))
  326. def _print_airyaiprime(self, expr):
  327. return "{}({})[1]".format(
  328. self._module_format("scipy.special.airy"),
  329. self._print(expr.args[0]))
  330. def _print_airybi(self, expr):
  331. return "{}({})[2]".format(
  332. self._module_format("scipy.special.airy"),
  333. self._print(expr.args[0]))
  334. def _print_airybiprime(self, expr):
  335. return "{}({})[3]".format(
  336. self._module_format("scipy.special.airy"),
  337. self._print(expr.args[0]))
  338. def _print_bernoulli(self, expr):
  339. # scipy's bernoulli is inconsistent with SymPy's so rewrite
  340. return self._print(expr._eval_rewrite_as_zeta(*expr.args))
  341. def _print_harmonic(self, expr):
  342. return self._print(expr._eval_rewrite_as_zeta(*expr.args))
  343. def _print_Integral(self, e):
  344. integration_vars, limits = _unpack_integral_limits(e)
  345. if len(limits) == 1:
  346. # nicer (but not necessary) to prefer quad over nquad for 1D case
  347. module_str = self._module_format("scipy.integrate.quad")
  348. limit_str = "%s, %s" % tuple(map(self._print, limits[0]))
  349. else:
  350. module_str = self._module_format("scipy.integrate.nquad")
  351. limit_str = "({})".format(", ".join(
  352. "(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  353. return "{}(lambda {}: {}, {})[0]".format(
  354. module_str,
  355. ", ".join(map(self._print, integration_vars)),
  356. self._print(e.args[0]),
  357. limit_str)
  358. def _print_Si(self, expr):
  359. return "{}({})[0]".format(
  360. self._module_format("scipy.special.sici"),
  361. self._print(expr.args[0]))
  362. def _print_Ci(self, expr):
  363. return "{}({})[1]".format(
  364. self._module_format("scipy.special.sici"),
  365. self._print(expr.args[0]))
  366. for func in _scipy_known_functions:
  367. setattr(SciPyPrinter, f'_print_{func}', _print_known_func)
  368. for const in _scipy_known_constants:
  369. setattr(SciPyPrinter, f'_print_{const}', _print_known_const)
  370. _cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()}
  371. _cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()}
  372. class CuPyPrinter(NumPyPrinter):
  373. """
  374. CuPy printer which handles vectorized piecewise functions,
  375. logical operators, etc.
  376. """
  377. _module = 'cupy'
  378. _kf = _cupy_known_functions
  379. _kc = _cupy_known_constants
  380. def __init__(self, settings=None):
  381. super().__init__(settings=settings)
  382. for func in _cupy_known_functions:
  383. setattr(CuPyPrinter, f'_print_{func}', _print_known_func)
  384. for const in _cupy_known_constants:
  385. setattr(CuPyPrinter, f'_print_{const}', _print_known_const)
  386. _jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()}
  387. _jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()}
  388. class JaxPrinter(NumPyPrinter):
  389. """
  390. JAX printer which handles vectorized piecewise functions,
  391. logical operators, etc.
  392. """
  393. _module = "jax.numpy"
  394. _kf = _jax_known_functions
  395. _kc = _jax_known_constants
  396. def __init__(self, settings=None):
  397. super().__init__(settings=settings)
  398. # These need specific override to allow for the lack of "jax.numpy.reduce"
  399. def _print_And(self, expr):
  400. "Logical And printer"
  401. return "{}({}.asarray([{}]), axis=0)".format(
  402. self._module_format(self._module + ".all"),
  403. self._module_format(self._module),
  404. ",".join(self._print(i) for i in expr.args),
  405. )
  406. def _print_Or(self, expr):
  407. "Logical Or printer"
  408. return "{}({}.asarray([{}]), axis=0)".format(
  409. self._module_format(self._module + ".any"),
  410. self._module_format(self._module),
  411. ",".join(self._print(i) for i in expr.args),
  412. )
  413. for func in _jax_known_functions:
  414. setattr(JaxPrinter, f'_print_{func}', _print_known_func)
  415. for const in _jax_known_constants:
  416. setattr(JaxPrinter, f'_print_{const}', _print_known_const)