manualintegrate.py 74 KB


  1. """Integration method that emulates by-hand techniques.
  2. This module also provides functionality to get the steps used to evaluate a
  3. particular integral, in the ``integral_steps`` function. This will return
  4. nested ``Rule`` s representing the integration rules used.
  5. Each ``Rule`` class represents a (maybe parametrized) integration rule, e.g.
  6. ``SinRule`` for integrating ``sin(x)`` and ``ReciprocalSqrtQuadraticRule``
  7. for integrating ``1/sqrt(a+b*x+c*x**2)``. The ``eval`` method returns the
  8. integration result.
  9. The ``manualintegrate`` function computes the integral by calling ``eval``
  10. on the rule returned by ``integral_steps``.
  11. The integrator can be extended with new heuristics and evaluation
  12. techniques. To do so, extend the ``Rule`` class, implement ``eval`` method,
  13. then write a function that accepts an ``IntegralInfo`` object and returns
  14. either a ``Rule`` instance or ``None``. If the new technique requires a new
  15. match, add the key and call to the antiderivative function to integral_steps.
  16. To enable simple substitutions, add the match to find_substitutions.
  17. """
  18. from __future__ import annotations
  19. from typing import NamedTuple, Type, Callable, Sequence
  20. from abc import ABC, abstractmethod
  21. from dataclasses import dataclass
  22. from collections import defaultdict
  23. from collections.abc import Mapping
  24. from sympy.core.add import Add
  25. from sympy.core.cache import cacheit
  26. from sympy.core.containers import Dict
  27. from sympy.core.expr import Expr
  28. from sympy.core.function import Derivative
  29. from sympy.core.logic import fuzzy_not
  30. from sympy.core.mul import Mul
  31. from sympy.core.numbers import Integer, Number, E
  32. from sympy.core.power import Pow
  33. from sympy.core.relational import Eq, Ne, Boolean
  34. from sympy.core.singleton import S
  35. from sympy.core.symbol import Dummy, Symbol, Wild
  36. from sympy.functions.elementary.complexes import Abs
  37. from sympy.functions.elementary.exponential import exp, log
  38. from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch,
  39. cosh, coth, sech, sinh, tanh, asinh)
  40. from sympy.functions.elementary.miscellaneous import sqrt
  41. from sympy.functions.elementary.piecewise import Piecewise
  42. from sympy.functions.elementary.trigonometric import (TrigonometricFunction,
  43. cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec)
  44. from sympy.functions.special.delta_functions import Heaviside, DiracDelta
  45. from sympy.functions.special.error_functions import (erf, erfi, fresnelc,
  46. fresnels, Ci, Chi, Si, Shi, Ei, li)
  47. from sympy.functions.special.gamma_functions import uppergamma
  48. from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f
  49. from sympy.functions.special.polynomials import (chebyshevt, chebyshevu,
  50. legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi,
  51. OrthogonalPolynomial)
  52. from sympy.functions.special.zeta_functions import polylog
  53. from .integrals import Integral
  54. from sympy.logic.boolalg import And
  55. from sympy.ntheory.factor_ import primefactors
  56. from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly
  57. from sympy.simplify.radsimp import fraction
  58. from sympy.simplify.simplify import simplify
  59. from sympy.solvers.solvers import solve
  60. from sympy.strategies.core import switch, do_one, null_safe, condition
  61. from sympy.utilities.iterables import iterable
  62. from sympy.utilities.misc import debug
  63. @dataclass
  64. class Rule(ABC):
  65. integrand: Expr
  66. variable: Symbol
  67. @abstractmethod
  68. def eval(self) -> Expr:
  69. pass
  70. @abstractmethod
  71. def contains_dont_know(self) -> bool:
  72. pass
  73. @dataclass
  74. class AtomicRule(Rule, ABC):
  75. """A simple rule that does not depend on other rules"""
  76. def contains_dont_know(self) -> bool:
  77. return False
  78. @dataclass
  79. class ConstantRule(AtomicRule):
  80. """integrate(a, x) -> a*x"""
  81. def eval(self) -> Expr:
  82. return self.integrand * self.variable
  83. @dataclass
  84. class ConstantTimesRule(Rule):
  85. """integrate(a*f(x), x) -> a*integrate(f(x), x)"""
  86. constant: Expr
  87. other: Expr
  88. substep: Rule
  89. def eval(self) -> Expr:
  90. return self.constant * self.substep.eval()
  91. def contains_dont_know(self) -> bool:
  92. return self.substep.contains_dont_know()
  93. @dataclass
  94. class PowerRule(AtomicRule):
  95. """integrate(x**a, x)"""
  96. base: Expr
  97. exp: Expr
  98. def eval(self) -> Expr:
  99. return Piecewise(
  100. ((self.base**(self.exp + 1))/(self.exp + 1), Ne(self.exp, -1)),
  101. (log(self.base), True),
  102. )
  103. @dataclass
  104. class NestedPowRule(AtomicRule):
  105. """integrate((x**a)**b, x)"""
  106. base: Expr
  107. exp: Expr
  108. def eval(self) -> Expr:
  109. m = self.base * self.integrand
  110. return Piecewise((m / (self.exp + 1), Ne(self.exp, -1)),
  111. (m * log(self.base), True))
  112. @dataclass
  113. class AddRule(Rule):
  114. """integrate(f(x) + g(x), x) -> integrate(f(x), x) + integrate(g(x), x)"""
  115. substeps: list[Rule]
  116. def eval(self) -> Expr:
  117. return Add(*(substep.eval() for substep in self.substeps))
  118. def contains_dont_know(self) -> bool:
  119. return any(substep.contains_dont_know() for substep in self.substeps)
  120. @dataclass
  121. class URule(Rule):
  122. """integrate(f(g(x))*g'(x), x) -> integrate(f(u), u), u = g(x)"""
  123. u_var: Symbol
  124. u_func: Expr
  125. substep: Rule
  126. def eval(self) -> Expr:
  127. result = self.substep.eval()
  128. if self.u_func.is_Pow:
  129. base, exp_ = self.u_func.as_base_exp()
  130. if exp_ == -1:
  131. # avoid needless -log(1/x) from substitution
  132. result = result.subs(log(self.u_var), -log(base))
  133. return result.subs(self.u_var, self.u_func)
  134. def contains_dont_know(self) -> bool:
  135. return self.substep.contains_dont_know()
  136. @dataclass
  137. class PartsRule(Rule):
  138. """integrate(u(x)*v'(x), x) -> u(x)*v(x) - integrate(u'(x)*v(x), x)"""
  139. u: Symbol
  140. dv: Expr
  141. v_step: Rule
  142. second_step: Rule | None # None when is a substep of CyclicPartsRule
  143. def eval(self) -> Expr:
  144. assert self.second_step is not None
  145. v = self.v_step.eval()
  146. return self.u * v - self.second_step.eval()
  147. def contains_dont_know(self) -> bool:
  148. return self.v_step.contains_dont_know() or (
  149. self.second_step is not None and self.second_step.contains_dont_know())
  150. @dataclass
  151. class CyclicPartsRule(Rule):
  152. """Apply PartsRule multiple times to integrate exp(x)*sin(x)"""
  153. parts_rules: list[PartsRule]
  154. coefficient: Expr
  155. def eval(self) -> Expr:
  156. result = []
  157. sign = 1
  158. for rule in self.parts_rules:
  159. result.append(sign * rule.u * rule.v_step.eval())
  160. sign *= -1
  161. return Add(*result) / (1 - self.coefficient)
  162. def contains_dont_know(self) -> bool:
  163. return any(substep.contains_dont_know() for substep in self.parts_rules)
  164. @dataclass
  165. class TrigRule(AtomicRule, ABC):
  166. pass
  167. @dataclass
  168. class SinRule(TrigRule):
  169. """integrate(sin(x), x) -> -cos(x)"""
  170. def eval(self) -> Expr:
  171. return -cos(self.variable)
  172. @dataclass
  173. class CosRule(TrigRule):
  174. """integrate(cos(x), x) -> sin(x)"""
  175. def eval(self) -> Expr:
  176. return sin(self.variable)
  177. @dataclass
  178. class SecTanRule(TrigRule):
  179. """integrate(sec(x)*tan(x), x) -> sec(x)"""
  180. def eval(self) -> Expr:
  181. return sec(self.variable)
  182. @dataclass
  183. class CscCotRule(TrigRule):
  184. """integrate(csc(x)*cot(x), x) -> -csc(x)"""
  185. def eval(self) -> Expr:
  186. return -csc(self.variable)
  187. @dataclass
  188. class Sec2Rule(TrigRule):
  189. """integrate(sec(x)**2, x) -> tan(x)"""
  190. def eval(self) -> Expr:
  191. return tan(self.variable)
  192. @dataclass
  193. class Csc2Rule(TrigRule):
  194. """integrate(csc(x)**2, x) -> -cot(x)"""
  195. def eval(self) -> Expr:
  196. return -cot(self.variable)
  197. @dataclass
  198. class HyperbolicRule(AtomicRule, ABC):
  199. pass
  200. @dataclass
  201. class SinhRule(HyperbolicRule):
  202. """integrate(sinh(x), x) -> cosh(x)"""
  203. def eval(self) -> Expr:
  204. return cosh(self.variable)
  205. @dataclass
  206. class CoshRule(HyperbolicRule):
  207. """integrate(cosh(x), x) -> sinh(x)"""
  208. def eval(self):
  209. return sinh(self.variable)
  210. @dataclass
  211. class ExpRule(AtomicRule):
  212. """integrate(a**x, x) -> a**x/ln(a)"""
  213. base: Expr
  214. exp: Expr
  215. def eval(self) -> Expr:
  216. return self.integrand / log(self.base)
  217. @dataclass
  218. class ReciprocalRule(AtomicRule):
  219. """integrate(1/x, x) -> ln(x)"""
  220. base: Expr
  221. def eval(self) -> Expr:
  222. return log(self.base)
  223. @dataclass
  224. class ArcsinRule(AtomicRule):
  225. """integrate(1/sqrt(1-x**2), x) -> asin(x)"""
  226. def eval(self) -> Expr:
  227. return asin(self.variable)
  228. @dataclass
  229. class ArcsinhRule(AtomicRule):
  230. """integrate(1/sqrt(1+x**2), x) -> asin(x)"""
  231. def eval(self) -> Expr:
  232. return asinh(self.variable)
  233. @dataclass
  234. class ReciprocalSqrtQuadraticRule(AtomicRule):
  235. """integrate(1/sqrt(a+b*x+c*x**2), x) -> log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)"""
  236. a: Expr
  237. b: Expr
  238. c: Expr
  239. def eval(self) -> Expr:
  240. a, b, c, x = self.a, self.b, self.c, self.variable
  241. return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)
  242. @dataclass
  243. class SqrtQuadraticDenomRule(AtomicRule):
  244. """integrate(poly(x)/sqrt(a+b*x+c*x**2), x)"""
  245. a: Expr
  246. b: Expr
  247. c: Expr
  248. coeffs: list[Expr]
  249. def eval(self) -> Expr:
  250. a, b, c, coeffs, x = self.a, self.b, self.c, self.coeffs.copy(), self.variable
  251. # Integrate poly/sqrt(a+b*x+c*x**2) using recursion.
  252. # coeffs are coefficients of the polynomial.
  253. # Let I_n = x**n/sqrt(a+b*x+c*x**2), then
  254. # I_n = A * x**(n-1)*sqrt(a+b*x+c*x**2) - B * I_{n-1} - C * I_{n-2}
  255. # where A = 1/(n*c), B = (2*n-1)*b/(2*n*c), C = (n-1)*a/(n*c)
  256. # See https://github.com/sympy/sympy/pull/23608 for proof.
  257. result_coeffs = []
  258. coeffs = coeffs.copy()
  259. for i in range(len(coeffs)-2):
  260. n = len(coeffs)-1-i
  261. coeff = coeffs[i]/(c*n)
  262. result_coeffs.append(coeff)
  263. coeffs[i+1] -= (2*n-1)*b/2*coeff
  264. coeffs[i+2] -= (n-1)*a*coeff
  265. d, e = coeffs[-1], coeffs[-2]
  266. s = sqrt(a+b*x+c*x**2)
  267. constant = d-b*e/(2*c)
  268. if constant == 0:
  269. I0 = 0
  270. else:
  271. step = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False)
  272. I0 = constant*step.eval()
  273. return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i)
  274. for i in range(len(result_coeffs))), e/c)*s + I0
  275. @dataclass
  276. class SqrtQuadraticRule(AtomicRule):
  277. """integrate(sqrt(a+b*x+c*x**2), x)"""
  278. a: Expr
  279. b: Expr
  280. c: Expr
  281. def eval(self) -> Expr:
  282. step = sqrt_quadratic_rule(IntegralInfo(self.integrand, self.variable), degenerate=False)
  283. return step.eval()
  284. @dataclass
  285. class AlternativeRule(Rule):
  286. """Multiple ways to do integration."""
  287. alternatives: list[Rule]
  288. def eval(self) -> Expr:
  289. return self.alternatives[0].eval()
  290. def contains_dont_know(self) -> bool:
  291. return any(substep.contains_dont_know() for substep in self.alternatives)
  292. @dataclass
  293. class DontKnowRule(Rule):
  294. """Leave the integral as is."""
  295. def eval(self) -> Expr:
  296. return Integral(self.integrand, self.variable)
  297. def contains_dont_know(self) -> bool:
  298. return True
  299. @dataclass
  300. class DerivativeRule(AtomicRule):
  301. """integrate(f'(x), x) -> f(x)"""
  302. def eval(self) -> Expr:
  303. assert isinstance(self.integrand, Derivative)
  304. variable_count = list(self.integrand.variable_count)
  305. for i, (var, count) in enumerate(variable_count):
  306. if var == self.variable:
  307. variable_count[i] = (var, count - 1)
  308. break
  309. return Derivative(self.integrand.expr, *variable_count)
  310. @dataclass
  311. class RewriteRule(Rule):
  312. """Rewrite integrand to another form that is easier to handle."""
  313. rewritten: Expr
  314. substep: Rule
  315. def eval(self) -> Expr:
  316. return self.substep.eval()
  317. def contains_dont_know(self) -> bool:
  318. return self.substep.contains_dont_know()
  319. @dataclass
  320. class CompleteSquareRule(RewriteRule):
  321. """Rewrite a+b*x+c*x**2 to a-b**2/(4*c) + c*(x+b/(2*c))**2"""
  322. pass
  323. @dataclass
  324. class PiecewiseRule(Rule):
  325. subfunctions: Sequence[tuple[Rule, bool | Boolean]]
  326. def eval(self) -> Expr:
  327. return Piecewise(*[(substep.eval(), cond)
  328. for substep, cond in self.subfunctions])
  329. def contains_dont_know(self) -> bool:
  330. return any(substep.contains_dont_know() for substep, _ in self.subfunctions)
  331. @dataclass
  332. class HeavisideRule(Rule):
  333. harg: Expr
  334. ibnd: Expr
  335. substep: Rule
  336. def eval(self) -> Expr:
  337. # If we are integrating over x and the integrand has the form
  338. # Heaviside(m*x+b)*g(x) == Heaviside(harg)*g(symbol)
  339. # then there needs to be continuity at -b/m == ibnd,
  340. # so we subtract the appropriate term.
  341. result = self.substep.eval()
  342. return Heaviside(self.harg) * (result - result.subs(self.variable, self.ibnd))
  343. def contains_dont_know(self) -> bool:
  344. return self.substep.contains_dont_know()
  345. @dataclass
  346. class DiracDeltaRule(AtomicRule):
  347. n: Expr
  348. a: Expr
  349. b: Expr
  350. def eval(self) -> Expr:
  351. n, a, b, x = self.n, self.a, self.b, self.variable
  352. if n == 0:
  353. return Heaviside(a+b*x)/b
  354. return DiracDelta(a+b*x, n-1)/b
  355. @dataclass
  356. class TrigSubstitutionRule(Rule):
  357. theta: Expr
  358. func: Expr
  359. rewritten: Expr
  360. substep: Rule
  361. restriction: bool | Boolean
  362. def eval(self) -> Expr:
  363. theta, func, x = self.theta, self.func, self.variable
  364. func = func.subs(sec(theta), 1/cos(theta))
  365. func = func.subs(csc(theta), 1/sin(theta))
  366. func = func.subs(cot(theta), 1/tan(theta))
  367. trig_function = list(func.find(TrigonometricFunction))
  368. assert len(trig_function) == 1
  369. trig_function = trig_function[0]
  370. relation = solve(x - func, trig_function)
  371. assert len(relation) == 1
  372. numer, denom = fraction(relation[0])
  373. if isinstance(trig_function, sin):
  374. opposite = numer
  375. hypotenuse = denom
  376. adjacent = sqrt(denom**2 - numer**2)
  377. inverse = asin(relation[0])
  378. elif isinstance(trig_function, cos):
  379. adjacent = numer
  380. hypotenuse = denom
  381. opposite = sqrt(denom**2 - numer**2)
  382. inverse = acos(relation[0])
  383. else: # tan
  384. opposite = numer
  385. adjacent = denom
  386. hypotenuse = sqrt(denom**2 + numer**2)
  387. inverse = atan(relation[0])
  388. substitution = [
  389. (sin(theta), opposite/hypotenuse),
  390. (cos(theta), adjacent/hypotenuse),
  391. (tan(theta), opposite/adjacent),
  392. (theta, inverse)
  393. ]
  394. return Piecewise(
  395. (self.substep.eval().subs(substitution).trigsimp(), self.restriction)
  396. )
  397. def contains_dont_know(self) -> bool:
  398. return self.substep.contains_dont_know()
  399. @dataclass
  400. class ArctanRule(AtomicRule):
  401. """integrate(a/(b*x**2+c), x) -> a/b / sqrt(c/b) * atan(x/sqrt(c/b))"""
  402. a: Expr
  403. b: Expr
  404. c: Expr
  405. def eval(self) -> Expr:
  406. a, b, c, x = self.a, self.b, self.c, self.variable
  407. return a/b / sqrt(c/b) * atan(x/sqrt(c/b))
  408. @dataclass
  409. class OrthogonalPolyRule(AtomicRule, ABC):
  410. n: Expr
  411. @dataclass
  412. class JacobiRule(OrthogonalPolyRule):
  413. a: Expr
  414. b: Expr
  415. def eval(self) -> Expr:
  416. n, a, b, x = self.n, self.a, self.b, self.variable
  417. return Piecewise(
  418. (2*jacobi(n + 1, a - 1, b - 1, x)/(n + a + b), Ne(n + a + b, 0)),
  419. (x, Eq(n, 0)),
  420. ((a + b + 2)*x**2/4 + (a - b)*x/2, Eq(n, 1)))
  421. @dataclass
  422. class GegenbauerRule(OrthogonalPolyRule):
  423. a: Expr
  424. def eval(self) -> Expr:
  425. n, a, x = self.n, self.a, self.variable
  426. return Piecewise(
  427. (gegenbauer(n + 1, a - 1, x)/(2*(a - 1)), Ne(a, 1)),
  428. (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)),
  429. (S.Zero, True))
  430. @dataclass
  431. class ChebyshevTRule(OrthogonalPolyRule):
  432. def eval(self) -> Expr:
  433. n, x = self.n, self.variable
  434. return Piecewise(
  435. ((chebyshevt(n + 1, x)/(n + 1) -
  436. chebyshevt(n - 1, x)/(n - 1))/2, Ne(Abs(n), 1)),
  437. (x**2/2, True))
  438. @dataclass
  439. class ChebyshevURule(OrthogonalPolyRule):
  440. def eval(self) -> Expr:
  441. n, x = self.n, self.variable
  442. return Piecewise(
  443. (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)),
  444. (S.Zero, True))
  445. @dataclass
  446. class LegendreRule(OrthogonalPolyRule):
  447. def eval(self) -> Expr:
  448. n, x = self.n, self.variable
  449. return(legendre(n + 1, x) - legendre(n - 1, x))/(2*n + 1)
  450. @dataclass
  451. class HermiteRule(OrthogonalPolyRule):
  452. def eval(self) -> Expr:
  453. n, x = self.n, self.variable
  454. return hermite(n + 1, x)/(2*(n + 1))
  455. @dataclass
  456. class LaguerreRule(OrthogonalPolyRule):
  457. def eval(self) -> Expr:
  458. n, x = self.n, self.variable
  459. return laguerre(n, x) - laguerre(n + 1, x)
  460. @dataclass
  461. class AssocLaguerreRule(OrthogonalPolyRule):
  462. a: Expr
  463. def eval(self) -> Expr:
  464. return -assoc_laguerre(self.n + 1, self.a - 1, self.variable)
  465. @dataclass
  466. class IRule(AtomicRule, ABC):
  467. a: Expr
  468. b: Expr
  469. @dataclass
  470. class CiRule(IRule):
  471. def eval(self) -> Expr:
  472. a, b, x = self.a, self.b, self.variable
  473. return cos(b)*Ci(a*x) - sin(b)*Si(a*x)
  474. @dataclass
  475. class ChiRule(IRule):
  476. def eval(self) -> Expr:
  477. a, b, x = self.a, self.b, self.variable
  478. return cosh(b)*Chi(a*x) + sinh(b)*Shi(a*x)
  479. @dataclass
  480. class EiRule(IRule):
  481. def eval(self) -> Expr:
  482. a, b, x = self.a, self.b, self.variable
  483. return exp(b)*Ei(a*x)
  484. @dataclass
  485. class SiRule(IRule):
  486. def eval(self) -> Expr:
  487. a, b, x = self.a, self.b, self.variable
  488. return sin(b)*Ci(a*x) + cos(b)*Si(a*x)
  489. @dataclass
  490. class ShiRule(IRule):
  491. def eval(self) -> Expr:
  492. a, b, x = self.a, self.b, self.variable
  493. return sinh(b)*Chi(a*x) + cosh(b)*Shi(a*x)
  494. @dataclass
  495. class LiRule(IRule):
  496. def eval(self) -> Expr:
  497. a, b, x = self.a, self.b, self.variable
  498. return li(a*x + b)/a
  499. @dataclass
  500. class ErfRule(AtomicRule):
  501. a: Expr
  502. b: Expr
  503. c: Expr
  504. def eval(self) -> Expr:
  505. a, b, c, x = self.a, self.b, self.c, self.variable
  506. if a.is_extended_real:
  507. return Piecewise(
  508. (sqrt(S.Pi/(-a))/2 * exp(c - b**2/(4*a)) *
  509. erf((-2*a*x - b)/(2*sqrt(-a))), a < 0),
  510. (sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) *
  511. erfi((2*a*x + b)/(2*sqrt(a))), True))
  512. return sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) * \
  513. erfi((2*a*x + b)/(2*sqrt(a)))
  514. @dataclass
  515. class FresnelCRule(AtomicRule):
  516. a: Expr
  517. b: Expr
  518. c: Expr
  519. def eval(self) -> Expr:
  520. a, b, c, x = self.a, self.b, self.c, self.variable
  521. return sqrt(S.Pi/(2*a)) * (
  522. cos(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)) +
  523. sin(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)))
  524. @dataclass
  525. class FresnelSRule(AtomicRule):
  526. a: Expr
  527. b: Expr
  528. c: Expr
  529. def eval(self) -> Expr:
  530. a, b, c, x = self.a, self.b, self.c, self.variable
  531. return sqrt(S.Pi/(2*a)) * (
  532. cos(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)) -
  533. sin(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)))
  534. @dataclass
  535. class PolylogRule(AtomicRule):
  536. a: Expr
  537. b: Expr
  538. def eval(self) -> Expr:
  539. return polylog(self.b + 1, self.a * self.variable)
  540. @dataclass
  541. class UpperGammaRule(AtomicRule):
  542. a: Expr
  543. e: Expr
  544. def eval(self) -> Expr:
  545. a, e, x = self.a, self.e, self.variable
  546. return x**e * (-a*x)**(-e) * uppergamma(e + 1, -a*x)/a
  547. @dataclass
  548. class EllipticFRule(AtomicRule):
  549. a: Expr
  550. d: Expr
  551. def eval(self) -> Expr:
  552. return elliptic_f(self.variable, self.d/self.a)/sqrt(self.a)
  553. @dataclass
  554. class EllipticERule(AtomicRule):
  555. a: Expr
  556. d: Expr
  557. def eval(self) -> Expr:
  558. return elliptic_e(self.variable, self.d/self.a)*sqrt(self.a)
  559. class IntegralInfo(NamedTuple):
  560. integrand: Expr
  561. symbol: Symbol
  562. def manual_diff(f, symbol):
  563. """Derivative of f in form expected by find_substitutions
  564. SymPy's derivatives for some trig functions (like cot) are not in a form
  565. that works well with finding substitutions; this replaces the
  566. derivatives for those particular forms with something that works better.
  567. """
  568. if f.args:
  569. arg = f.args[0]
  570. if isinstance(f, tan):
  571. return arg.diff(symbol) * sec(arg)**2
  572. elif isinstance(f, cot):
  573. return -arg.diff(symbol) * csc(arg)**2
  574. elif isinstance(f, sec):
  575. return arg.diff(symbol) * sec(arg) * tan(arg)
  576. elif isinstance(f, csc):
  577. return -arg.diff(symbol) * csc(arg) * cot(arg)
  578. elif isinstance(f, Add):
  579. return sum([manual_diff(arg, symbol) for arg in f.args])
  580. elif isinstance(f, Mul):
  581. if len(f.args) == 2 and isinstance(f.args[0], Number):
  582. return f.args[0] * manual_diff(f.args[1], symbol)
  583. return f.diff(symbol)
  584. def manual_subs(expr, *args):
  585. """
  586. A wrapper for `expr.subs(*args)` with additional logic for substitution
  587. of invertible functions.
  588. """
  589. if len(args) == 1:
  590. sequence = args[0]
  591. if isinstance(sequence, (Dict, Mapping)):
  592. sequence = sequence.items()
  593. elif not iterable(sequence):
  594. raise ValueError("Expected an iterable of (old, new) pairs")
  595. elif len(args) == 2:
  596. sequence = [args]
  597. else:
  598. raise ValueError("subs accepts either 1 or 2 arguments")
  599. new_subs = []
  600. for old, new in sequence:
  601. if isinstance(old, log):
  602. # If log(x) = y, then exp(a*log(x)) = exp(a*y)
  603. # that is, x**a = exp(a*y). Replace nontrivial powers of x
  604. # before subs turns them into `exp(y)**a`, but
  605. # do not replace x itself yet, to avoid `log(exp(y))`.
  606. x0 = old.args[0]
  607. expr = expr.replace(lambda x: x.is_Pow and x.base == x0,
  608. lambda x: exp(x.exp*new))
  609. new_subs.append((x0, exp(new)))
  610. return expr.subs(list(sequence) + new_subs)
  611. # Method based on that on SIN, described in "Symbolic Integration: The
  612. # Stormy Decade"
  613. inverse_trig_functions = (atan, asin, acos, acot, acsc, asec)
  614. def find_substitutions(integrand, symbol, u_var):
  615. results = []
  616. def test_subterm(u, u_diff):
  617. if u_diff == 0:
  618. return False
  619. substituted = integrand / u_diff
  620. debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var))
  621. substituted = manual_subs(substituted, u, u_var).cancel()
  622. if substituted.has_free(symbol):
  623. return False
  624. # avoid increasing the degree of a rational function
  625. if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var):
  626. deg_before = max([degree(t, symbol) for t in integrand.as_numer_denom()])
  627. deg_after = max([degree(t, u_var) for t in substituted.as_numer_denom()])
  628. if deg_after > deg_before:
  629. return False
  630. return substituted.as_independent(u_var, as_Add=False)
  631. def exp_subterms(term: Expr):
  632. linear_coeffs = []
  633. terms = []
  634. n = Wild('n', properties=[lambda n: n.is_Integer])
  635. for exp_ in term.find(exp):
  636. arg = exp_.args[0]
  637. if symbol not in arg.free_symbols:
  638. continue
  639. match = arg.match(n*symbol)
  640. if match:
  641. linear_coeffs.append(match[n])
  642. else:
  643. terms.append(exp_)
  644. if linear_coeffs:
  645. terms.append(exp(gcd_list(linear_coeffs)*symbol))
  646. return terms
  647. def possible_subterms(term):
  648. if isinstance(term, (TrigonometricFunction, HyperbolicFunction,
  649. *inverse_trig_functions,
  650. exp, log, Heaviside)):
  651. return [term.args[0]]
  652. elif isinstance(term, (chebyshevt, chebyshevu,
  653. legendre, hermite, laguerre)):
  654. return [term.args[1]]
  655. elif isinstance(term, (gegenbauer, assoc_laguerre)):
  656. return [term.args[2]]
  657. elif isinstance(term, jacobi):
  658. return [term.args[3]]
  659. elif isinstance(term, Mul):
  660. r = []
  661. for u in term.args:
  662. r.append(u)
  663. r.extend(possible_subterms(u))
  664. return r
  665. elif isinstance(term, Pow):
  666. r = [arg for arg in term.args if arg.has(symbol)]
  667. if term.exp.is_Integer:
  668. r.extend([term.base**d for d in primefactors(term.exp)
  669. if 1 < d < abs(term.args[1])])
  670. if term.base.is_Add:
  671. r.extend([t for t in possible_subterms(term.base)
  672. if t.is_Pow])
  673. return r
  674. elif isinstance(term, Add):
  675. r = []
  676. for arg in term.args:
  677. r.append(arg)
  678. r.extend(possible_subterms(arg))
  679. return r
  680. return []
  681. for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))):
  682. if u == symbol:
  683. continue
  684. u_diff = manual_diff(u, symbol)
  685. new_integrand = test_subterm(u, u_diff)
  686. if new_integrand is not False:
  687. constant, new_integrand = new_integrand
  688. if new_integrand == integrand.subs(symbol, u_var):
  689. continue
  690. substitution = (u, constant, new_integrand)
  691. if substitution not in results:
  692. results.append(substitution)
  693. return results
  694. def rewriter(condition, rewrite):
  695. """Strategy that rewrites an integrand."""
  696. def _rewriter(integral):
  697. integrand, symbol = integral
  698. debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol))
  699. if condition(*integral):
  700. rewritten = rewrite(*integral)
  701. if rewritten != integrand:
  702. substep = integral_steps(rewritten, symbol)
  703. if not isinstance(substep, DontKnowRule) and substep:
  704. return RewriteRule(integrand, symbol, rewritten, substep)
  705. return _rewriter
  706. def proxy_rewriter(condition, rewrite):
  707. """Strategy that rewrites an integrand based on some other criteria."""
  708. def _proxy_rewriter(criteria):
  709. criteria, integral = criteria
  710. integrand, symbol = integral
  711. debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria))
  712. args = criteria + list(integral)
  713. if condition(*args):
  714. rewritten = rewrite(*args)
  715. if rewritten != integrand:
  716. return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
  717. return _proxy_rewriter
  718. def multiplexer(conditions):
  719. """Apply the rule that matches the condition, else None"""
  720. def multiplexer_rl(expr):
  721. for key, rule in conditions.items():
  722. if key(expr):
  723. return rule(expr)
  724. return multiplexer_rl
  725. def alternatives(*rules):
  726. """Strategy that makes an AlternativeRule out of multiple possible results."""
  727. def _alternatives(integral):
  728. alts = []
  729. count = 0
  730. debug("List of Alternative Rules")
  731. for rule in rules:
  732. count = count + 1
  733. debug("Rule {}: {}".format(count, rule))
  734. result = rule(integral)
  735. if (result and not isinstance(result, DontKnowRule) and
  736. result != integral and result not in alts):
  737. alts.append(result)
  738. if len(alts) == 1:
  739. return alts[0]
  740. elif alts:
  741. doable = [rule for rule in alts if not rule.contains_dont_know()]
  742. if doable:
  743. return AlternativeRule(*integral, doable)
  744. else:
  745. return AlternativeRule(*integral, alts)
  746. return _alternatives
  747. def constant_rule(integral):
  748. return ConstantRule(*integral)
  749. def power_rule(integral):
  750. integrand, symbol = integral
  751. base, expt = integrand.as_base_exp()
  752. if symbol not in expt.free_symbols and isinstance(base, Symbol):
  753. if simplify(expt + 1) == 0:
  754. return ReciprocalRule(integrand, symbol, base)
  755. return PowerRule(integrand, symbol, base, expt)
  756. elif symbol not in base.free_symbols and isinstance(expt, Symbol):
  757. rule = ExpRule(integrand, symbol, base, expt)
  758. if fuzzy_not(log(base).is_zero):
  759. return rule
  760. elif log(base).is_zero:
  761. return ConstantRule(1, symbol)
  762. return PiecewiseRule(integrand, symbol, [
  763. (rule, Ne(log(base), 0)),
  764. (ConstantRule(1, symbol), True)
  765. ])
  766. def exp_rule(integral):
  767. integrand, symbol = integral
  768. if isinstance(integrand.args[0], Symbol):
  769. return ExpRule(integrand, symbol, E, integrand.args[0])
  770. def orthogonal_poly_rule(integral):
  771. orthogonal_poly_classes = {
  772. jacobi: JacobiRule,
  773. gegenbauer: GegenbauerRule,
  774. chebyshevt: ChebyshevTRule,
  775. chebyshevu: ChebyshevURule,
  776. legendre: LegendreRule,
  777. hermite: HermiteRule,
  778. laguerre: LaguerreRule,
  779. assoc_laguerre: AssocLaguerreRule
  780. }
  781. orthogonal_poly_var_index = {
  782. jacobi: 3,
  783. gegenbauer: 2,
  784. assoc_laguerre: 2
  785. }
  786. integrand, symbol = integral
  787. for klass in orthogonal_poly_classes:
  788. if isinstance(integrand, klass):
  789. var_index = orthogonal_poly_var_index.get(klass, 1)
  790. if (integrand.args[var_index] is symbol and not
  791. any(v.has(symbol) for v in integrand.args[:var_index])):
  792. return orthogonal_poly_classes[klass](integrand, symbol, *integrand.args[:var_index])
  793. _special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = []
  794. _wilds = []
  795. _symbol = Dummy('x')
  796. def special_function_rule(integral):
  797. integrand, symbol = integral
  798. if not _special_function_patterns:
  799. a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero])
  800. b = Wild('b', exclude=[_symbol])
  801. c = Wild('c', exclude=[_symbol])
  802. d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero])
  803. e = Wild('e', exclude=[_symbol], properties=[
  804. lambda x: not (x.is_nonnegative and x.is_integer)])
  805. _wilds.extend((a, b, c, d, e))
  806. # patterns consist of a SymPy class, a wildcard expr, an optional
  807. # condition coded as a lambda (when Wild properties are not enough),
  808. # followed by an applicable rule
  809. linear_pattern = a*_symbol + b
  810. quadratic_pattern = a*_symbol**2 + b*_symbol + c
  811. _special_function_patterns.extend((
  812. (Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule),
  813. (Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule),
  814. (Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule),
  815. (Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule),
  816. (Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule),
  817. (Pow, 1/log(linear_pattern, evaluate=False), None, LiRule),
  818. (exp, exp(quadratic_pattern, evaluate=False), None, ErfRule),
  819. (sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule),
  820. (cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule),
  821. (Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule),
  822. (Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule),
  823. (Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2),
  824. lambda a, d: a != d, EllipticFRule),
  825. (Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2),
  826. lambda a, d: a != d, EllipticERule),
  827. ))
  828. _integrand = integrand.subs(symbol, _symbol)
  829. for type_, pattern, constraint, rule in _special_function_patterns:
  830. if isinstance(_integrand, type_):
  831. match = _integrand.match(pattern)
  832. if match:
  833. wild_vals = tuple(match.get(w) for w in _wilds
  834. if match.get(w) is not None)
  835. if constraint is None or constraint(*wild_vals):
  836. return rule(integrand, symbol, *wild_vals)
  837. def _add_degenerate_step(generic_cond, generic_step: Rule, degenerate_step: Rule | None) -> Rule:
  838. if degenerate_step is None:
  839. return generic_step
  840. if isinstance(generic_step, PiecewiseRule):
  841. subfunctions = [(substep, (cond & generic_cond).simplify())
  842. for substep, cond in generic_step.subfunctions]
  843. else:
  844. subfunctions = [(generic_step, generic_cond)]
  845. if isinstance(degenerate_step, PiecewiseRule):
  846. subfunctions += degenerate_step.subfunctions
  847. else:
  848. subfunctions.append((degenerate_step, S.true))
  849. return PiecewiseRule(generic_step.integrand, generic_step.variable, subfunctions)
  850. def nested_pow_rule(integral: IntegralInfo):
  851. # nested (c*(a+b*x)**d)**e
  852. integrand, x = integral
  853. a_ = Wild('a', exclude=[x])
  854. b_ = Wild('b', exclude=[x, 0])
  855. pattern = a_+b_*x
  856. generic_cond = S.true
  857. class NoMatch(Exception):
  858. pass
  859. def _get_base_exp(expr: Expr) -> tuple[Expr, Expr]:
  860. if not expr.has_free(x):
  861. return S.One, S.Zero
  862. if expr.is_Mul:
  863. _, terms = expr.as_coeff_mul()
  864. if not terms:
  865. return S.One, S.Zero
  866. results = [_get_base_exp(term) for term in terms]
  867. bases = {b for b, _ in results}
  868. bases.discard(S.One)
  869. if len(bases) == 1:
  870. return bases.pop(), Add(*(e for _, e in results))
  871. raise NoMatch
  872. if expr.is_Pow:
  873. b, e = expr.base, expr.exp # type: ignore
  874. if e.has_free(x):
  875. raise NoMatch
  876. base_, sub_exp = _get_base_exp(b)
  877. return base_, sub_exp * e
  878. match = expr.match(pattern)
  879. if match:
  880. a, b = match[a_], match[b_]
  881. base_ = x + a/b
  882. nonlocal generic_cond
  883. generic_cond = Ne(b, 0)
  884. return base_, S.One
  885. raise NoMatch
  886. try:
  887. base, exp_ = _get_base_exp(integrand)
  888. except NoMatch:
  889. return
  890. if generic_cond is S.true:
  891. degenerate_step = None
  892. else:
  893. # equivalent with subs(b, 0) but no need to find b
  894. degenerate_step = ConstantRule(integrand.subs(x, 0), x)
  895. generic_step = NestedPowRule(integrand, x, base, exp_)
  896. return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
  897. def inverse_trig_rule(integral: IntegralInfo, degenerate=True):
  898. """
  899. Set degenerate=False on recursive call where coefficient of quadratic term
  900. is assumed non-zero.
  901. """
  902. integrand, symbol = integral
  903. base, exp = integrand.as_base_exp()
  904. a = Wild('a', exclude=[symbol])
  905. b = Wild('b', exclude=[symbol])
  906. c = Wild('c', exclude=[symbol, 0])
  907. match = base.match(a + b*symbol + c*symbol**2)
  908. if not match:
  909. return
  910. def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h) -> Rule:
  911. u_var = Dummy("u")
  912. rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2) # a>0, c>0
  913. quadratic_base = sqrt(c/a)*(symbol-h)
  914. constant = 1/sqrt(c)
  915. u_func = None
  916. if quadratic_base is not symbol:
  917. u_func = quadratic_base
  918. quadratic_base = u_var
  919. standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2)
  920. substep = RuleClass(standard_form, quadratic_base)
  921. if constant != 1:
  922. substep = ConstantTimesRule(constant*standard_form, symbol, constant, standard_form, substep)
  923. if u_func is not None:
  924. substep = URule(rewritten, symbol, u_var, u_func, substep)
  925. if h != 0:
  926. substep = CompleteSquareRule(integrand, symbol, rewritten, substep)
  927. return substep
  928. a, b, c = [match.get(i, S.Zero) for i in (a, b, c)]
  929. generic_cond = Ne(c, 0)
  930. if not degenerate or generic_cond is S.true:
  931. degenerate_step = None
  932. elif b.is_zero:
  933. degenerate_step = ConstantRule(a ** exp, symbol)
  934. else:
  935. degenerate_step = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol))
  936. if simplify(2*exp + 1) == 0:
  937. h, k = -b/(2*c), a - b**2/(4*c) # rewrite base to k + c*(symbol-h)**2
  938. non_square_cond = Ne(k, 0)
  939. square_step = None
  940. if non_square_cond is not S.true:
  941. square_step = NestedPowRule(1/sqrt(c*(symbol-h)**2), symbol, symbol-h, S.NegativeOne)
  942. if non_square_cond is S.false:
  943. return square_step
  944. generic_step = ReciprocalSqrtQuadraticRule(integrand, symbol, a, b, c)
  945. step = _add_degenerate_step(non_square_cond, generic_step, square_step)
  946. if k.is_real and c.is_real:
  947. # list of ((rule, base_exp, a, sign_a, b, sign_b), condition)
  948. rules = []
  949. for args, cond in ( # don't apply ArccoshRule to x**2-1
  950. ((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)), # 1-x**2
  951. ((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)), # 1+x**2
  952. ):
  953. if cond is S.true:
  954. return make_inverse_trig(*args)
  955. if cond is not S.false:
  956. rules.append((make_inverse_trig(*args), cond))
  957. if rules:
  958. if not k.is_positive: # conditions are not thorough, need fall back rule
  959. rules.append((generic_step, S.true))
  960. step = PiecewiseRule(integrand, symbol, rules)
  961. else:
  962. step = generic_step
  963. return _add_degenerate_step(generic_cond, step, degenerate_step)
  964. if exp == S.Half:
  965. step = SqrtQuadraticRule(integrand, symbol, a, b, c)
  966. return _add_degenerate_step(generic_cond, step, degenerate_step)
  967. def add_rule(integral):
  968. integrand, symbol = integral
  969. results = [integral_steps(g, symbol)
  970. for g in integrand.as_ordered_terms()]
  971. return None if None in results else AddRule(integrand, symbol, results)
  972. def mul_rule(integral: IntegralInfo):
  973. integrand, symbol = integral
  974. # Constant times function case
  975. coeff, f = integrand.as_independent(symbol)
  976. if coeff != 1:
  977. next_step = integral_steps(f, symbol)
  978. if next_step is not None:
  979. return ConstantTimesRule(integrand, symbol, coeff, f, next_step)
  980. def _parts_rule(integrand, symbol) -> tuple[Expr, Expr, Expr, Expr, Rule] | None:
  981. # LIATE rule:
  982. # log, inverse trig, algebraic, trigonometric, exponential
  983. def pull_out_algebraic(integrand):
  984. integrand = integrand.cancel().together()
  985. # iterating over Piecewise args would not work here
  986. algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul
  987. else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)])
  988. if algebraic:
  989. u = Mul(*algebraic)
  990. dv = (integrand / u).cancel()
  991. return u, dv
  992. def pull_out_u(*functions) -> Callable[[Expr], tuple[Expr, Expr] | None]:
  993. def pull_out_u_rl(integrand: Expr) -> tuple[Expr, Expr] | None:
  994. if any(integrand.has(f) for f in functions):
  995. args = [arg for arg in integrand.args
  996. if any(isinstance(arg, cls) for cls in functions)]
  997. if args:
  998. u = Mul(*args)
  999. dv = integrand / u
  1000. return u, dv
  1001. return None
  1002. return pull_out_u_rl
  1003. liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions),
  1004. pull_out_algebraic, pull_out_u(sin, cos),
  1005. pull_out_u(exp)]
  1006. dummy = Dummy("temporary")
  1007. # we can integrate log(x) and atan(x) by setting dv = 1
  1008. if isinstance(integrand, (log, *inverse_trig_functions)):
  1009. integrand = dummy * integrand
  1010. for index, rule in enumerate(liate_rules):
  1011. result = rule(integrand)
  1012. if result:
  1013. u, dv = result
  1014. # Don't pick u to be a constant if possible
  1015. if symbol not in u.free_symbols and not u.has(dummy):
  1016. return None
  1017. u = u.subs(dummy, 1)
  1018. dv = dv.subs(dummy, 1)
  1019. # Don't pick a non-polynomial algebraic to be differentiated
  1020. if rule == pull_out_algebraic and not u.is_polynomial(symbol):
  1021. return None
  1022. # Don't trade one logarithm for another
  1023. if isinstance(u, log):
  1024. rec_dv = 1/dv
  1025. if (rec_dv.is_polynomial(symbol) and
  1026. degree(rec_dv, symbol) == 1):
  1027. return None
  1028. # Can integrate a polynomial times OrthogonalPolynomial
  1029. if rule == pull_out_algebraic:
  1030. if dv.is_Derivative or dv.has(TrigonometricFunction) or \
  1031. isinstance(dv, OrthogonalPolynomial):
  1032. v_step = integral_steps(dv, symbol)
  1033. if v_step.contains_dont_know():
  1034. return None
  1035. else:
  1036. du = u.diff(symbol)
  1037. v = v_step.eval()
  1038. return u, dv, v, du, v_step
  1039. # make sure dv is amenable to integration
  1040. accept = False
  1041. if index < 2: # log and inverse trig are usually worth trying
  1042. accept = True
  1043. elif (rule == pull_out_algebraic and dv.args and
  1044. all(isinstance(a, (sin, cos, exp))
  1045. for a in dv.args)):
  1046. accept = True
  1047. else:
  1048. for lrule in liate_rules[index + 1:]:
  1049. r = lrule(integrand)
  1050. if r and r[0].subs(dummy, 1).equals(dv):
  1051. accept = True
  1052. break
  1053. if accept:
  1054. du = u.diff(symbol)
  1055. v_step = integral_steps(simplify(dv), symbol)
  1056. if not v_step.contains_dont_know():
  1057. v = v_step.eval()
  1058. return u, dv, v, du, v_step
  1059. return None
  1060. def parts_rule(integral):
  1061. integrand, symbol = integral
  1062. constant, integrand = integrand.as_coeff_Mul()
  1063. result = _parts_rule(integrand, symbol)
  1064. steps = []
  1065. if result:
  1066. u, dv, v, du, v_step = result
  1067. debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step))
  1068. steps.append(result)
  1069. if isinstance(v, Integral):
  1070. return
  1071. # Set a limit on the number of times u can be used
  1072. if isinstance(u, (sin, cos, exp, sinh, cosh)):
  1073. cachekey = u.xreplace({symbol: _cache_dummy})
  1074. if _parts_u_cache[cachekey] > 2:
  1075. return
  1076. _parts_u_cache[cachekey] += 1
  1077. # Try cyclic integration by parts a few times
  1078. for _ in range(4):
  1079. debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand))
  1080. coefficient = ((v * du) / integrand).cancel()
  1081. if coefficient == 1:
  1082. break
  1083. if symbol not in coefficient.free_symbols:
  1084. rule = CyclicPartsRule(integrand, symbol,
  1085. [PartsRule(None, None, u, dv, v_step, None)
  1086. for (u, dv, v, du, v_step) in steps],
  1087. (-1) ** len(steps) * coefficient)
  1088. if (constant != 1) and rule:
  1089. rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule)
  1090. return rule
  1091. # _parts_rule is sensitive to constants, factor it out
  1092. next_constant, next_integrand = (v * du).as_coeff_Mul()
  1093. result = _parts_rule(next_integrand, symbol)
  1094. if result:
  1095. u, dv, v, du, v_step = result
  1096. u *= next_constant
  1097. du *= next_constant
  1098. steps.append((u, dv, v, du, v_step))
  1099. else:
  1100. break
  1101. def make_second_step(steps, integrand):
  1102. if steps:
  1103. u, dv, v, du, v_step = steps[0]
  1104. return PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du))
  1105. return integral_steps(integrand, symbol)
  1106. if steps:
  1107. u, dv, v, du, v_step = steps[0]
  1108. rule = PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du))
  1109. if (constant != 1) and rule:
  1110. rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule)
  1111. return rule
  1112. def trig_rule(integral):
  1113. integrand, symbol = integral
  1114. if integrand == sin(symbol):
  1115. return SinRule(integrand, symbol)
  1116. if integrand == cos(symbol):
  1117. return CosRule(integrand, symbol)
  1118. if integrand == sec(symbol)**2:
  1119. return Sec2Rule(integrand, symbol)
  1120. if integrand == csc(symbol)**2:
  1121. return Csc2Rule(integrand, symbol)
  1122. if isinstance(integrand, tan):
  1123. rewritten = sin(*integrand.args) / cos(*integrand.args)
  1124. elif isinstance(integrand, cot):
  1125. rewritten = cos(*integrand.args) / sin(*integrand.args)
  1126. elif isinstance(integrand, sec):
  1127. arg = integrand.args[0]
  1128. rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) /
  1129. (sec(arg) + tan(arg)))
  1130. elif isinstance(integrand, csc):
  1131. arg = integrand.args[0]
  1132. rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) /
  1133. (csc(arg) + cot(arg)))
  1134. else:
  1135. return
  1136. return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
  1137. def trig_product_rule(integral: IntegralInfo):
  1138. integrand, symbol = integral
  1139. if integrand == sec(symbol) * tan(symbol):
  1140. return SecTanRule(integrand, symbol)
  1141. if integrand == csc(symbol) * cot(symbol):
  1142. return CscCotRule(integrand, symbol)
  1143. def quadratic_denom_rule(integral):
  1144. integrand, symbol = integral
  1145. a = Wild('a', exclude=[symbol])
  1146. b = Wild('b', exclude=[symbol])
  1147. c = Wild('c', exclude=[symbol])
  1148. match = integrand.match(a / (b * symbol ** 2 + c))
  1149. if match:
  1150. a, b, c = match[a], match[b], match[c]
  1151. general_rule = ArctanRule(integrand, symbol, a, b, c)
  1152. if b.is_extended_real and c.is_extended_real:
  1153. positive_cond = c/b > 0
  1154. if positive_cond is S.true:
  1155. return general_rule
  1156. coeff = a/(2*sqrt(-c)*sqrt(b))
  1157. constant = sqrt(-c/b)
  1158. r1 = 1/(symbol-constant)
  1159. r2 = 1/(symbol+constant)
  1160. log_steps = [ReciprocalRule(r1, symbol, symbol-constant),
  1161. ConstantTimesRule(-r2, symbol, -1, r2, ReciprocalRule(r2, symbol, symbol+constant))]
  1162. rewritten = sub = r1 - r2
  1163. negative_step = AddRule(sub, symbol, log_steps)
  1164. if coeff != 1:
  1165. rewritten = Mul(coeff, sub, evaluate=False)
  1166. negative_step = ConstantTimesRule(rewritten, symbol, coeff, sub, negative_step)
  1167. negative_step = RewriteRule(integrand, symbol, rewritten, negative_step)
  1168. if positive_cond is S.false:
  1169. return negative_step
  1170. return PiecewiseRule(integrand, symbol, [(general_rule, positive_cond), (negative_step, S.true)])
  1171. return general_rule
  1172. d = Wild('d', exclude=[symbol])
  1173. match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d))
  1174. if match2:
  1175. b, c = match2[b], match2[c]
  1176. if b.is_zero:
  1177. return
  1178. u = Dummy('u')
  1179. u_func = symbol + c/(2*b)
  1180. integrand2 = integrand.subs(symbol, u - c / (2*b))
  1181. next_step = integral_steps(integrand2, u)
  1182. if next_step:
  1183. return URule(integrand2, symbol, u, u_func, next_step)
  1184. else:
  1185. return
  1186. e = Wild('e', exclude=[symbol])
  1187. match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e))
  1188. if match3:
  1189. a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e]
  1190. if c.is_zero:
  1191. return
  1192. denominator = c * symbol**2 + d * symbol + e
  1193. const = a/(2*c)
  1194. numer1 = (2*c*symbol+d)
  1195. numer2 = - const*d + b
  1196. u = Dummy('u')
  1197. step1 = URule(integrand, symbol,
  1198. u, denominator, integral_steps(u**(-1), u))
  1199. if const != 1:
  1200. step1 = ConstantTimesRule(const*numer1/denominator, symbol,
  1201. const, numer1/denominator, step1)
  1202. if numer2.is_zero:
  1203. return step1
  1204. step2 = integral_steps(numer2/denominator, symbol)
  1205. substeps = AddRule(integrand, symbol, [step1, step2])
  1206. rewriten = const*numer1/denominator+numer2/denominator
  1207. return RewriteRule(integrand, symbol, rewriten, substeps)
  1208. return
  1209. def sqrt_linear_rule(integral: IntegralInfo):
  1210. """
  1211. Substitute common (a+b*x)**(1/n)
  1212. """
  1213. integrand, x = integral
  1214. a = Wild('a', exclude=[x])
  1215. b = Wild('b', exclude=[x, 0])
  1216. a0 = b0 = 0
  1217. bases, qs, bs = [], [], []
  1218. for pow_ in integrand.find(Pow): # collect all (a+b*x)**(p/q)
  1219. base, exp_ = pow_.base, pow_.exp
  1220. if exp_.is_Integer or x not in base.free_symbols: # skip 1/x and sqrt(2)
  1221. continue
  1222. if not exp_.is_Rational: # exclude x**pi
  1223. return
  1224. match = base.match(a+b*x)
  1225. if not match: # skip non-linear
  1226. continue # for sqrt(x+sqrt(x)), although base is non-linear, we can still substitute sqrt(x)
  1227. a1, b1 = match[a], match[b]
  1228. if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative: # cannot transform sqrt(x) to sqrt(x+1) or sqrt(-x)
  1229. return
  1230. if b0 == 0 or (b0/b1 > 1) is S.true: # choose the latter of sqrt(2*x) and sqrt(x) as representative
  1231. a0, b0 = a1, b1
  1232. bases.append(base)
  1233. bs.append(b1)
  1234. qs.append(exp_.q)
  1235. if b0 == 0: # no such pattern found
  1236. return
  1237. q0: Integer = lcm_list(qs)
  1238. u_x = (a0 + b0*x)**(1/q0)
  1239. u = Dummy("u")
  1240. substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q)
  1241. for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0)
  1242. substep = integral_steps(substituted*u**(q0-1)*q0/b0, u)
  1243. if not substep.contains_dont_know():
  1244. step: Rule = URule(integrand, x, u, u_x, substep)
  1245. generic_cond = Ne(b0, 0)
  1246. if generic_cond is not S.true: # possible degenerate case
  1247. simplified = integrand.subs({b: 0 for b in bs})
  1248. degenerate_step = integral_steps(simplified, x)
  1249. step = PiecewiseRule(integrand, x, [(step, generic_cond), (degenerate_step, S.true)])
  1250. return step
  1251. def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True):
  1252. integrand, x = integral
  1253. a = Wild('a', exclude=[x])
  1254. b = Wild('b', exclude=[x])
  1255. c = Wild('c', exclude=[x, 0])
  1256. f = Wild('f')
  1257. n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd])
  1258. match = integrand.match(f*sqrt(a+b*x+c*x**2)**n)
  1259. if not match:
  1260. return
  1261. a, b, c, f, n = match[a], match[b], match[c], match[f], match[n]
  1262. f_poly = f.as_poly(x)
  1263. if f_poly is None:
  1264. return
  1265. generic_cond = Ne(c, 0)
  1266. if not degenerate or generic_cond is S.true:
  1267. degenerate_step = None
  1268. elif b.is_zero:
  1269. degenerate_step = integral_steps(f*sqrt(a)**n, x)
  1270. else:
  1271. degenerate_step = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x))
  1272. def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr):
  1273. denom = sqrt(a+b*x+c*x**2)
  1274. deg = numer_poly.degree()
  1275. if deg <= 1:
  1276. # integrand == (d+e*x)/sqrt(a+b*x+c*x**2)
  1277. e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr())
  1278. # rewrite numerator to A*(2*c*x+b) + B
  1279. A = e/(2*c)
  1280. B = d-A*b
  1281. pre_substitute = (2*c*x+b)/denom
  1282. constant_step: Rule | None = None
  1283. linear_step: Rule | None = None
  1284. if A != 0:
  1285. u = Dummy("u")
  1286. pow_rule = PowerRule(1/sqrt(u), u, u, -S.Half)
  1287. linear_step = URule(pre_substitute, x, u, a+b*x+c*x**2, pow_rule)
  1288. if A != 1:
  1289. linear_step = ConstantTimesRule(A*pre_substitute, x, A, pre_substitute, linear_step)
  1290. if B != 0:
  1291. constant_step = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False)
  1292. if B != 1:
  1293. constant_step = ConstantTimesRule(B/denom, x, B, 1/denom, constant_step) # type: ignore
  1294. if linear_step and constant_step:
  1295. add = Add(A*pre_substitute, B/denom, evaluate=False)
  1296. step: Rule | None = RewriteRule(integrand, x, add, AddRule(add, x, [linear_step, constant_step]))
  1297. else:
  1298. step = linear_step or constant_step
  1299. else:
  1300. coeffs = numer_poly.all_coeffs()
  1301. step = SqrtQuadraticDenomRule(integrand, x, a, b, c, coeffs)
  1302. return step
  1303. if n > 0: # rewrite poly * sqrt(s)**(2*k-1) to poly*s**k / sqrt(s)
  1304. numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2)
  1305. rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2)
  1306. substep = sqrt_quadratic_denom_rule(numer_poly, rewritten)
  1307. generic_step = RewriteRule(integrand, x, rewritten, substep)
  1308. elif n == -1:
  1309. generic_step = sqrt_quadratic_denom_rule(f_poly, integrand)
  1310. else:
  1311. return # todo: handle n < -1 case
  1312. return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
  1313. def hyperbolic_rule(integral: tuple[Expr, Symbol]):
  1314. integrand, symbol = integral
  1315. if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol:
  1316. if integrand.func == sinh:
  1317. return SinhRule(integrand, symbol)
  1318. if integrand.func == cosh:
  1319. return CoshRule(integrand, symbol)
  1320. u = Dummy('u')
  1321. if integrand.func == tanh:
  1322. rewritten = sinh(symbol)/cosh(symbol)
  1323. return RewriteRule(integrand, symbol, rewritten,
  1324. URule(rewritten, symbol, u, cosh(symbol), ReciprocalRule(1/u, u, u)))
  1325. if integrand.func == coth:
  1326. rewritten = cosh(symbol)/sinh(symbol)
  1327. return RewriteRule(integrand, symbol, rewritten,
  1328. URule(rewritten, symbol, u, sinh(symbol), ReciprocalRule(1/u, u, u)))
  1329. else:
  1330. rewritten = integrand.rewrite(tanh)
  1331. if integrand.func == sech:
  1332. return RewriteRule(integrand, symbol, rewritten,
  1333. URule(rewritten, symbol, u, tanh(symbol/2),
  1334. ArctanRule(2/(u**2 + 1), u, S(2), S.One, S.One)))
  1335. if integrand.func == csch:
  1336. return RewriteRule(integrand, symbol, rewritten,
  1337. URule(rewritten, symbol, u, tanh(symbol/2),
  1338. ReciprocalRule(1/u, u, u)))
  1339. @cacheit
  1340. def make_wilds(symbol):
  1341. a = Wild('a', exclude=[symbol])
  1342. b = Wild('b', exclude=[symbol])
  1343. m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
  1344. n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
  1345. return a, b, m, n
  1346. @cacheit
  1347. def sincos_pattern(symbol):
  1348. a, b, m, n = make_wilds(symbol)
  1349. pattern = sin(a*symbol)**m * cos(b*symbol)**n
  1350. return pattern, a, b, m, n
  1351. @cacheit
  1352. def tansec_pattern(symbol):
  1353. a, b, m, n = make_wilds(symbol)
  1354. pattern = tan(a*symbol)**m * sec(b*symbol)**n
  1355. return pattern, a, b, m, n
  1356. @cacheit
  1357. def cotcsc_pattern(symbol):
  1358. a, b, m, n = make_wilds(symbol)
  1359. pattern = cot(a*symbol)**m * csc(b*symbol)**n
  1360. return pattern, a, b, m, n
  1361. @cacheit
  1362. def heaviside_pattern(symbol):
  1363. m = Wild('m', exclude=[symbol])
  1364. b = Wild('b', exclude=[symbol])
  1365. g = Wild('g')
  1366. pattern = Heaviside(m*symbol + b) * g
  1367. return pattern, m, b, g
  1368. def uncurry(func):
  1369. def uncurry_rl(args):
  1370. return func(*args)
  1371. return uncurry_rl
  1372. def trig_rewriter(rewrite):
  1373. def trig_rewriter_rl(args):
  1374. a, b, m, n, integrand, symbol = args
  1375. rewritten = rewrite(a, b, m, n, integrand, symbol)
  1376. if rewritten != integrand:
  1377. return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
  1378. return trig_rewriter_rl
  1379. sincos_botheven_condition = uncurry(
  1380. lambda a, b, m, n, i, s: m.is_even and n.is_even and
  1381. m.is_nonnegative and n.is_nonnegative)
  1382. sincos_botheven = trig_rewriter(
  1383. lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) *
  1384. (((1 + cos(2*b*symbol)) / 2) ** (n / 2)) ))
  1385. sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3)
  1386. sincos_sinodd = trig_rewriter(
  1387. lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) *
  1388. sin(a*symbol) *
  1389. cos(b*symbol) ** n))
  1390. sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3)
  1391. sincos_cosodd = trig_rewriter(
  1392. lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) *
  1393. cos(b*symbol) *
  1394. sin(a*symbol) ** m))
  1395. tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
  1396. tansec_seceven = trig_rewriter(
  1397. lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) *
  1398. sec(b*symbol)**2 *
  1399. tan(a*symbol) ** m ))
  1400. tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
  1401. tansec_tanodd = trig_rewriter(
  1402. lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) *
  1403. tan(a*symbol) *
  1404. sec(b*symbol) ** n ))
  1405. tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0)
  1406. tan_tansquared = trig_rewriter(
  1407. lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1))
  1408. cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
  1409. cotcsc_csceven = trig_rewriter(
  1410. lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) *
  1411. csc(b*symbol)**2 *
  1412. cot(a*symbol) ** m ))
  1413. cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
  1414. cotcsc_cotodd = trig_rewriter(
  1415. lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) *
  1416. cot(a*symbol) *
  1417. csc(b*symbol) ** n ))
  1418. def trig_sincos_rule(integral):
  1419. integrand, symbol = integral
  1420. if any(integrand.has(f) for f in (sin, cos)):
  1421. pattern, a, b, m, n = sincos_pattern(symbol)
  1422. match = integrand.match(pattern)
  1423. if not match:
  1424. return
  1425. return multiplexer({
  1426. sincos_botheven_condition: sincos_botheven,
  1427. sincos_sinodd_condition: sincos_sinodd,
  1428. sincos_cosodd_condition: sincos_cosodd
  1429. })(tuple(
  1430. [match.get(i, S.Zero) for i in (a, b, m, n)] +
  1431. [integrand, symbol]))
  1432. def trig_tansec_rule(integral):
  1433. integrand, symbol = integral
  1434. integrand = integrand.subs({
  1435. 1 / cos(symbol): sec(symbol)
  1436. })
  1437. if any(integrand.has(f) for f in (tan, sec)):
  1438. pattern, a, b, m, n = tansec_pattern(symbol)
  1439. match = integrand.match(pattern)
  1440. if not match:
  1441. return
  1442. return multiplexer({
  1443. tansec_tanodd_condition: tansec_tanodd,
  1444. tansec_seceven_condition: tansec_seceven,
  1445. tan_tansquared_condition: tan_tansquared
  1446. })(tuple(
  1447. [match.get(i, S.Zero) for i in (a, b, m, n)] +
  1448. [integrand, symbol]))
  1449. def trig_cotcsc_rule(integral):
  1450. integrand, symbol = integral
  1451. integrand = integrand.subs({
  1452. 1 / sin(symbol): csc(symbol),
  1453. 1 / tan(symbol): cot(symbol),
  1454. cos(symbol) / tan(symbol): cot(symbol)
  1455. })
  1456. if any(integrand.has(f) for f in (cot, csc)):
  1457. pattern, a, b, m, n = cotcsc_pattern(symbol)
  1458. match = integrand.match(pattern)
  1459. if not match:
  1460. return
  1461. return multiplexer({
  1462. cotcsc_cotodd_condition: cotcsc_cotodd,
  1463. cotcsc_csceven_condition: cotcsc_csceven
  1464. })(tuple(
  1465. [match.get(i, S.Zero) for i in (a, b, m, n)] +
  1466. [integrand, symbol]))
  1467. def trig_sindouble_rule(integral):
  1468. integrand, symbol = integral
  1469. a = Wild('a', exclude=[sin(2*symbol)])
  1470. match = integrand.match(sin(2*symbol)*a)
  1471. if match:
  1472. sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol)
  1473. return integral_steps(integrand * sin_double, symbol)
  1474. def trig_powers_products_rule(integral):
  1475. return do_one(null_safe(trig_sincos_rule),
  1476. null_safe(trig_tansec_rule),
  1477. null_safe(trig_cotcsc_rule),
  1478. null_safe(trig_sindouble_rule))(integral)
  1479. def trig_substitution_rule(integral):
  1480. integrand, symbol = integral
  1481. A = Wild('a', exclude=[0, symbol])
  1482. B = Wild('b', exclude=[0, symbol])
  1483. theta = Dummy("theta")
  1484. target_pattern = A + B*symbol**2
  1485. matches = integrand.find(target_pattern)
  1486. for expr in matches:
  1487. match = expr.match(target_pattern)
  1488. a = match.get(A, S.Zero)
  1489. b = match.get(B, S.Zero)
  1490. a_positive = ((a.is_number and a > 0) or a.is_positive)
  1491. b_positive = ((b.is_number and b > 0) or b.is_positive)
  1492. a_negative = ((a.is_number and a < 0) or a.is_negative)
  1493. b_negative = ((b.is_number and b < 0) or b.is_negative)
  1494. x_func = None
  1495. if a_positive and b_positive:
  1496. # a**2 + b*x**2. Assume sec(theta) > 0, -pi/2 < theta < pi/2
  1497. x_func = (sqrt(a)/sqrt(b)) * tan(theta)
  1498. # Do not restrict the domain: tan(theta) takes on any real
  1499. # value on the interval -pi/2 < theta < pi/2 so x takes on
  1500. # any value
  1501. restriction = True
  1502. elif a_positive and b_negative:
  1503. # a**2 - b*x**2. Assume cos(theta) > 0, -pi/2 < theta < pi/2
  1504. constant = sqrt(a)/sqrt(-b)
  1505. x_func = constant * sin(theta)
  1506. restriction = And(symbol > -constant, symbol < constant)
  1507. elif a_negative and b_positive:
  1508. # b*x**2 - a**2. Assume sin(theta) > 0, 0 < theta < pi
  1509. constant = sqrt(-a)/sqrt(b)
  1510. x_func = constant * sec(theta)
  1511. restriction = And(symbol > -constant, symbol < constant)
  1512. if x_func:
  1513. # Manually simplify sqrt(trig(theta)**2) to trig(theta)
  1514. # Valid due to assumed domain restriction
  1515. substitutions = {}
  1516. for f in [sin, cos, tan,
  1517. sec, csc, cot]:
  1518. substitutions[sqrt(f(theta)**2)] = f(theta)
  1519. substitutions[sqrt(f(theta)**(-2))] = 1/f(theta)
  1520. replaced = integrand.subs(symbol, x_func).trigsimp()
  1521. replaced = manual_subs(replaced, substitutions)
  1522. if not replaced.has(symbol):
  1523. replaced *= manual_diff(x_func, theta)
  1524. replaced = replaced.trigsimp()
  1525. secants = replaced.find(1/cos(theta))
  1526. if secants:
  1527. replaced = replaced.xreplace({
  1528. 1/cos(theta): sec(theta)
  1529. })
  1530. substep = integral_steps(replaced, theta)
  1531. if not substep.contains_dont_know():
  1532. return TrigSubstitutionRule(integrand, symbol,
  1533. theta, x_func, replaced, substep, restriction)
  1534. def heaviside_rule(integral):
  1535. integrand, symbol = integral
  1536. pattern, m, b, g = heaviside_pattern(symbol)
  1537. match = integrand.match(pattern)
  1538. if match and 0 != match[g]:
  1539. # f = Heaviside(m*x + b)*g
  1540. substep = integral_steps(match[g], symbol)
  1541. m, b = match[m], match[b]
  1542. return HeavisideRule(integrand, symbol, m*symbol + b, -b/m, substep)
  1543. def dirac_delta_rule(integral: IntegralInfo):
  1544. integrand, x = integral
  1545. if len(integrand.args) == 1:
  1546. n = S.Zero
  1547. else:
  1548. n = integrand.args[1]
  1549. if not n.is_Integer or n < 0:
  1550. return
  1551. a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0])
  1552. match = integrand.args[0].match(a+b*x)
  1553. if not match:
  1554. return
  1555. a, b = match[a], match[b]
  1556. generic_cond = Ne(b, 0)
  1557. if generic_cond is S.true:
  1558. degenerate_step = None
  1559. else:
  1560. degenerate_step = ConstantRule(DiracDelta(a, n), x)
  1561. generic_step = DiracDeltaRule(integrand, x, n, a, b)
  1562. return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
  1563. def substitution_rule(integral):
  1564. integrand, symbol = integral
  1565. u_var = Dummy("u")
  1566. substitutions = find_substitutions(integrand, symbol, u_var)
  1567. count = 0
  1568. if substitutions:
  1569. debug("List of Substitution Rules")
  1570. ways = []
  1571. for u_func, c, substituted in substitutions:
  1572. subrule = integral_steps(substituted, u_var)
  1573. count = count + 1
  1574. debug("Rule {}: {}".format(count, subrule))
  1575. if subrule.contains_dont_know():
  1576. continue
  1577. if simplify(c - 1) != 0:
  1578. _, denom = c.as_numer_denom()
  1579. if subrule:
  1580. subrule = ConstantTimesRule(c * substituted, u_var, c, substituted, subrule)
  1581. if denom.free_symbols:
  1582. piecewise = []
  1583. could_be_zero = []
  1584. if isinstance(denom, Mul):
  1585. could_be_zero = denom.args
  1586. else:
  1587. could_be_zero.append(denom)
  1588. for expr in could_be_zero:
  1589. if not fuzzy_not(expr.is_zero):
  1590. substep = integral_steps(manual_subs(integrand, expr, 0), symbol)
  1591. if substep:
  1592. piecewise.append((
  1593. substep,
  1594. Eq(expr, 0)
  1595. ))
  1596. piecewise.append((subrule, True))
  1597. subrule = PiecewiseRule(substituted, symbol, piecewise)
  1598. ways.append(URule(integrand, symbol, u_var, u_func, subrule))
  1599. if len(ways) > 1:
  1600. return AlternativeRule(integrand, symbol, ways)
  1601. elif ways:
  1602. return ways[0]
  1603. partial_fractions_rule = rewriter(
  1604. lambda integrand, symbol: integrand.is_rational_function(),
  1605. lambda integrand, symbol: integrand.apart(symbol))
  1606. cancel_rule = rewriter(
  1607. # lambda integrand, symbol: integrand.is_algebraic_expr(),
  1608. # lambda integrand, symbol: isinstance(integrand, Mul),
  1609. lambda integrand, symbol: True,
  1610. lambda integrand, symbol: integrand.cancel())
  1611. distribute_expand_rule = rewriter(
  1612. lambda integrand, symbol: (
  1613. all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)
  1614. or isinstance(integrand, Pow)
  1615. or isinstance(integrand, Mul)),
  1616. lambda integrand, symbol: integrand.expand())
  1617. trig_expand_rule = rewriter(
  1618. # If there are trig functions with different arguments, expand them
  1619. lambda integrand, symbol: (
  1620. len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1),
  1621. lambda integrand, symbol: integrand.expand(trig=True))
  1622. def derivative_rule(integral):
  1623. integrand = integral[0]
  1624. diff_variables = integrand.variables
  1625. undifferentiated_function = integrand.expr
  1626. integrand_variables = undifferentiated_function.free_symbols
  1627. if integral.symbol in integrand_variables:
  1628. if integral.symbol in diff_variables:
  1629. return DerivativeRule(*integral)
  1630. else:
  1631. return DontKnowRule(integrand, integral.symbol)
  1632. else:
  1633. return ConstantRule(*integral)
  1634. def rewrites_rule(integral):
  1635. integrand, symbol = integral
  1636. if integrand.match(1/cos(symbol)):
  1637. rewritten = integrand.subs(1/cos(symbol), sec(symbol))
  1638. return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
  1639. def fallback_rule(integral):
  1640. return DontKnowRule(*integral)
  1641. # Cache is used to break cyclic integrals.
  1642. # Need to use the same dummy variable in cached expressions for them to match.
  1643. # Also record "u" of integration by parts, to avoid infinite repetition.
  1644. _integral_cache: dict[Expr, Expr | None] = {}
  1645. _parts_u_cache: dict[Expr, int] = defaultdict(int)
  1646. _cache_dummy = Dummy("z")
  1647. def integral_steps(integrand, symbol, **options):
  1648. """Returns the steps needed to compute an integral.
  1649. Explanation
  1650. ===========
  1651. This function attempts to mirror what a student would do by hand as
  1652. closely as possible.
  1653. SymPy Gamma uses this to provide a step-by-step explanation of an
  1654. integral. The code it uses to format the results of this function can be
  1655. found at
  1656. https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py.
  1657. Examples
  1658. ========
  1659. >>> from sympy import exp, sin
  1660. >>> from sympy.integrals.manualintegrate import integral_steps
  1661. >>> from sympy.abc import x
  1662. >>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \
  1663. # doctest: +NORMALIZE_WHITESPACE
  1664. URule(integrand=exp(x)/(exp(2*x) + 1), variable=x, u_var=_u, u_func=exp(x),
  1665. substep=ArctanRule(integrand=1/(_u**2 + 1), variable=_u, a=1, b=1, c=1))
  1666. >>> print(repr(integral_steps(sin(x), x))) \
  1667. # doctest: +NORMALIZE_WHITESPACE
  1668. SinRule(integrand=sin(x), variable=x)
  1669. >>> print(repr(integral_steps((x**2 + 3)**2, x))) \
  1670. # doctest: +NORMALIZE_WHITESPACE
  1671. RewriteRule(integrand=(x**2 + 3)**2, variable=x, rewritten=x**4 + 6*x**2 + 9,
  1672. substep=AddRule(integrand=x**4 + 6*x**2 + 9, variable=x,
  1673. substeps=[PowerRule(integrand=x**4, variable=x, base=x, exp=4),
  1674. ConstantTimesRule(integrand=6*x**2, variable=x, constant=6, other=x**2,
  1675. substep=PowerRule(integrand=x**2, variable=x, base=x, exp=2)),
  1676. ConstantRule(integrand=9, variable=x)]))
  1677. Returns
  1678. =======
  1679. rule : Rule
  1680. The first step; most rules have substeps that must also be
  1681. considered. These substeps can be evaluated using ``manualintegrate``
  1682. to obtain a result.
  1683. """
  1684. cachekey = integrand.xreplace({symbol: _cache_dummy})
  1685. if cachekey in _integral_cache:
  1686. if _integral_cache[cachekey] is None:
  1687. # Stop this attempt, because it leads around in a loop
  1688. return DontKnowRule(integrand, symbol)
  1689. else:
  1690. # TODO: This is for future development, as currently
  1691. # _integral_cache gets no values other than None
  1692. return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol),
  1693. symbol)
  1694. else:
  1695. _integral_cache[cachekey] = None
  1696. integral = IntegralInfo(integrand, symbol)
  1697. def key(integral):
  1698. integrand = integral.integrand
  1699. if symbol not in integrand.free_symbols:
  1700. return Number
  1701. for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial):
  1702. if isinstance(integrand, cls):
  1703. return cls
  1704. return type(integrand)
  1705. def integral_is_subclass(*klasses):
  1706. def _integral_is_subclass(integral):
  1707. k = key(integral)
  1708. return k and issubclass(k, klasses)
  1709. return _integral_is_subclass
  1710. result = do_one(
  1711. null_safe(special_function_rule),
  1712. null_safe(switch(key, {
  1713. Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule),
  1714. null_safe(sqrt_linear_rule),
  1715. null_safe(quadratic_denom_rule)),
  1716. Symbol: power_rule,
  1717. exp: exp_rule,
  1718. Add: add_rule,
  1719. Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule),
  1720. null_safe(heaviside_rule), null_safe(quadratic_denom_rule),
  1721. null_safe(sqrt_linear_rule),
  1722. null_safe(sqrt_quadratic_rule)),
  1723. Derivative: derivative_rule,
  1724. TrigonometricFunction: trig_rule,
  1725. Heaviside: heaviside_rule,
  1726. DiracDelta: dirac_delta_rule,
  1727. OrthogonalPolynomial: orthogonal_poly_rule,
  1728. Number: constant_rule
  1729. })),
  1730. do_one(
  1731. null_safe(trig_rule),
  1732. null_safe(hyperbolic_rule),
  1733. null_safe(alternatives(
  1734. rewrites_rule,
  1735. substitution_rule,
  1736. condition(
  1737. integral_is_subclass(Mul, Pow),
  1738. partial_fractions_rule),
  1739. condition(
  1740. integral_is_subclass(Mul, Pow),
  1741. cancel_rule),
  1742. condition(
  1743. integral_is_subclass(Mul, log,
  1744. *inverse_trig_functions),
  1745. parts_rule),
  1746. condition(
  1747. integral_is_subclass(Mul, Pow),
  1748. distribute_expand_rule),
  1749. trig_powers_products_rule,
  1750. trig_expand_rule
  1751. )),
  1752. null_safe(condition(integral_is_subclass(Mul, Pow), nested_pow_rule)),
  1753. null_safe(trig_substitution_rule)
  1754. ),
  1755. fallback_rule)(integral)
  1756. del _integral_cache[cachekey]
  1757. return result
  1758. def manualintegrate(f, var):
  1759. """manualintegrate(f, var)
  1760. Explanation
  1761. ===========
  1762. Compute indefinite integral of a single variable using an algorithm that
  1763. resembles what a student would do by hand.
  1764. Unlike :func:`~.integrate`, var can only be a single symbol.
  1765. Examples
  1766. ========
  1767. >>> from sympy import sin, cos, tan, exp, log, integrate
  1768. >>> from sympy.integrals.manualintegrate import manualintegrate
  1769. >>> from sympy.abc import x
  1770. >>> manualintegrate(1 / x, x)
  1771. log(x)
  1772. >>> integrate(1/x)
  1773. log(x)
  1774. >>> manualintegrate(log(x), x)
  1775. x*log(x) - x
  1776. >>> integrate(log(x))
  1777. x*log(x) - x
  1778. >>> manualintegrate(exp(x) / (1 + exp(2 * x)), x)
  1779. atan(exp(x))
  1780. >>> integrate(exp(x) / (1 + exp(2 * x)))
  1781. RootSum(4*_z**2 + 1, Lambda(_i, _i*log(2*_i + exp(x))))
  1782. >>> manualintegrate(cos(x)**4 * sin(x), x)
  1783. -cos(x)**5/5
  1784. >>> integrate(cos(x)**4 * sin(x), x)
  1785. -cos(x)**5/5
  1786. >>> manualintegrate(cos(x)**4 * sin(x)**3, x)
  1787. cos(x)**7/7 - cos(x)**5/5
  1788. >>> integrate(cos(x)**4 * sin(x)**3, x)
  1789. cos(x)**7/7 - cos(x)**5/5
  1790. >>> manualintegrate(tan(x), x)
  1791. -log(cos(x))
  1792. >>> integrate(tan(x), x)
  1793. -log(cos(x))
  1794. See Also
  1795. ========
  1796. sympy.integrals.integrals.integrate
  1797. sympy.integrals.integrals.Integral.doit
  1798. sympy.integrals.integrals.Integral
  1799. """
  1800. result = integral_steps(f, var).eval()
  1801. # Clear the cache of u-parts
  1802. _parts_u_cache.clear()
  1803. # If we got Piecewise with two parts, put generic first
  1804. if isinstance(result, Piecewise) and len(result.args) == 2:
  1805. cond = result.args[0][1]
  1806. if isinstance(cond, Eq) and result.args[1][1] == True:
  1807. result = result.func(
  1808. (result.args[1][0], Ne(*cond.args)),
  1809. (result.args[0][0], True))
  1810. return result