cse_main.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946
  1. """ Tools for doing common subexpression elimination.
  2. """
  3. from collections import defaultdict
  4. from sympy.core import Basic, Mul, Add, Pow, sympify
  5. from sympy.core.containers import Tuple, OrderedSet
  6. from sympy.core.exprtools import factor_terms
  7. from sympy.core.singleton import S
  8. from sympy.core.sorting import ordered
  9. from sympy.core.symbol import symbols, Symbol
  10. from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix,
  11. SparseMatrix, ImmutableSparseMatrix)
  12. from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul,
  13. MatAdd, MatPow, Inverse)
  14. from sympy.matrices.expressions.matexpr import MatrixElement
  15. from sympy.polys.rootoftools import RootOf
  16. from sympy.utilities.iterables import numbered_symbols, sift, \
  17. topological_sort, iterable
  18. from . import cse_opts
  19. # (preprocessor, postprocessor) pairs which are commonly useful. They should
  20. # each take a SymPy expression and return a possibly transformed expression.
  21. # When used in the function ``cse()``, the target expressions will be transformed
  22. # by each of the preprocessor functions in order. After the common
  23. # subexpressions are eliminated, each resulting expression will have the
  24. # postprocessor functions transform them in *reverse* order in order to undo the
  25. # transformation if necessary. This allows the algorithm to operate on
  26. # a representation of the expressions that allows for more optimization
  27. # opportunities.
  28. # ``None`` can be used to specify no transformation for either the preprocessor or
  29. # postprocessor.
  30. basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post),
  31. (factor_terms, None)]
  32. # sometimes we want the output in a different format; non-trivial
  33. # transformations can be put here for users
  34. # ===============================================================
  35. def reps_toposort(r):
  36. """Sort replacements ``r`` so (k1, v1) appears before (k2, v2)
  37. if k2 is in v1's free symbols. This orders items in the
  38. way that cse returns its results (hence, in order to use the
  39. replacements in a substitution option it would make sense
  40. to reverse the order).
  41. Examples
  42. ========
  43. >>> from sympy.simplify.cse_main import reps_toposort
  44. >>> from sympy.abc import x, y
  45. >>> from sympy import Eq
  46. >>> for l, r in reps_toposort([(x, y + 1), (y, 2)]):
  47. ... print(Eq(l, r))
  48. ...
  49. Eq(y, 2)
  50. Eq(x, y + 1)
  51. """
  52. r = sympify(r)
  53. E = []
  54. for c1, (k1, v1) in enumerate(r):
  55. for c2, (k2, v2) in enumerate(r):
  56. if k1 in v2.free_symbols:
  57. E.append((c1, c2))
  58. return [r[i] for i in topological_sort((range(len(r)), E))]
  59. def cse_separate(r, e):
  60. """Move expressions that are in the form (symbol, expr) out of the
  61. expressions and sort them into the replacements using the reps_toposort.
  62. Examples
  63. ========
  64. >>> from sympy.simplify.cse_main import cse_separate
  65. >>> from sympy.abc import x, y, z
  66. >>> from sympy import cos, exp, cse, Eq, symbols
  67. >>> x0, x1 = symbols('x:2')
  68. >>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
  69. >>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [
  70. ... [[(x0, y + 1), (x, z + 1), (x1, x + 1)],
  71. ... [x1 + exp(x1/x0) + cos(x0), z - 2]],
  72. ... [[(x1, y + 1), (x, z + 1), (x0, x + 1)],
  73. ... [x0 + exp(x0/x1) + cos(x1), z - 2]]]
  74. ...
  75. True
  76. """
  77. d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol)
  78. r = r + [w.args for w in d[True]]
  79. e = d[False]
  80. return [reps_toposort(r), e]
  81. def cse_release_variables(r, e):
  82. """
  83. Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is
  84. either an expression or None. The value of None is used when a
  85. symbol is no longer needed for subsequent expressions.
  86. Use of such output can reduce the memory footprint of lambdified
  87. expressions that contain large, repeated subexpressions.
  88. Examples
  89. ========
  90. >>> from sympy import cse
  91. >>> from sympy.simplify.cse_main import cse_release_variables
  92. >>> from sympy.abc import x, y
  93. >>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)]
  94. >>> defs, rvs = cse_release_variables(*cse(eqs))
  95. >>> for i in defs:
  96. ... print(i)
  97. ...
  98. (x0, x + y)
  99. (x1, (x0 - 1)**2)
  100. (x2, 2*x + 1)
  101. (_3, x0/x2 + x1)
  102. (_4, x2**x0)
  103. (x2, None)
  104. (_0, x1)
  105. (x1, None)
  106. (_2, x0)
  107. (x0, None)
  108. (_1, x)
  109. >>> print(rvs)
  110. (_0, _1, _2, _3, _4)
  111. """
  112. if not r:
  113. return r, e
  114. s, p = zip(*r)
  115. esyms = symbols('_:%d' % len(e))
  116. syms = list(esyms)
  117. s = list(s)
  118. in_use = set(s)
  119. p = list(p)
  120. # sort e so those with most sub-expressions appear first
  121. e = [(e[i], syms[i]) for i in range(len(e))]
  122. e, syms = zip(*sorted(e,
  123. key=lambda x: -sum([p[s.index(i)].count_ops()
  124. for i in x[0].free_symbols & in_use])))
  125. syms = list(syms)
  126. p += e
  127. rv = []
  128. i = len(p) - 1
  129. while i >= 0:
  130. _p = p.pop()
  131. c = in_use & _p.free_symbols
  132. if c: # sorting for canonical results
  133. rv.extend([(s, None) for s in sorted(c, key=str)])
  134. if i >= len(r):
  135. rv.append((syms.pop(), _p))
  136. else:
  137. rv.append((s[i], _p))
  138. in_use -= c
  139. i -= 1
  140. rv.reverse()
  141. return rv, esyms
  142. # ====end of cse postprocess idioms===========================
  143. def preprocess_for_cse(expr, optimizations):
  144. """ Preprocess an expression to optimize for common subexpression
  145. elimination.
  146. Parameters
  147. ==========
  148. expr : SymPy expression
  149. The target expression to optimize.
  150. optimizations : list of (callable, callable) pairs
  151. The (preprocessor, postprocessor) pairs.
  152. Returns
  153. =======
  154. expr : SymPy expression
  155. The transformed expression.
  156. """
  157. for pre, post in optimizations:
  158. if pre is not None:
  159. expr = pre(expr)
  160. return expr
  161. def postprocess_for_cse(expr, optimizations):
  162. """Postprocess an expression after common subexpression elimination to
  163. return the expression to canonical SymPy form.
  164. Parameters
  165. ==========
  166. expr : SymPy expression
  167. The target expression to transform.
  168. optimizations : list of (callable, callable) pairs, optional
  169. The (preprocessor, postprocessor) pairs. The postprocessors will be
  170. applied in reversed order to undo the effects of the preprocessors
  171. correctly.
  172. Returns
  173. =======
  174. expr : SymPy expression
  175. The transformed expression.
  176. """
  177. for pre, post in reversed(optimizations):
  178. if post is not None:
  179. expr = post(expr)
  180. return expr
  181. class FuncArgTracker:
  182. """
  183. A class which manages a mapping from functions to arguments and an inverse
  184. mapping from arguments to functions.
  185. """
  186. def __init__(self, funcs):
  187. # To minimize the number of symbolic comparisons, all function arguments
  188. # get assigned a value number.
  189. self.value_numbers = {}
  190. self.value_number_to_value = []
  191. # Both of these maps use integer indices for arguments / functions.
  192. self.arg_to_funcset = []
  193. self.func_to_argset = []
  194. for func_i, func in enumerate(funcs):
  195. func_argset = OrderedSet()
  196. for func_arg in func.args:
  197. arg_number = self.get_or_add_value_number(func_arg)
  198. func_argset.add(arg_number)
  199. self.arg_to_funcset[arg_number].add(func_i)
  200. self.func_to_argset.append(func_argset)
  201. def get_args_in_value_order(self, argset):
  202. """
  203. Return the list of arguments in sorted order according to their value
  204. numbers.
  205. """
  206. return [self.value_number_to_value[argn] for argn in sorted(argset)]
  207. def get_or_add_value_number(self, value):
  208. """
  209. Return the value number for the given argument.
  210. """
  211. nvalues = len(self.value_numbers)
  212. value_number = self.value_numbers.setdefault(value, nvalues)
  213. if value_number == nvalues:
  214. self.value_number_to_value.append(value)
  215. self.arg_to_funcset.append(OrderedSet())
  216. return value_number
  217. def stop_arg_tracking(self, func_i):
  218. """
  219. Remove the function func_i from the argument to function mapping.
  220. """
  221. for arg in self.func_to_argset[func_i]:
  222. self.arg_to_funcset[arg].remove(func_i)
  223. def get_common_arg_candidates(self, argset, min_func_i=0):
  224. """Return a dict whose keys are function numbers. The entries of the dict are
  225. the number of arguments said function has in common with
  226. ``argset``. Entries have at least 2 items in common. All keys have
  227. value at least ``min_func_i``.
  228. """
  229. count_map = defaultdict(lambda: 0)
  230. if not argset:
  231. return count_map
  232. funcsets = [self.arg_to_funcset[arg] for arg in argset]
  233. # As an optimization below, we handle the largest funcset separately from
  234. # the others.
  235. largest_funcset = max(funcsets, key=len)
  236. for funcset in funcsets:
  237. if largest_funcset is funcset:
  238. continue
  239. for func_i in funcset:
  240. if func_i >= min_func_i:
  241. count_map[func_i] += 1
  242. # We pick the smaller of the two containers (count_map, largest_funcset)
  243. # to iterate over to reduce the number of iterations needed.
  244. (smaller_funcs_container,
  245. larger_funcs_container) = sorted(
  246. [largest_funcset, count_map],
  247. key=len)
  248. for func_i in smaller_funcs_container:
  249. # Not already in count_map? It can't possibly be in the output, so
  250. # skip it.
  251. if count_map[func_i] < 1:
  252. continue
  253. if func_i in larger_funcs_container:
  254. count_map[func_i] += 1
  255. return {k: v for k, v in count_map.items() if v >= 2}
  256. def get_subset_candidates(self, argset, restrict_to_funcset=None):
  257. """
  258. Return a set of functions each of which whose argument list contains
  259. ``argset``, optionally filtered only to contain functions in
  260. ``restrict_to_funcset``.
  261. """
  262. iarg = iter(argset)
  263. indices = OrderedSet(
  264. fi for fi in self.arg_to_funcset[next(iarg)])
  265. if restrict_to_funcset is not None:
  266. indices &= restrict_to_funcset
  267. for arg in iarg:
  268. indices &= self.arg_to_funcset[arg]
  269. return indices
  270. def update_func_argset(self, func_i, new_argset):
  271. """
  272. Update a function with a new set of arguments.
  273. """
  274. new_args = OrderedSet(new_argset)
  275. old_args = self.func_to_argset[func_i]
  276. for deleted_arg in old_args - new_args:
  277. self.arg_to_funcset[deleted_arg].remove(func_i)
  278. for added_arg in new_args - old_args:
  279. self.arg_to_funcset[added_arg].add(func_i)
  280. self.func_to_argset[func_i].clear()
  281. self.func_to_argset[func_i].update(new_args)
  282. class Unevaluated:
  283. def __init__(self, func, args):
  284. self.func = func
  285. self.args = args
  286. def __str__(self):
  287. return "Uneval<{}>({})".format(
  288. self.func, ", ".join(str(a) for a in self.args))
  289. def as_unevaluated_basic(self):
  290. return self.func(*self.args, evaluate=False)
  291. @property
  292. def free_symbols(self):
  293. return set().union(*[a.free_symbols for a in self.args])
  294. __repr__ = __str__
  295. def match_common_args(func_class, funcs, opt_subs):
  296. """
  297. Recognize and extract common subexpressions of function arguments within a
  298. set of function calls. For instance, for the following function calls::
  299. x + z + y
  300. sin(x + y)
  301. this will extract a common subexpression of `x + y`::
  302. w = x + y
  303. w + z
  304. sin(w)
  305. The function we work with is assumed to be associative and commutative.
  306. Parameters
  307. ==========
  308. func_class: class
  309. The function class (e.g. Add, Mul)
  310. funcs: list of functions
  311. A list of function calls.
  312. opt_subs: dict
  313. A dictionary of substitutions which this function may update.
  314. """
  315. # Sort to ensure that whole-function subexpressions come before the items
  316. # that use them.
  317. funcs = sorted(funcs, key=lambda f: len(f.args))
  318. arg_tracker = FuncArgTracker(funcs)
  319. changed = OrderedSet()
  320. for i in range(len(funcs)):
  321. common_arg_candidates_counts = arg_tracker.get_common_arg_candidates(
  322. arg_tracker.func_to_argset[i], min_func_i=i + 1)
  323. # Sort the candidates in order of match size.
  324. # This makes us try combining smaller matches first.
  325. common_arg_candidates = OrderedSet(sorted(
  326. common_arg_candidates_counts.keys(),
  327. key=lambda k: (common_arg_candidates_counts[k], k)))
  328. while common_arg_candidates:
  329. j = common_arg_candidates.pop(last=False)
  330. com_args = arg_tracker.func_to_argset[i].intersection(
  331. arg_tracker.func_to_argset[j])
  332. if len(com_args) <= 1:
  333. # This may happen if a set of common arguments was already
  334. # combined in a previous iteration.
  335. continue
  336. # For all sets, replace the common symbols by the function
  337. # over them, to allow recursive matches.
  338. diff_i = arg_tracker.func_to_argset[i].difference(com_args)
  339. if diff_i:
  340. # com_func needs to be unevaluated to allow for recursive matches.
  341. com_func = Unevaluated(
  342. func_class, arg_tracker.get_args_in_value_order(com_args))
  343. com_func_number = arg_tracker.get_or_add_value_number(com_func)
  344. arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number]))
  345. changed.add(i)
  346. else:
  347. # Treat the whole expression as a CSE.
  348. #
  349. # The reason this needs to be done is somewhat subtle. Within
  350. # tree_cse(), to_eliminate only contains expressions that are
  351. # seen more than once. The problem is unevaluated expressions
  352. # do not compare equal to the evaluated equivalent. So
  353. # tree_cse() won't mark funcs[i] as a CSE if we use an
  354. # unevaluated version.
  355. com_func_number = arg_tracker.get_or_add_value_number(funcs[i])
  356. diff_j = arg_tracker.func_to_argset[j].difference(com_args)
  357. arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number]))
  358. changed.add(j)
  359. for k in arg_tracker.get_subset_candidates(
  360. com_args, common_arg_candidates):
  361. diff_k = arg_tracker.func_to_argset[k].difference(com_args)
  362. arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number]))
  363. changed.add(k)
  364. if i in changed:
  365. opt_subs[funcs[i]] = Unevaluated(func_class,
  366. arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i]))
  367. arg_tracker.stop_arg_tracking(i)
  368. def opt_cse(exprs, order='canonical'):
  369. """Find optimization opportunities in Adds, Muls, Pows and negative
  370. coefficient Muls.
  371. Parameters
  372. ==========
  373. exprs : list of SymPy expressions
  374. The expressions to optimize.
  375. order : string, 'none' or 'canonical'
  376. The order by which Mul and Add arguments are processed. For large
  377. expressions where speed is a concern, use the setting order='none'.
  378. Returns
  379. =======
  380. opt_subs : dictionary of expression substitutions
  381. The expression substitutions which can be useful to optimize CSE.
  382. Examples
  383. ========
  384. >>> from sympy.simplify.cse_main import opt_cse
  385. >>> from sympy.abc import x
  386. >>> opt_subs = opt_cse([x**-2])
  387. >>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0]
  388. >>> print((k, v.as_unevaluated_basic()))
  389. (x**(-2), 1/(x**2))
  390. """
  391. opt_subs = {}
  392. adds = OrderedSet()
  393. muls = OrderedSet()
  394. seen_subexp = set()
  395. collapsible_subexp = set()
  396. def _find_opts(expr):
  397. if not isinstance(expr, (Basic, Unevaluated)):
  398. return
  399. if expr.is_Atom or expr.is_Order:
  400. return
  401. if iterable(expr):
  402. list(map(_find_opts, expr))
  403. return
  404. if expr in seen_subexp:
  405. return expr
  406. seen_subexp.add(expr)
  407. list(map(_find_opts, expr.args))
  408. if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign():
  409. # XXX -expr does not always work rigorously for some expressions
  410. # containing UnevaluatedExpr.
  411. # https://github.com/sympy/sympy/issues/24818
  412. if isinstance(expr, Add):
  413. neg_expr = Add(*(-i for i in expr.args))
  414. else:
  415. neg_expr = -expr
  416. if not neg_expr.is_Atom:
  417. opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr))
  418. seen_subexp.add(neg_expr)
  419. expr = neg_expr
  420. if isinstance(expr, (Mul, MatMul)):
  421. if len(expr.args) == 1:
  422. collapsible_subexp.add(expr)
  423. else:
  424. muls.add(expr)
  425. elif isinstance(expr, (Add, MatAdd)):
  426. if len(expr.args) == 1:
  427. collapsible_subexp.add(expr)
  428. else:
  429. adds.add(expr)
  430. elif isinstance(expr, Inverse):
  431. # Do not want to treat `Inverse` as a `MatPow`
  432. pass
  433. elif isinstance(expr, (Pow, MatPow)):
  434. base, exp = expr.base, expr.exp
  435. if exp.could_extract_minus_sign():
  436. opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1))
  437. for e in exprs:
  438. if isinstance(e, (Basic, Unevaluated)):
  439. _find_opts(e)
  440. # Handle collapsing of multinary operations with single arguments
  441. edges = [(s, s.args[0]) for s in collapsible_subexp
  442. if s.args[0] in collapsible_subexp]
  443. for e in reversed(topological_sort((collapsible_subexp, edges))):
  444. opt_subs[e] = opt_subs.get(e.args[0], e.args[0])
  445. # split muls into commutative
  446. commutative_muls = OrderedSet()
  447. for m in muls:
  448. c, nc = m.args_cnc(cset=False)
  449. if c:
  450. c_mul = m.func(*c)
  451. if nc:
  452. if c_mul == 1:
  453. new_obj = m.func(*nc)
  454. else:
  455. if isinstance(m, MatMul):
  456. new_obj = m.func(c_mul, *nc, evaluate=False)
  457. else:
  458. new_obj = m.func(c_mul, m.func(*nc), evaluate=False)
  459. opt_subs[m] = new_obj
  460. if len(c) > 1:
  461. commutative_muls.add(c_mul)
  462. match_common_args(Add, adds, opt_subs)
  463. match_common_args(Mul, commutative_muls, opt_subs)
  464. return opt_subs
  465. def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
  466. """Perform raw CSE on expression tree, taking opt_subs into account.
  467. Parameters
  468. ==========
  469. exprs : list of SymPy expressions
  470. The expressions to reduce.
  471. symbols : infinite iterator yielding unique Symbols
  472. The symbols used to label the common subexpressions which are pulled
  473. out.
  474. opt_subs : dictionary of expression substitutions
  475. The expressions to be substituted before any CSE action is performed.
  476. order : string, 'none' or 'canonical'
  477. The order by which Mul and Add arguments are processed. For large
  478. expressions where speed is a concern, use the setting order='none'.
  479. ignore : iterable of Symbols
  480. Substitutions containing any Symbol from ``ignore`` will be ignored.
  481. """
  482. if opt_subs is None:
  483. opt_subs = {}
  484. ## Find repeated sub-expressions
  485. to_eliminate = set()
  486. seen_subexp = set()
  487. excluded_symbols = set()
  488. def _find_repeated(expr):
  489. if not isinstance(expr, (Basic, Unevaluated)):
  490. return
  491. if isinstance(expr, RootOf):
  492. return
  493. if isinstance(expr, Basic) and (
  494. expr.is_Atom or
  495. expr.is_Order or
  496. isinstance(expr, (MatrixSymbol, MatrixElement))):
  497. if expr.is_Symbol:
  498. excluded_symbols.add(expr)
  499. return
  500. if iterable(expr):
  501. args = expr
  502. else:
  503. if expr in seen_subexp:
  504. for ign in ignore:
  505. if ign in expr.free_symbols:
  506. break
  507. else:
  508. to_eliminate.add(expr)
  509. return
  510. seen_subexp.add(expr)
  511. if expr in opt_subs:
  512. expr = opt_subs[expr]
  513. args = expr.args
  514. list(map(_find_repeated, args))
  515. for e in exprs:
  516. if isinstance(e, Basic):
  517. _find_repeated(e)
  518. ## Rebuild tree
  519. # Remove symbols from the generator that conflict with names in the expressions.
  520. symbols = (symbol for symbol in symbols if symbol not in excluded_symbols)
  521. replacements = []
  522. subs = {}
  523. def _rebuild(expr):
  524. if not isinstance(expr, (Basic, Unevaluated)):
  525. return expr
  526. if not expr.args:
  527. return expr
  528. if iterable(expr):
  529. new_args = [_rebuild(arg) for arg in expr.args]
  530. return expr.func(*new_args)
  531. if expr in subs:
  532. return subs[expr]
  533. orig_expr = expr
  534. if expr in opt_subs:
  535. expr = opt_subs[expr]
  536. # If enabled, parse Muls and Adds arguments by order to ensure
  537. # replacement order independent from hashes
  538. if order != 'none':
  539. if isinstance(expr, (Mul, MatMul)):
  540. c, nc = expr.args_cnc()
  541. if c == [1]:
  542. args = nc
  543. else:
  544. args = list(ordered(c)) + nc
  545. elif isinstance(expr, (Add, MatAdd)):
  546. args = list(ordered(expr.args))
  547. else:
  548. args = expr.args
  549. else:
  550. args = expr.args
  551. new_args = list(map(_rebuild, args))
  552. if isinstance(expr, Unevaluated) or new_args != args:
  553. new_expr = expr.func(*new_args)
  554. else:
  555. new_expr = expr
  556. if orig_expr in to_eliminate:
  557. try:
  558. sym = next(symbols)
  559. except StopIteration:
  560. raise ValueError("Symbols iterator ran out of symbols.")
  561. if isinstance(orig_expr, MatrixExpr):
  562. sym = MatrixSymbol(sym.name, orig_expr.rows,
  563. orig_expr.cols)
  564. subs[orig_expr] = sym
  565. replacements.append((sym, new_expr))
  566. return sym
  567. else:
  568. return new_expr
  569. reduced_exprs = []
  570. for e in exprs:
  571. if isinstance(e, Basic):
  572. reduced_e = _rebuild(e)
  573. else:
  574. reduced_e = e
  575. reduced_exprs.append(reduced_e)
  576. return replacements, reduced_exprs
  577. def cse(exprs, symbols=None, optimizations=None, postprocess=None,
  578. order='canonical', ignore=(), list=True):
  579. """ Perform common subexpression elimination on an expression.
  580. Parameters
  581. ==========
  582. exprs : list of SymPy expressions, or a single SymPy expression
  583. The expressions to reduce.
  584. symbols : infinite iterator yielding unique Symbols
  585. The symbols used to label the common subexpressions which are pulled
  586. out. The ``numbered_symbols`` generator is useful. The default is a
  587. stream of symbols of the form "x0", "x1", etc. This must be an
  588. infinite iterator.
  589. optimizations : list of (callable, callable) pairs
  590. The (preprocessor, postprocessor) pairs of external optimization
  591. functions. Optionally 'basic' can be passed for a set of predefined
  592. basic optimizations. Such 'basic' optimizations were used by default
  593. in old implementation, however they can be really slow on larger
  594. expressions. Now, no pre or post optimizations are made by default.
  595. postprocess : a function which accepts the two return values of cse and
  596. returns the desired form of output from cse, e.g. if you want the
  597. replacements reversed the function might be the following lambda:
  598. lambda r, e: return reversed(r), e
  599. order : string, 'none' or 'canonical'
  600. The order by which Mul and Add arguments are processed. If set to
  601. 'canonical', arguments will be canonically ordered. If set to 'none',
  602. ordering will be faster but dependent on expressions hashes, thus
  603. machine dependent and variable. For large expressions where speed is a
  604. concern, use the setting order='none'.
  605. ignore : iterable of Symbols
  606. Substitutions containing any Symbol from ``ignore`` will be ignored.
  607. list : bool, (default True)
  608. Returns expression in list or else with same type as input (when False).
  609. Returns
  610. =======
  611. replacements : list of (Symbol, expression) pairs
  612. All of the common subexpressions that were replaced. Subexpressions
  613. earlier in this list might show up in subexpressions later in this
  614. list.
  615. reduced_exprs : list of SymPy expressions
  616. The reduced expressions with all of the replacements above.
  617. Examples
  618. ========
  619. >>> from sympy import cse, SparseMatrix
  620. >>> from sympy.abc import x, y, z, w
  621. >>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3)
  622. ([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3])
  623. List of expressions with recursive substitutions:
  624. >>> m = SparseMatrix([x + y, x + y + z])
  625. >>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m])
  626. ([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([
  627. [x0],
  628. [x1]])])
  629. Note: the type and mutability of input matrices is retained.
  630. >>> isinstance(_[1][-1], SparseMatrix)
  631. True
  632. The user may disallow substitutions containing certain symbols:
  633. >>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,))
  634. ([(x0, x + 1)], [x0*y**2, 3*x0*y**2])
  635. The default return value for the reduced expression(s) is a list, even if there is only
  636. one expression. The `list` flag preserves the type of the input in the output:
  637. >>> cse(x)
  638. ([], [x])
  639. >>> cse(x, list=False)
  640. ([], x)
  641. """
  642. if not list:
  643. return _cse_homogeneous(exprs,
  644. symbols=symbols, optimizations=optimizations,
  645. postprocess=postprocess, order=order, ignore=ignore)
  646. if isinstance(exprs, (int, float)):
  647. exprs = sympify(exprs)
  648. # Handle the case if just one expression was passed.
  649. if isinstance(exprs, (Basic, MatrixBase)):
  650. exprs = [exprs]
  651. copy = exprs
  652. temp = []
  653. for e in exprs:
  654. if isinstance(e, (Matrix, ImmutableMatrix)):
  655. temp.append(Tuple(*e.flat()))
  656. elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
  657. temp.append(Tuple(*e.todok().items()))
  658. else:
  659. temp.append(e)
  660. exprs = temp
  661. del temp
  662. if optimizations is None:
  663. optimizations = []
  664. elif optimizations == 'basic':
  665. optimizations = basic_optimizations
  666. # Preprocess the expressions to give us better optimization opportunities.
  667. reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs]
  668. if symbols is None:
  669. symbols = numbered_symbols(cls=Symbol)
  670. else:
  671. # In case we get passed an iterable with an __iter__ method instead of
  672. # an actual iterator.
  673. symbols = iter(symbols)
  674. # Find other optimization opportunities.
  675. opt_subs = opt_cse(reduced_exprs, order)
  676. # Main CSE algorithm.
  677. replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs,
  678. order, ignore)
  679. # Postprocess the expressions to return the expressions to canonical form.
  680. exprs = copy
  681. for i, (sym, subtree) in enumerate(replacements):
  682. subtree = postprocess_for_cse(subtree, optimizations)
  683. replacements[i] = (sym, subtree)
  684. reduced_exprs = [postprocess_for_cse(e, optimizations)
  685. for e in reduced_exprs]
  686. # Get the matrices back
  687. for i, e in enumerate(exprs):
  688. if isinstance(e, (Matrix, ImmutableMatrix)):
  689. reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i])
  690. if isinstance(e, ImmutableMatrix):
  691. reduced_exprs[i] = reduced_exprs[i].as_immutable()
  692. elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)):
  693. m = SparseMatrix(e.rows, e.cols, {})
  694. for k, v in reduced_exprs[i]:
  695. m[k] = v
  696. if isinstance(e, ImmutableSparseMatrix):
  697. m = m.as_immutable()
  698. reduced_exprs[i] = m
  699. if postprocess is None:
  700. return replacements, reduced_exprs
  701. return postprocess(replacements, reduced_exprs)
  702. def _cse_homogeneous(exprs, **kwargs):
  703. """
  704. Same as ``cse`` but the ``reduced_exprs`` are returned
  705. with the same type as ``exprs`` or a sympified version of the same.
  706. Parameters
  707. ==========
  708. exprs : an Expr, iterable of Expr or dictionary with Expr values
  709. the expressions in which repeated subexpressions will be identified
  710. kwargs : additional arguments for the ``cse`` function
  711. Returns
  712. =======
  713. replacements : list of (Symbol, expression) pairs
  714. All of the common subexpressions that were replaced. Subexpressions
  715. earlier in this list might show up in subexpressions later in this
  716. list.
  717. reduced_exprs : list of SymPy expressions
  718. The reduced expressions with all of the replacements above.
  719. Examples
  720. ========
  721. >>> from sympy.simplify.cse_main import cse
  722. >>> from sympy import cos, Tuple, Matrix
  723. >>> from sympy.abc import x
  724. >>> output = lambda x: type(cse(x, list=False)[1])
  725. >>> output(1)
  726. <class 'sympy.core.numbers.One'>
  727. >>> output('cos(x)')
  728. <class 'str'>
  729. >>> output(cos(x))
  730. cos
  731. >>> output(Tuple(1, x))
  732. <class 'sympy.core.containers.Tuple'>
  733. >>> output(Matrix([[1,0], [0,1]]))
  734. <class 'sympy.matrices.dense.MutableDenseMatrix'>
  735. >>> output([1, x])
  736. <class 'list'>
  737. >>> output((1, x))
  738. <class 'tuple'>
  739. >>> output({1, x})
  740. <class 'set'>
  741. """
  742. if isinstance(exprs, str):
  743. replacements, reduced_exprs = _cse_homogeneous(
  744. sympify(exprs), **kwargs)
  745. return replacements, repr(reduced_exprs)
  746. if isinstance(exprs, (list, tuple, set)):
  747. replacements, reduced_exprs = cse(exprs, **kwargs)
  748. return replacements, type(exprs)(reduced_exprs)
  749. if isinstance(exprs, dict):
  750. keys = list(exprs.keys()) # In order to guarantee the order of the elements.
  751. replacements, values = cse([exprs[k] for k in keys], **kwargs)
  752. reduced_exprs = dict(zip(keys, values))
  753. return replacements, reduced_exprs
  754. try:
  755. replacements, (reduced_exprs,) = cse(exprs, **kwargs)
  756. except TypeError: # For example 'mpf' objects
  757. return [], exprs
  758. else:
  759. return replacements, reduced_exprs