12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171 |
- """Integration method that emulates by-hand techniques.
- This module also provides functionality to get the steps used to evaluate a
- particular integral, in the ``integral_steps`` function. This will return
- nested ``Rule`` s representing the integration rules used.
- Each ``Rule`` class represents a (maybe parametrized) integration rule, e.g.
- ``SinRule`` for integrating ``sin(x)`` and ``ReciprocalSqrtQuadraticRule``
- for integrating ``1/sqrt(a+b*x+c*x**2)``. The ``eval`` method returns the
- integration result.
- The ``manualintegrate`` function computes the integral by calling ``eval``
- on the rule returned by ``integral_steps``.
- The integrator can be extended with new heuristics and evaluation
- techniques. To do so, extend the ``Rule`` class, implement ``eval`` method,
- then write a function that accepts an ``IntegralInfo`` object and returns
- either a ``Rule`` instance or ``None``. If the new technique requires a new
- match, add the key and call to the antiderivative function to integral_steps.
- To enable simple substitutions, add the match to find_substitutions.
- """
- from __future__ import annotations
- from typing import NamedTuple, Type, Callable, Sequence
- from abc import ABC, abstractmethod
- from dataclasses import dataclass
- from collections import defaultdict
- from collections.abc import Mapping
- from sympy.core.add import Add
- from sympy.core.cache import cacheit
- from sympy.core.containers import Dict
- from sympy.core.expr import Expr
- from sympy.core.function import Derivative
- from sympy.core.logic import fuzzy_not
- from sympy.core.mul import Mul
- from sympy.core.numbers import Integer, Number, E
- from sympy.core.power import Pow
- from sympy.core.relational import Eq, Ne, Boolean
- from sympy.core.singleton import S
- from sympy.core.symbol import Dummy, Symbol, Wild
- from sympy.functions.elementary.complexes import Abs
- from sympy.functions.elementary.exponential import exp, log
- from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch,
- cosh, coth, sech, sinh, tanh, asinh)
- from sympy.functions.elementary.miscellaneous import sqrt
- from sympy.functions.elementary.piecewise import Piecewise
- from sympy.functions.elementary.trigonometric import (TrigonometricFunction,
- cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec)
- from sympy.functions.special.delta_functions import Heaviside, DiracDelta
- from sympy.functions.special.error_functions import (erf, erfi, fresnelc,
- fresnels, Ci, Chi, Si, Shi, Ei, li)
- from sympy.functions.special.gamma_functions import uppergamma
- from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f
- from sympy.functions.special.polynomials import (chebyshevt, chebyshevu,
- legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi,
- OrthogonalPolynomial)
- from sympy.functions.special.zeta_functions import polylog
- from .integrals import Integral
- from sympy.logic.boolalg import And
- from sympy.ntheory.factor_ import primefactors
- from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly
- from sympy.simplify.radsimp import fraction
- from sympy.simplify.simplify import simplify
- from sympy.solvers.solvers import solve
- from sympy.strategies.core import switch, do_one, null_safe, condition
- from sympy.utilities.iterables import iterable
- from sympy.utilities.misc import debug
- @dataclass
- class Rule(ABC):
- integrand: Expr
- variable: Symbol
- @abstractmethod
- def eval(self) -> Expr:
- pass
- @abstractmethod
- def contains_dont_know(self) -> bool:
- pass
- @dataclass
- class AtomicRule(Rule, ABC):
- """A simple rule that does not depend on other rules"""
- def contains_dont_know(self) -> bool:
- return False
- @dataclass
- class ConstantRule(AtomicRule):
- """integrate(a, x) -> a*x"""
- def eval(self) -> Expr:
- return self.integrand * self.variable
- @dataclass
- class ConstantTimesRule(Rule):
- """integrate(a*f(x), x) -> a*integrate(f(x), x)"""
- constant: Expr
- other: Expr
- substep: Rule
- def eval(self) -> Expr:
- return self.constant * self.substep.eval()
- def contains_dont_know(self) -> bool:
- return self.substep.contains_dont_know()
- @dataclass
- class PowerRule(AtomicRule):
- """integrate(x**a, x)"""
- base: Expr
- exp: Expr
- def eval(self) -> Expr:
- return Piecewise(
- ((self.base**(self.exp + 1))/(self.exp + 1), Ne(self.exp, -1)),
- (log(self.base), True),
- )
- @dataclass
- class NestedPowRule(AtomicRule):
- """integrate((x**a)**b, x)"""
- base: Expr
- exp: Expr
- def eval(self) -> Expr:
- m = self.base * self.integrand
- return Piecewise((m / (self.exp + 1), Ne(self.exp, -1)),
- (m * log(self.base), True))
- @dataclass
- class AddRule(Rule):
- """integrate(f(x) + g(x), x) -> integrate(f(x), x) + integrate(g(x), x)"""
- substeps: list[Rule]
- def eval(self) -> Expr:
- return Add(*(substep.eval() for substep in self.substeps))
- def contains_dont_know(self) -> bool:
- return any(substep.contains_dont_know() for substep in self.substeps)
- @dataclass
- class URule(Rule):
- """integrate(f(g(x))*g'(x), x) -> integrate(f(u), u), u = g(x)"""
- u_var: Symbol
- u_func: Expr
- substep: Rule
- def eval(self) -> Expr:
- result = self.substep.eval()
- if self.u_func.is_Pow:
- base, exp_ = self.u_func.as_base_exp()
- if exp_ == -1:
- # avoid needless -log(1/x) from substitution
- result = result.subs(log(self.u_var), -log(base))
- return result.subs(self.u_var, self.u_func)
- def contains_dont_know(self) -> bool:
- return self.substep.contains_dont_know()
- @dataclass
- class PartsRule(Rule):
- """integrate(u(x)*v'(x), x) -> u(x)*v(x) - integrate(u'(x)*v(x), x)"""
- u: Symbol
- dv: Expr
- v_step: Rule
- second_step: Rule | None # None when is a substep of CyclicPartsRule
- def eval(self) -> Expr:
- assert self.second_step is not None
- v = self.v_step.eval()
- return self.u * v - self.second_step.eval()
- def contains_dont_know(self) -> bool:
- return self.v_step.contains_dont_know() or (
- self.second_step is not None and self.second_step.contains_dont_know())
- @dataclass
- class CyclicPartsRule(Rule):
- """Apply PartsRule multiple times to integrate exp(x)*sin(x)"""
- parts_rules: list[PartsRule]
- coefficient: Expr
- def eval(self) -> Expr:
- result = []
- sign = 1
- for rule in self.parts_rules:
- result.append(sign * rule.u * rule.v_step.eval())
- sign *= -1
- return Add(*result) / (1 - self.coefficient)
- def contains_dont_know(self) -> bool:
- return any(substep.contains_dont_know() for substep in self.parts_rules)
- @dataclass
- class TrigRule(AtomicRule, ABC):
- pass
- @dataclass
- class SinRule(TrigRule):
- """integrate(sin(x), x) -> -cos(x)"""
- def eval(self) -> Expr:
- return -cos(self.variable)
- @dataclass
- class CosRule(TrigRule):
- """integrate(cos(x), x) -> sin(x)"""
- def eval(self) -> Expr:
- return sin(self.variable)
- @dataclass
- class SecTanRule(TrigRule):
- """integrate(sec(x)*tan(x), x) -> sec(x)"""
- def eval(self) -> Expr:
- return sec(self.variable)
- @dataclass
- class CscCotRule(TrigRule):
- """integrate(csc(x)*cot(x), x) -> -csc(x)"""
- def eval(self) -> Expr:
- return -csc(self.variable)
- @dataclass
- class Sec2Rule(TrigRule):
- """integrate(sec(x)**2, x) -> tan(x)"""
- def eval(self) -> Expr:
- return tan(self.variable)
- @dataclass
- class Csc2Rule(TrigRule):
- """integrate(csc(x)**2, x) -> -cot(x)"""
- def eval(self) -> Expr:
- return -cot(self.variable)
- @dataclass
- class HyperbolicRule(AtomicRule, ABC):
- pass
- @dataclass
- class SinhRule(HyperbolicRule):
- """integrate(sinh(x), x) -> cosh(x)"""
- def eval(self) -> Expr:
- return cosh(self.variable)
- @dataclass
- class CoshRule(HyperbolicRule):
- """integrate(cosh(x), x) -> sinh(x)"""
- def eval(self):
- return sinh(self.variable)
- @dataclass
- class ExpRule(AtomicRule):
- """integrate(a**x, x) -> a**x/ln(a)"""
- base: Expr
- exp: Expr
- def eval(self) -> Expr:
- return self.integrand / log(self.base)
- @dataclass
- class ReciprocalRule(AtomicRule):
- """integrate(1/x, x) -> ln(x)"""
- base: Expr
- def eval(self) -> Expr:
- return log(self.base)
- @dataclass
- class ArcsinRule(AtomicRule):
- """integrate(1/sqrt(1-x**2), x) -> asin(x)"""
- def eval(self) -> Expr:
- return asin(self.variable)
- @dataclass
- class ArcsinhRule(AtomicRule):
- """integrate(1/sqrt(1+x**2), x) -> asin(x)"""
- def eval(self) -> Expr:
- return asinh(self.variable)
- @dataclass
- class ReciprocalSqrtQuadraticRule(AtomicRule):
- """integrate(1/sqrt(a+b*x+c*x**2), x) -> log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)"""
- a: Expr
- b: Expr
- c: Expr
- def eval(self) -> Expr:
- a, b, c, x = self.a, self.b, self.c, self.variable
- return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)
- @dataclass
- class SqrtQuadraticDenomRule(AtomicRule):
- """integrate(poly(x)/sqrt(a+b*x+c*x**2), x)"""
- a: Expr
- b: Expr
- c: Expr
- coeffs: list[Expr]
- def eval(self) -> Expr:
- a, b, c, coeffs, x = self.a, self.b, self.c, self.coeffs.copy(), self.variable
- # Integrate poly/sqrt(a+b*x+c*x**2) using recursion.
- # coeffs are coefficients of the polynomial.
- # Let I_n = x**n/sqrt(a+b*x+c*x**2), then
- # I_n = A * x**(n-1)*sqrt(a+b*x+c*x**2) - B * I_{n-1} - C * I_{n-2}
- # where A = 1/(n*c), B = (2*n-1)*b/(2*n*c), C = (n-1)*a/(n*c)
- # See https://github.com/sympy/sympy/pull/23608 for proof.
- result_coeffs = []
- coeffs = coeffs.copy()
- for i in range(len(coeffs)-2):
- n = len(coeffs)-1-i
- coeff = coeffs[i]/(c*n)
- result_coeffs.append(coeff)
- coeffs[i+1] -= (2*n-1)*b/2*coeff
- coeffs[i+2] -= (n-1)*a*coeff
- d, e = coeffs[-1], coeffs[-2]
- s = sqrt(a+b*x+c*x**2)
- constant = d-b*e/(2*c)
- if constant == 0:
- I0 = 0
- else:
- step = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False)
- I0 = constant*step.eval()
- return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i)
- for i in range(len(result_coeffs))), e/c)*s + I0
- @dataclass
- class SqrtQuadraticRule(AtomicRule):
- """integrate(sqrt(a+b*x+c*x**2), x)"""
- a: Expr
- b: Expr
- c: Expr
- def eval(self) -> Expr:
- step = sqrt_quadratic_rule(IntegralInfo(self.integrand, self.variable), degenerate=False)
- return step.eval()
- @dataclass
- class AlternativeRule(Rule):
- """Multiple ways to do integration."""
- alternatives: list[Rule]
- def eval(self) -> Expr:
- return self.alternatives[0].eval()
- def contains_dont_know(self) -> bool:
- return any(substep.contains_dont_know() for substep in self.alternatives)
- @dataclass
- class DontKnowRule(Rule):
- """Leave the integral as is."""
- def eval(self) -> Expr:
- return Integral(self.integrand, self.variable)
- def contains_dont_know(self) -> bool:
- return True
- @dataclass
- class DerivativeRule(AtomicRule):
- """integrate(f'(x), x) -> f(x)"""
- def eval(self) -> Expr:
- assert isinstance(self.integrand, Derivative)
- variable_count = list(self.integrand.variable_count)
- for i, (var, count) in enumerate(variable_count):
- if var == self.variable:
- variable_count[i] = (var, count - 1)
- break
- return Derivative(self.integrand.expr, *variable_count)
- @dataclass
- class RewriteRule(Rule):
- """Rewrite integrand to another form that is easier to handle."""
- rewritten: Expr
- substep: Rule
- def eval(self) -> Expr:
- return self.substep.eval()
- def contains_dont_know(self) -> bool:
- return self.substep.contains_dont_know()
- @dataclass
- class CompleteSquareRule(RewriteRule):
- """Rewrite a+b*x+c*x**2 to a-b**2/(4*c) + c*(x+b/(2*c))**2"""
- pass
- @dataclass
- class PiecewiseRule(Rule):
- subfunctions: Sequence[tuple[Rule, bool | Boolean]]
- def eval(self) -> Expr:
- return Piecewise(*[(substep.eval(), cond)
- for substep, cond in self.subfunctions])
- def contains_dont_know(self) -> bool:
- return any(substep.contains_dont_know() for substep, _ in self.subfunctions)
- @dataclass
- class HeavisideRule(Rule):
- harg: Expr
- ibnd: Expr
- substep: Rule
- def eval(self) -> Expr:
- # If we are integrating over x and the integrand has the form
- # Heaviside(m*x+b)*g(x) == Heaviside(harg)*g(symbol)
- # then there needs to be continuity at -b/m == ibnd,
- # so we subtract the appropriate term.
- result = self.substep.eval()
- return Heaviside(self.harg) * (result - result.subs(self.variable, self.ibnd))
- def contains_dont_know(self) -> bool:
- return self.substep.contains_dont_know()
- @dataclass
- class DiracDeltaRule(AtomicRule):
- n: Expr
- a: Expr
- b: Expr
- def eval(self) -> Expr:
- n, a, b, x = self.n, self.a, self.b, self.variable
- if n == 0:
- return Heaviside(a+b*x)/b
- return DiracDelta(a+b*x, n-1)/b
- @dataclass
- class TrigSubstitutionRule(Rule):
- theta: Expr
- func: Expr
- rewritten: Expr
- substep: Rule
- restriction: bool | Boolean
- def eval(self) -> Expr:
- theta, func, x = self.theta, self.func, self.variable
- func = func.subs(sec(theta), 1/cos(theta))
- func = func.subs(csc(theta), 1/sin(theta))
- func = func.subs(cot(theta), 1/tan(theta))
- trig_function = list(func.find(TrigonometricFunction))
- assert len(trig_function) == 1
- trig_function = trig_function[0]
- relation = solve(x - func, trig_function)
- assert len(relation) == 1
- numer, denom = fraction(relation[0])
- if isinstance(trig_function, sin):
- opposite = numer
- hypotenuse = denom
- adjacent = sqrt(denom**2 - numer**2)
- inverse = asin(relation[0])
- elif isinstance(trig_function, cos):
- adjacent = numer
- hypotenuse = denom
- opposite = sqrt(denom**2 - numer**2)
- inverse = acos(relation[0])
- else: # tan
- opposite = numer
- adjacent = denom
- hypotenuse = sqrt(denom**2 + numer**2)
- inverse = atan(relation[0])
- substitution = [
- (sin(theta), opposite/hypotenuse),
- (cos(theta), adjacent/hypotenuse),
- (tan(theta), opposite/adjacent),
- (theta, inverse)
- ]
- return Piecewise(
- (self.substep.eval().subs(substitution).trigsimp(), self.restriction)
- )
- def contains_dont_know(self) -> bool:
- return self.substep.contains_dont_know()
- @dataclass
- class ArctanRule(AtomicRule):
- """integrate(a/(b*x**2+c), x) -> a/b / sqrt(c/b) * atan(x/sqrt(c/b))"""
- a: Expr
- b: Expr
- c: Expr
- def eval(self) -> Expr:
- a, b, c, x = self.a, self.b, self.c, self.variable
- return a/b / sqrt(c/b) * atan(x/sqrt(c/b))
- @dataclass
- class OrthogonalPolyRule(AtomicRule, ABC):
- n: Expr
- @dataclass
- class JacobiRule(OrthogonalPolyRule):
- a: Expr
- b: Expr
- def eval(self) -> Expr:
- n, a, b, x = self.n, self.a, self.b, self.variable
- return Piecewise(
- (2*jacobi(n + 1, a - 1, b - 1, x)/(n + a + b), Ne(n + a + b, 0)),
- (x, Eq(n, 0)),
- ((a + b + 2)*x**2/4 + (a - b)*x/2, Eq(n, 1)))
- @dataclass
- class GegenbauerRule(OrthogonalPolyRule):
- a: Expr
- def eval(self) -> Expr:
- n, a, x = self.n, self.a, self.variable
- return Piecewise(
- (gegenbauer(n + 1, a - 1, x)/(2*(a - 1)), Ne(a, 1)),
- (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)),
- (S.Zero, True))
- @dataclass
- class ChebyshevTRule(OrthogonalPolyRule):
- def eval(self) -> Expr:
- n, x = self.n, self.variable
- return Piecewise(
- ((chebyshevt(n + 1, x)/(n + 1) -
- chebyshevt(n - 1, x)/(n - 1))/2, Ne(Abs(n), 1)),
- (x**2/2, True))
- @dataclass
- class ChebyshevURule(OrthogonalPolyRule):
- def eval(self) -> Expr:
- n, x = self.n, self.variable
- return Piecewise(
- (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)),
- (S.Zero, True))
- @dataclass
- class LegendreRule(OrthogonalPolyRule):
- def eval(self) -> Expr:
- n, x = self.n, self.variable
- return(legendre(n + 1, x) - legendre(n - 1, x))/(2*n + 1)
- @dataclass
- class HermiteRule(OrthogonalPolyRule):
- def eval(self) -> Expr:
- n, x = self.n, self.variable
- return hermite(n + 1, x)/(2*(n + 1))
- @dataclass
- class LaguerreRule(OrthogonalPolyRule):
- def eval(self) -> Expr:
- n, x = self.n, self.variable
- return laguerre(n, x) - laguerre(n + 1, x)
- @dataclass
- class AssocLaguerreRule(OrthogonalPolyRule):
- a: Expr
- def eval(self) -> Expr:
- return -assoc_laguerre(self.n + 1, self.a - 1, self.variable)
- @dataclass
- class IRule(AtomicRule, ABC):
- a: Expr
- b: Expr
- @dataclass
- class CiRule(IRule):
- def eval(self) -> Expr:
- a, b, x = self.a, self.b, self.variable
- return cos(b)*Ci(a*x) - sin(b)*Si(a*x)
- @dataclass
- class ChiRule(IRule):
- def eval(self) -> Expr:
- a, b, x = self.a, self.b, self.variable
- return cosh(b)*Chi(a*x) + sinh(b)*Shi(a*x)
- @dataclass
- class EiRule(IRule):
- def eval(self) -> Expr:
- a, b, x = self.a, self.b, self.variable
- return exp(b)*Ei(a*x)
- @dataclass
- class SiRule(IRule):
- def eval(self) -> Expr:
- a, b, x = self.a, self.b, self.variable
- return sin(b)*Ci(a*x) + cos(b)*Si(a*x)
- @dataclass
- class ShiRule(IRule):
- def eval(self) -> Expr:
- a, b, x = self.a, self.b, self.variable
- return sinh(b)*Chi(a*x) + cosh(b)*Shi(a*x)
- @dataclass
- class LiRule(IRule):
- def eval(self) -> Expr:
- a, b, x = self.a, self.b, self.variable
- return li(a*x + b)/a
- @dataclass
- class ErfRule(AtomicRule):
- a: Expr
- b: Expr
- c: Expr
- def eval(self) -> Expr:
- a, b, c, x = self.a, self.b, self.c, self.variable
- if a.is_extended_real:
- return Piecewise(
- (sqrt(S.Pi/(-a))/2 * exp(c - b**2/(4*a)) *
- erf((-2*a*x - b)/(2*sqrt(-a))), a < 0),
- (sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) *
- erfi((2*a*x + b)/(2*sqrt(a))), True))
- return sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) * \
- erfi((2*a*x + b)/(2*sqrt(a)))
- @dataclass
- class FresnelCRule(AtomicRule):
- a: Expr
- b: Expr
- c: Expr
- def eval(self) -> Expr:
- a, b, c, x = self.a, self.b, self.c, self.variable
- return sqrt(S.Pi/(2*a)) * (
- cos(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)) +
- sin(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)))
- @dataclass
- class FresnelSRule(AtomicRule):
- a: Expr
- b: Expr
- c: Expr
- def eval(self) -> Expr:
- a, b, c, x = self.a, self.b, self.c, self.variable
- return sqrt(S.Pi/(2*a)) * (
- cos(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)) -
- sin(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)))
- @dataclass
- class PolylogRule(AtomicRule):
- a: Expr
- b: Expr
- def eval(self) -> Expr:
- return polylog(self.b + 1, self.a * self.variable)
- @dataclass
- class UpperGammaRule(AtomicRule):
- a: Expr
- e: Expr
- def eval(self) -> Expr:
- a, e, x = self.a, self.e, self.variable
- return x**e * (-a*x)**(-e) * uppergamma(e + 1, -a*x)/a
- @dataclass
- class EllipticFRule(AtomicRule):
- a: Expr
- d: Expr
- def eval(self) -> Expr:
- return elliptic_f(self.variable, self.d/self.a)/sqrt(self.a)
- @dataclass
- class EllipticERule(AtomicRule):
- a: Expr
- d: Expr
- def eval(self) -> Expr:
- return elliptic_e(self.variable, self.d/self.a)*sqrt(self.a)
- class IntegralInfo(NamedTuple):
- integrand: Expr
- symbol: Symbol
- def manual_diff(f, symbol):
- """Derivative of f in form expected by find_substitutions
- SymPy's derivatives for some trig functions (like cot) are not in a form
- that works well with finding substitutions; this replaces the
- derivatives for those particular forms with something that works better.
- """
- if f.args:
- arg = f.args[0]
- if isinstance(f, tan):
- return arg.diff(symbol) * sec(arg)**2
- elif isinstance(f, cot):
- return -arg.diff(symbol) * csc(arg)**2
- elif isinstance(f, sec):
- return arg.diff(symbol) * sec(arg) * tan(arg)
- elif isinstance(f, csc):
- return -arg.diff(symbol) * csc(arg) * cot(arg)
- elif isinstance(f, Add):
- return sum([manual_diff(arg, symbol) for arg in f.args])
- elif isinstance(f, Mul):
- if len(f.args) == 2 and isinstance(f.args[0], Number):
- return f.args[0] * manual_diff(f.args[1], symbol)
- return f.diff(symbol)
- def manual_subs(expr, *args):
- """
- A wrapper for `expr.subs(*args)` with additional logic for substitution
- of invertible functions.
- """
- if len(args) == 1:
- sequence = args[0]
- if isinstance(sequence, (Dict, Mapping)):
- sequence = sequence.items()
- elif not iterable(sequence):
- raise ValueError("Expected an iterable of (old, new) pairs")
- elif len(args) == 2:
- sequence = [args]
- else:
- raise ValueError("subs accepts either 1 or 2 arguments")
- new_subs = []
- for old, new in sequence:
- if isinstance(old, log):
- # If log(x) = y, then exp(a*log(x)) = exp(a*y)
- # that is, x**a = exp(a*y). Replace nontrivial powers of x
- # before subs turns them into `exp(y)**a`, but
- # do not replace x itself yet, to avoid `log(exp(y))`.
- x0 = old.args[0]
- expr = expr.replace(lambda x: x.is_Pow and x.base == x0,
- lambda x: exp(x.exp*new))
- new_subs.append((x0, exp(new)))
- return expr.subs(list(sequence) + new_subs)
- # Method based on that on SIN, described in "Symbolic Integration: The
- # Stormy Decade"
- inverse_trig_functions = (atan, asin, acos, acot, acsc, asec)
- def find_substitutions(integrand, symbol, u_var):
- results = []
- def test_subterm(u, u_diff):
- if u_diff == 0:
- return False
- substituted = integrand / u_diff
- debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var))
- substituted = manual_subs(substituted, u, u_var).cancel()
- if substituted.has_free(symbol):
- return False
- # avoid increasing the degree of a rational function
- if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var):
- deg_before = max([degree(t, symbol) for t in integrand.as_numer_denom()])
- deg_after = max([degree(t, u_var) for t in substituted.as_numer_denom()])
- if deg_after > deg_before:
- return False
- return substituted.as_independent(u_var, as_Add=False)
- def exp_subterms(term: Expr):
- linear_coeffs = []
- terms = []
- n = Wild('n', properties=[lambda n: n.is_Integer])
- for exp_ in term.find(exp):
- arg = exp_.args[0]
- if symbol not in arg.free_symbols:
- continue
- match = arg.match(n*symbol)
- if match:
- linear_coeffs.append(match[n])
- else:
- terms.append(exp_)
- if linear_coeffs:
- terms.append(exp(gcd_list(linear_coeffs)*symbol))
- return terms
- def possible_subterms(term):
- if isinstance(term, (TrigonometricFunction, HyperbolicFunction,
- *inverse_trig_functions,
- exp, log, Heaviside)):
- return [term.args[0]]
- elif isinstance(term, (chebyshevt, chebyshevu,
- legendre, hermite, laguerre)):
- return [term.args[1]]
- elif isinstance(term, (gegenbauer, assoc_laguerre)):
- return [term.args[2]]
- elif isinstance(term, jacobi):
- return [term.args[3]]
- elif isinstance(term, Mul):
- r = []
- for u in term.args:
- r.append(u)
- r.extend(possible_subterms(u))
- return r
- elif isinstance(term, Pow):
- r = [arg for arg in term.args if arg.has(symbol)]
- if term.exp.is_Integer:
- r.extend([term.base**d for d in primefactors(term.exp)
- if 1 < d < abs(term.args[1])])
- if term.base.is_Add:
- r.extend([t for t in possible_subterms(term.base)
- if t.is_Pow])
- return r
- elif isinstance(term, Add):
- r = []
- for arg in term.args:
- r.append(arg)
- r.extend(possible_subterms(arg))
- return r
- return []
- for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))):
- if u == symbol:
- continue
- u_diff = manual_diff(u, symbol)
- new_integrand = test_subterm(u, u_diff)
- if new_integrand is not False:
- constant, new_integrand = new_integrand
- if new_integrand == integrand.subs(symbol, u_var):
- continue
- substitution = (u, constant, new_integrand)
- if substitution not in results:
- results.append(substitution)
- return results
- def rewriter(condition, rewrite):
- """Strategy that rewrites an integrand."""
- def _rewriter(integral):
- integrand, symbol = integral
- debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol))
- if condition(*integral):
- rewritten = rewrite(*integral)
- if rewritten != integrand:
- substep = integral_steps(rewritten, symbol)
- if not isinstance(substep, DontKnowRule) and substep:
- return RewriteRule(integrand, symbol, rewritten, substep)
- return _rewriter
- def proxy_rewriter(condition, rewrite):
- """Strategy that rewrites an integrand based on some other criteria."""
- def _proxy_rewriter(criteria):
- criteria, integral = criteria
- integrand, symbol = integral
- debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria))
- args = criteria + list(integral)
- if condition(*args):
- rewritten = rewrite(*args)
- if rewritten != integrand:
- return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
- return _proxy_rewriter
- def multiplexer(conditions):
- """Apply the rule that matches the condition, else None"""
- def multiplexer_rl(expr):
- for key, rule in conditions.items():
- if key(expr):
- return rule(expr)
- return multiplexer_rl
- def alternatives(*rules):
- """Strategy that makes an AlternativeRule out of multiple possible results."""
- def _alternatives(integral):
- alts = []
- count = 0
- debug("List of Alternative Rules")
- for rule in rules:
- count = count + 1
- debug("Rule {}: {}".format(count, rule))
- result = rule(integral)
- if (result and not isinstance(result, DontKnowRule) and
- result != integral and result not in alts):
- alts.append(result)
- if len(alts) == 1:
- return alts[0]
- elif alts:
- doable = [rule for rule in alts if not rule.contains_dont_know()]
- if doable:
- return AlternativeRule(*integral, doable)
- else:
- return AlternativeRule(*integral, alts)
- return _alternatives
- def constant_rule(integral):
- return ConstantRule(*integral)
- def power_rule(integral):
- integrand, symbol = integral
- base, expt = integrand.as_base_exp()
- if symbol not in expt.free_symbols and isinstance(base, Symbol):
- if simplify(expt + 1) == 0:
- return ReciprocalRule(integrand, symbol, base)
- return PowerRule(integrand, symbol, base, expt)
- elif symbol not in base.free_symbols and isinstance(expt, Symbol):
- rule = ExpRule(integrand, symbol, base, expt)
- if fuzzy_not(log(base).is_zero):
- return rule
- elif log(base).is_zero:
- return ConstantRule(1, symbol)
- return PiecewiseRule(integrand, symbol, [
- (rule, Ne(log(base), 0)),
- (ConstantRule(1, symbol), True)
- ])
- def exp_rule(integral):
- integrand, symbol = integral
- if isinstance(integrand.args[0], Symbol):
- return ExpRule(integrand, symbol, E, integrand.args[0])
- def orthogonal_poly_rule(integral):
- orthogonal_poly_classes = {
- jacobi: JacobiRule,
- gegenbauer: GegenbauerRule,
- chebyshevt: ChebyshevTRule,
- chebyshevu: ChebyshevURule,
- legendre: LegendreRule,
- hermite: HermiteRule,
- laguerre: LaguerreRule,
- assoc_laguerre: AssocLaguerreRule
- }
- orthogonal_poly_var_index = {
- jacobi: 3,
- gegenbauer: 2,
- assoc_laguerre: 2
- }
- integrand, symbol = integral
- for klass in orthogonal_poly_classes:
- if isinstance(integrand, klass):
- var_index = orthogonal_poly_var_index.get(klass, 1)
- if (integrand.args[var_index] is symbol and not
- any(v.has(symbol) for v in integrand.args[:var_index])):
- return orthogonal_poly_classes[klass](integrand, symbol, *integrand.args[:var_index])
- _special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = []
- _wilds = []
- _symbol = Dummy('x')
- def special_function_rule(integral):
- integrand, symbol = integral
- if not _special_function_patterns:
- a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero])
- b = Wild('b', exclude=[_symbol])
- c = Wild('c', exclude=[_symbol])
- d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero])
- e = Wild('e', exclude=[_symbol], properties=[
- lambda x: not (x.is_nonnegative and x.is_integer)])
- _wilds.extend((a, b, c, d, e))
- # patterns consist of a SymPy class, a wildcard expr, an optional
- # condition coded as a lambda (when Wild properties are not enough),
- # followed by an applicable rule
- linear_pattern = a*_symbol + b
- quadratic_pattern = a*_symbol**2 + b*_symbol + c
- _special_function_patterns.extend((
- (Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule),
- (Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule),
- (Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule),
- (Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule),
- (Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule),
- (Pow, 1/log(linear_pattern, evaluate=False), None, LiRule),
- (exp, exp(quadratic_pattern, evaluate=False), None, ErfRule),
- (sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule),
- (cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule),
- (Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule),
- (Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule),
- (Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2),
- lambda a, d: a != d, EllipticFRule),
- (Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2),
- lambda a, d: a != d, EllipticERule),
- ))
- _integrand = integrand.subs(symbol, _symbol)
- for type_, pattern, constraint, rule in _special_function_patterns:
- if isinstance(_integrand, type_):
- match = _integrand.match(pattern)
- if match:
- wild_vals = tuple(match.get(w) for w in _wilds
- if match.get(w) is not None)
- if constraint is None or constraint(*wild_vals):
- return rule(integrand, symbol, *wild_vals)
- def _add_degenerate_step(generic_cond, generic_step: Rule, degenerate_step: Rule | None) -> Rule:
- if degenerate_step is None:
- return generic_step
- if isinstance(generic_step, PiecewiseRule):
- subfunctions = [(substep, (cond & generic_cond).simplify())
- for substep, cond in generic_step.subfunctions]
- else:
- subfunctions = [(generic_step, generic_cond)]
- if isinstance(degenerate_step, PiecewiseRule):
- subfunctions += degenerate_step.subfunctions
- else:
- subfunctions.append((degenerate_step, S.true))
- return PiecewiseRule(generic_step.integrand, generic_step.variable, subfunctions)
- def nested_pow_rule(integral: IntegralInfo):
- # nested (c*(a+b*x)**d)**e
- integrand, x = integral
- a_ = Wild('a', exclude=[x])
- b_ = Wild('b', exclude=[x, 0])
- pattern = a_+b_*x
- generic_cond = S.true
- class NoMatch(Exception):
- pass
- def _get_base_exp(expr: Expr) -> tuple[Expr, Expr]:
- if not expr.has_free(x):
- return S.One, S.Zero
- if expr.is_Mul:
- _, terms = expr.as_coeff_mul()
- if not terms:
- return S.One, S.Zero
- results = [_get_base_exp(term) for term in terms]
- bases = {b for b, _ in results}
- bases.discard(S.One)
- if len(bases) == 1:
- return bases.pop(), Add(*(e for _, e in results))
- raise NoMatch
- if expr.is_Pow:
- b, e = expr.base, expr.exp # type: ignore
- if e.has_free(x):
- raise NoMatch
- base_, sub_exp = _get_base_exp(b)
- return base_, sub_exp * e
- match = expr.match(pattern)
- if match:
- a, b = match[a_], match[b_]
- base_ = x + a/b
- nonlocal generic_cond
- generic_cond = Ne(b, 0)
- return base_, S.One
- raise NoMatch
- try:
- base, exp_ = _get_base_exp(integrand)
- except NoMatch:
- return
- if generic_cond is S.true:
- degenerate_step = None
- else:
- # equivalent with subs(b, 0) but no need to find b
- degenerate_step = ConstantRule(integrand.subs(x, 0), x)
- generic_step = NestedPowRule(integrand, x, base, exp_)
- return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
- def inverse_trig_rule(integral: IntegralInfo, degenerate=True):
- """
- Set degenerate=False on recursive call where coefficient of quadratic term
- is assumed non-zero.
- """
- integrand, symbol = integral
- base, exp = integrand.as_base_exp()
- a = Wild('a', exclude=[symbol])
- b = Wild('b', exclude=[symbol])
- c = Wild('c', exclude=[symbol, 0])
- match = base.match(a + b*symbol + c*symbol**2)
- if not match:
- return
- def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h) -> Rule:
- u_var = Dummy("u")
- rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2) # a>0, c>0
- quadratic_base = sqrt(c/a)*(symbol-h)
- constant = 1/sqrt(c)
- u_func = None
- if quadratic_base is not symbol:
- u_func = quadratic_base
- quadratic_base = u_var
- standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2)
- substep = RuleClass(standard_form, quadratic_base)
- if constant != 1:
- substep = ConstantTimesRule(constant*standard_form, symbol, constant, standard_form, substep)
- if u_func is not None:
- substep = URule(rewritten, symbol, u_var, u_func, substep)
- if h != 0:
- substep = CompleteSquareRule(integrand, symbol, rewritten, substep)
- return substep
- a, b, c = [match.get(i, S.Zero) for i in (a, b, c)]
- generic_cond = Ne(c, 0)
- if not degenerate or generic_cond is S.true:
- degenerate_step = None
- elif b.is_zero:
- degenerate_step = ConstantRule(a ** exp, symbol)
- else:
- degenerate_step = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol))
- if simplify(2*exp + 1) == 0:
- h, k = -b/(2*c), a - b**2/(4*c) # rewrite base to k + c*(symbol-h)**2
- non_square_cond = Ne(k, 0)
- square_step = None
- if non_square_cond is not S.true:
- square_step = NestedPowRule(1/sqrt(c*(symbol-h)**2), symbol, symbol-h, S.NegativeOne)
- if non_square_cond is S.false:
- return square_step
- generic_step = ReciprocalSqrtQuadraticRule(integrand, symbol, a, b, c)
- step = _add_degenerate_step(non_square_cond, generic_step, square_step)
- if k.is_real and c.is_real:
- # list of ((rule, base_exp, a, sign_a, b, sign_b), condition)
- rules = []
- for args, cond in ( # don't apply ArccoshRule to x**2-1
- ((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)), # 1-x**2
- ((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)), # 1+x**2
- ):
- if cond is S.true:
- return make_inverse_trig(*args)
- if cond is not S.false:
- rules.append((make_inverse_trig(*args), cond))
- if rules:
- if not k.is_positive: # conditions are not thorough, need fall back rule
- rules.append((generic_step, S.true))
- step = PiecewiseRule(integrand, symbol, rules)
- else:
- step = generic_step
- return _add_degenerate_step(generic_cond, step, degenerate_step)
- if exp == S.Half:
- step = SqrtQuadraticRule(integrand, symbol, a, b, c)
- return _add_degenerate_step(generic_cond, step, degenerate_step)
- def add_rule(integral):
- integrand, symbol = integral
- results = [integral_steps(g, symbol)
- for g in integrand.as_ordered_terms()]
- return None if None in results else AddRule(integrand, symbol, results)
- def mul_rule(integral: IntegralInfo):
- integrand, symbol = integral
- # Constant times function case
- coeff, f = integrand.as_independent(symbol)
- if coeff != 1:
- next_step = integral_steps(f, symbol)
- if next_step is not None:
- return ConstantTimesRule(integrand, symbol, coeff, f, next_step)
- def _parts_rule(integrand, symbol) -> tuple[Expr, Expr, Expr, Expr, Rule] | None:
- # LIATE rule:
- # log, inverse trig, algebraic, trigonometric, exponential
- def pull_out_algebraic(integrand):
- integrand = integrand.cancel().together()
- # iterating over Piecewise args would not work here
- algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul
- else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)])
- if algebraic:
- u = Mul(*algebraic)
- dv = (integrand / u).cancel()
- return u, dv
- def pull_out_u(*functions) -> Callable[[Expr], tuple[Expr, Expr] | None]:
- def pull_out_u_rl(integrand: Expr) -> tuple[Expr, Expr] | None:
- if any(integrand.has(f) for f in functions):
- args = [arg for arg in integrand.args
- if any(isinstance(arg, cls) for cls in functions)]
- if args:
- u = Mul(*args)
- dv = integrand / u
- return u, dv
- return None
- return pull_out_u_rl
- liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions),
- pull_out_algebraic, pull_out_u(sin, cos),
- pull_out_u(exp)]
- dummy = Dummy("temporary")
- # we can integrate log(x) and atan(x) by setting dv = 1
- if isinstance(integrand, (log, *inverse_trig_functions)):
- integrand = dummy * integrand
- for index, rule in enumerate(liate_rules):
- result = rule(integrand)
- if result:
- u, dv = result
- # Don't pick u to be a constant if possible
- if symbol not in u.free_symbols and not u.has(dummy):
- return None
- u = u.subs(dummy, 1)
- dv = dv.subs(dummy, 1)
- # Don't pick a non-polynomial algebraic to be differentiated
- if rule == pull_out_algebraic and not u.is_polynomial(symbol):
- return None
- # Don't trade one logarithm for another
- if isinstance(u, log):
- rec_dv = 1/dv
- if (rec_dv.is_polynomial(symbol) and
- degree(rec_dv, symbol) == 1):
- return None
- # Can integrate a polynomial times OrthogonalPolynomial
- if rule == pull_out_algebraic:
- if dv.is_Derivative or dv.has(TrigonometricFunction) or \
- isinstance(dv, OrthogonalPolynomial):
- v_step = integral_steps(dv, symbol)
- if v_step.contains_dont_know():
- return None
- else:
- du = u.diff(symbol)
- v = v_step.eval()
- return u, dv, v, du, v_step
- # make sure dv is amenable to integration
- accept = False
- if index < 2: # log and inverse trig are usually worth trying
- accept = True
- elif (rule == pull_out_algebraic and dv.args and
- all(isinstance(a, (sin, cos, exp))
- for a in dv.args)):
- accept = True
- else:
- for lrule in liate_rules[index + 1:]:
- r = lrule(integrand)
- if r and r[0].subs(dummy, 1).equals(dv):
- accept = True
- break
- if accept:
- du = u.diff(symbol)
- v_step = integral_steps(simplify(dv), symbol)
- if not v_step.contains_dont_know():
- v = v_step.eval()
- return u, dv, v, du, v_step
- return None
- def parts_rule(integral):
- integrand, symbol = integral
- constant, integrand = integrand.as_coeff_Mul()
- result = _parts_rule(integrand, symbol)
- steps = []
- if result:
- u, dv, v, du, v_step = result
- debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step))
- steps.append(result)
- if isinstance(v, Integral):
- return
- # Set a limit on the number of times u can be used
- if isinstance(u, (sin, cos, exp, sinh, cosh)):
- cachekey = u.xreplace({symbol: _cache_dummy})
- if _parts_u_cache[cachekey] > 2:
- return
- _parts_u_cache[cachekey] += 1
- # Try cyclic integration by parts a few times
- for _ in range(4):
- debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand))
- coefficient = ((v * du) / integrand).cancel()
- if coefficient == 1:
- break
- if symbol not in coefficient.free_symbols:
- rule = CyclicPartsRule(integrand, symbol,
- [PartsRule(None, None, u, dv, v_step, None)
- for (u, dv, v, du, v_step) in steps],
- (-1) ** len(steps) * coefficient)
- if (constant != 1) and rule:
- rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule)
- return rule
- # _parts_rule is sensitive to constants, factor it out
- next_constant, next_integrand = (v * du).as_coeff_Mul()
- result = _parts_rule(next_integrand, symbol)
- if result:
- u, dv, v, du, v_step = result
- u *= next_constant
- du *= next_constant
- steps.append((u, dv, v, du, v_step))
- else:
- break
- def make_second_step(steps, integrand):
- if steps:
- u, dv, v, du, v_step = steps[0]
- return PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du))
- return integral_steps(integrand, symbol)
- if steps:
- u, dv, v, du, v_step = steps[0]
- rule = PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du))
- if (constant != 1) and rule:
- rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule)
- return rule
- def trig_rule(integral):
- integrand, symbol = integral
- if integrand == sin(symbol):
- return SinRule(integrand, symbol)
- if integrand == cos(symbol):
- return CosRule(integrand, symbol)
- if integrand == sec(symbol)**2:
- return Sec2Rule(integrand, symbol)
- if integrand == csc(symbol)**2:
- return Csc2Rule(integrand, symbol)
- if isinstance(integrand, tan):
- rewritten = sin(*integrand.args) / cos(*integrand.args)
- elif isinstance(integrand, cot):
- rewritten = cos(*integrand.args) / sin(*integrand.args)
- elif isinstance(integrand, sec):
- arg = integrand.args[0]
- rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) /
- (sec(arg) + tan(arg)))
- elif isinstance(integrand, csc):
- arg = integrand.args[0]
- rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) /
- (csc(arg) + cot(arg)))
- else:
- return
- return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
- def trig_product_rule(integral: IntegralInfo):
- integrand, symbol = integral
- if integrand == sec(symbol) * tan(symbol):
- return SecTanRule(integrand, symbol)
- if integrand == csc(symbol) * cot(symbol):
- return CscCotRule(integrand, symbol)
- def quadratic_denom_rule(integral):
- integrand, symbol = integral
- a = Wild('a', exclude=[symbol])
- b = Wild('b', exclude=[symbol])
- c = Wild('c', exclude=[symbol])
- match = integrand.match(a / (b * symbol ** 2 + c))
- if match:
- a, b, c = match[a], match[b], match[c]
- general_rule = ArctanRule(integrand, symbol, a, b, c)
- if b.is_extended_real and c.is_extended_real:
- positive_cond = c/b > 0
- if positive_cond is S.true:
- return general_rule
- coeff = a/(2*sqrt(-c)*sqrt(b))
- constant = sqrt(-c/b)
- r1 = 1/(symbol-constant)
- r2 = 1/(symbol+constant)
- log_steps = [ReciprocalRule(r1, symbol, symbol-constant),
- ConstantTimesRule(-r2, symbol, -1, r2, ReciprocalRule(r2, symbol, symbol+constant))]
- rewritten = sub = r1 - r2
- negative_step = AddRule(sub, symbol, log_steps)
- if coeff != 1:
- rewritten = Mul(coeff, sub, evaluate=False)
- negative_step = ConstantTimesRule(rewritten, symbol, coeff, sub, negative_step)
- negative_step = RewriteRule(integrand, symbol, rewritten, negative_step)
- if positive_cond is S.false:
- return negative_step
- return PiecewiseRule(integrand, symbol, [(general_rule, positive_cond), (negative_step, S.true)])
- return general_rule
- d = Wild('d', exclude=[symbol])
- match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d))
- if match2:
- b, c = match2[b], match2[c]
- if b.is_zero:
- return
- u = Dummy('u')
- u_func = symbol + c/(2*b)
- integrand2 = integrand.subs(symbol, u - c / (2*b))
- next_step = integral_steps(integrand2, u)
- if next_step:
- return URule(integrand2, symbol, u, u_func, next_step)
- else:
- return
- e = Wild('e', exclude=[symbol])
- match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e))
- if match3:
- a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e]
- if c.is_zero:
- return
- denominator = c * symbol**2 + d * symbol + e
- const = a/(2*c)
- numer1 = (2*c*symbol+d)
- numer2 = - const*d + b
- u = Dummy('u')
- step1 = URule(integrand, symbol,
- u, denominator, integral_steps(u**(-1), u))
- if const != 1:
- step1 = ConstantTimesRule(const*numer1/denominator, symbol,
- const, numer1/denominator, step1)
- if numer2.is_zero:
- return step1
- step2 = integral_steps(numer2/denominator, symbol)
- substeps = AddRule(integrand, symbol, [step1, step2])
- rewriten = const*numer1/denominator+numer2/denominator
- return RewriteRule(integrand, symbol, rewriten, substeps)
- return
- def sqrt_linear_rule(integral: IntegralInfo):
- """
- Substitute common (a+b*x)**(1/n)
- """
- integrand, x = integral
- a = Wild('a', exclude=[x])
- b = Wild('b', exclude=[x, 0])
- a0 = b0 = 0
- bases, qs, bs = [], [], []
- for pow_ in integrand.find(Pow): # collect all (a+b*x)**(p/q)
- base, exp_ = pow_.base, pow_.exp
- if exp_.is_Integer or x not in base.free_symbols: # skip 1/x and sqrt(2)
- continue
- if not exp_.is_Rational: # exclude x**pi
- return
- match = base.match(a+b*x)
- if not match: # skip non-linear
- continue # for sqrt(x+sqrt(x)), although base is non-linear, we can still substitute sqrt(x)
- a1, b1 = match[a], match[b]
- if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative: # cannot transform sqrt(x) to sqrt(x+1) or sqrt(-x)
- return
- if b0 == 0 or (b0/b1 > 1) is S.true: # choose the latter of sqrt(2*x) and sqrt(x) as representative
- a0, b0 = a1, b1
- bases.append(base)
- bs.append(b1)
- qs.append(exp_.q)
- if b0 == 0: # no such pattern found
- return
- q0: Integer = lcm_list(qs)
- u_x = (a0 + b0*x)**(1/q0)
- u = Dummy("u")
- substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q)
- for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0)
- substep = integral_steps(substituted*u**(q0-1)*q0/b0, u)
- if not substep.contains_dont_know():
- step: Rule = URule(integrand, x, u, u_x, substep)
- generic_cond = Ne(b0, 0)
- if generic_cond is not S.true: # possible degenerate case
- simplified = integrand.subs({b: 0 for b in bs})
- degenerate_step = integral_steps(simplified, x)
- step = PiecewiseRule(integrand, x, [(step, generic_cond), (degenerate_step, S.true)])
- return step
- def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True):
- integrand, x = integral
- a = Wild('a', exclude=[x])
- b = Wild('b', exclude=[x])
- c = Wild('c', exclude=[x, 0])
- f = Wild('f')
- n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd])
- match = integrand.match(f*sqrt(a+b*x+c*x**2)**n)
- if not match:
- return
- a, b, c, f, n = match[a], match[b], match[c], match[f], match[n]
- f_poly = f.as_poly(x)
- if f_poly is None:
- return
- generic_cond = Ne(c, 0)
- if not degenerate or generic_cond is S.true:
- degenerate_step = None
- elif b.is_zero:
- degenerate_step = integral_steps(f*sqrt(a)**n, x)
- else:
- degenerate_step = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x))
- def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr):
- denom = sqrt(a+b*x+c*x**2)
- deg = numer_poly.degree()
- if deg <= 1:
- # integrand == (d+e*x)/sqrt(a+b*x+c*x**2)
- e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr())
- # rewrite numerator to A*(2*c*x+b) + B
- A = e/(2*c)
- B = d-A*b
- pre_substitute = (2*c*x+b)/denom
- constant_step: Rule | None = None
- linear_step: Rule | None = None
- if A != 0:
- u = Dummy("u")
- pow_rule = PowerRule(1/sqrt(u), u, u, -S.Half)
- linear_step = URule(pre_substitute, x, u, a+b*x+c*x**2, pow_rule)
- if A != 1:
- linear_step = ConstantTimesRule(A*pre_substitute, x, A, pre_substitute, linear_step)
- if B != 0:
- constant_step = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False)
- if B != 1:
- constant_step = ConstantTimesRule(B/denom, x, B, 1/denom, constant_step) # type: ignore
- if linear_step and constant_step:
- add = Add(A*pre_substitute, B/denom, evaluate=False)
- step: Rule | None = RewriteRule(integrand, x, add, AddRule(add, x, [linear_step, constant_step]))
- else:
- step = linear_step or constant_step
- else:
- coeffs = numer_poly.all_coeffs()
- step = SqrtQuadraticDenomRule(integrand, x, a, b, c, coeffs)
- return step
- if n > 0: # rewrite poly * sqrt(s)**(2*k-1) to poly*s**k / sqrt(s)
- numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2)
- rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2)
- substep = sqrt_quadratic_denom_rule(numer_poly, rewritten)
- generic_step = RewriteRule(integrand, x, rewritten, substep)
- elif n == -1:
- generic_step = sqrt_quadratic_denom_rule(f_poly, integrand)
- else:
- return # todo: handle n < -1 case
- return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
- def hyperbolic_rule(integral: tuple[Expr, Symbol]):
- integrand, symbol = integral
- if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol:
- if integrand.func == sinh:
- return SinhRule(integrand, symbol)
- if integrand.func == cosh:
- return CoshRule(integrand, symbol)
- u = Dummy('u')
- if integrand.func == tanh:
- rewritten = sinh(symbol)/cosh(symbol)
- return RewriteRule(integrand, symbol, rewritten,
- URule(rewritten, symbol, u, cosh(symbol), ReciprocalRule(1/u, u, u)))
- if integrand.func == coth:
- rewritten = cosh(symbol)/sinh(symbol)
- return RewriteRule(integrand, symbol, rewritten,
- URule(rewritten, symbol, u, sinh(symbol), ReciprocalRule(1/u, u, u)))
- else:
- rewritten = integrand.rewrite(tanh)
- if integrand.func == sech:
- return RewriteRule(integrand, symbol, rewritten,
- URule(rewritten, symbol, u, tanh(symbol/2),
- ArctanRule(2/(u**2 + 1), u, S(2), S.One, S.One)))
- if integrand.func == csch:
- return RewriteRule(integrand, symbol, rewritten,
- URule(rewritten, symbol, u, tanh(symbol/2),
- ReciprocalRule(1/u, u, u)))
- @cacheit
- def make_wilds(symbol):
- a = Wild('a', exclude=[symbol])
- b = Wild('b', exclude=[symbol])
- m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
- n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
- return a, b, m, n
- @cacheit
- def sincos_pattern(symbol):
- a, b, m, n = make_wilds(symbol)
- pattern = sin(a*symbol)**m * cos(b*symbol)**n
- return pattern, a, b, m, n
- @cacheit
- def tansec_pattern(symbol):
- a, b, m, n = make_wilds(symbol)
- pattern = tan(a*symbol)**m * sec(b*symbol)**n
- return pattern, a, b, m, n
- @cacheit
- def cotcsc_pattern(symbol):
- a, b, m, n = make_wilds(symbol)
- pattern = cot(a*symbol)**m * csc(b*symbol)**n
- return pattern, a, b, m, n
- @cacheit
- def heaviside_pattern(symbol):
- m = Wild('m', exclude=[symbol])
- b = Wild('b', exclude=[symbol])
- g = Wild('g')
- pattern = Heaviside(m*symbol + b) * g
- return pattern, m, b, g
- def uncurry(func):
- def uncurry_rl(args):
- return func(*args)
- return uncurry_rl
- def trig_rewriter(rewrite):
- def trig_rewriter_rl(args):
- a, b, m, n, integrand, symbol = args
- rewritten = rewrite(a, b, m, n, integrand, symbol)
- if rewritten != integrand:
- return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
- return trig_rewriter_rl
- sincos_botheven_condition = uncurry(
- lambda a, b, m, n, i, s: m.is_even and n.is_even and
- m.is_nonnegative and n.is_nonnegative)
- sincos_botheven = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) *
- (((1 + cos(2*b*symbol)) / 2) ** (n / 2)) ))
- sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3)
- sincos_sinodd = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) *
- sin(a*symbol) *
- cos(b*symbol) ** n))
- sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3)
- sincos_cosodd = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) *
- cos(b*symbol) *
- sin(a*symbol) ** m))
- tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
- tansec_seceven = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) *
- sec(b*symbol)**2 *
- tan(a*symbol) ** m ))
- tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
- tansec_tanodd = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) *
- tan(a*symbol) *
- sec(b*symbol) ** n ))
- tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0)
- tan_tansquared = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1))
- cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
- cotcsc_csceven = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) *
- csc(b*symbol)**2 *
- cot(a*symbol) ** m ))
- cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
- cotcsc_cotodd = trig_rewriter(
- lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) *
- cot(a*symbol) *
- csc(b*symbol) ** n ))
- def trig_sincos_rule(integral):
- integrand, symbol = integral
- if any(integrand.has(f) for f in (sin, cos)):
- pattern, a, b, m, n = sincos_pattern(symbol)
- match = integrand.match(pattern)
- if not match:
- return
- return multiplexer({
- sincos_botheven_condition: sincos_botheven,
- sincos_sinodd_condition: sincos_sinodd,
- sincos_cosodd_condition: sincos_cosodd
- })(tuple(
- [match.get(i, S.Zero) for i in (a, b, m, n)] +
- [integrand, symbol]))
- def trig_tansec_rule(integral):
- integrand, symbol = integral
- integrand = integrand.subs({
- 1 / cos(symbol): sec(symbol)
- })
- if any(integrand.has(f) for f in (tan, sec)):
- pattern, a, b, m, n = tansec_pattern(symbol)
- match = integrand.match(pattern)
- if not match:
- return
- return multiplexer({
- tansec_tanodd_condition: tansec_tanodd,
- tansec_seceven_condition: tansec_seceven,
- tan_tansquared_condition: tan_tansquared
- })(tuple(
- [match.get(i, S.Zero) for i in (a, b, m, n)] +
- [integrand, symbol]))
- def trig_cotcsc_rule(integral):
- integrand, symbol = integral
- integrand = integrand.subs({
- 1 / sin(symbol): csc(symbol),
- 1 / tan(symbol): cot(symbol),
- cos(symbol) / tan(symbol): cot(symbol)
- })
- if any(integrand.has(f) for f in (cot, csc)):
- pattern, a, b, m, n = cotcsc_pattern(symbol)
- match = integrand.match(pattern)
- if not match:
- return
- return multiplexer({
- cotcsc_cotodd_condition: cotcsc_cotodd,
- cotcsc_csceven_condition: cotcsc_csceven
- })(tuple(
- [match.get(i, S.Zero) for i in (a, b, m, n)] +
- [integrand, symbol]))
- def trig_sindouble_rule(integral):
- integrand, symbol = integral
- a = Wild('a', exclude=[sin(2*symbol)])
- match = integrand.match(sin(2*symbol)*a)
- if match:
- sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol)
- return integral_steps(integrand * sin_double, symbol)
- def trig_powers_products_rule(integral):
- return do_one(null_safe(trig_sincos_rule),
- null_safe(trig_tansec_rule),
- null_safe(trig_cotcsc_rule),
- null_safe(trig_sindouble_rule))(integral)
- def trig_substitution_rule(integral):
- integrand, symbol = integral
- A = Wild('a', exclude=[0, symbol])
- B = Wild('b', exclude=[0, symbol])
- theta = Dummy("theta")
- target_pattern = A + B*symbol**2
- matches = integrand.find(target_pattern)
- for expr in matches:
- match = expr.match(target_pattern)
- a = match.get(A, S.Zero)
- b = match.get(B, S.Zero)
- a_positive = ((a.is_number and a > 0) or a.is_positive)
- b_positive = ((b.is_number and b > 0) or b.is_positive)
- a_negative = ((a.is_number and a < 0) or a.is_negative)
- b_negative = ((b.is_number and b < 0) or b.is_negative)
- x_func = None
- if a_positive and b_positive:
- # a**2 + b*x**2. Assume sec(theta) > 0, -pi/2 < theta < pi/2
- x_func = (sqrt(a)/sqrt(b)) * tan(theta)
- # Do not restrict the domain: tan(theta) takes on any real
- # value on the interval -pi/2 < theta < pi/2 so x takes on
- # any value
- restriction = True
- elif a_positive and b_negative:
- # a**2 - b*x**2. Assume cos(theta) > 0, -pi/2 < theta < pi/2
- constant = sqrt(a)/sqrt(-b)
- x_func = constant * sin(theta)
- restriction = And(symbol > -constant, symbol < constant)
- elif a_negative and b_positive:
- # b*x**2 - a**2. Assume sin(theta) > 0, 0 < theta < pi
- constant = sqrt(-a)/sqrt(b)
- x_func = constant * sec(theta)
- restriction = And(symbol > -constant, symbol < constant)
- if x_func:
- # Manually simplify sqrt(trig(theta)**2) to trig(theta)
- # Valid due to assumed domain restriction
- substitutions = {}
- for f in [sin, cos, tan,
- sec, csc, cot]:
- substitutions[sqrt(f(theta)**2)] = f(theta)
- substitutions[sqrt(f(theta)**(-2))] = 1/f(theta)
- replaced = integrand.subs(symbol, x_func).trigsimp()
- replaced = manual_subs(replaced, substitutions)
- if not replaced.has(symbol):
- replaced *= manual_diff(x_func, theta)
- replaced = replaced.trigsimp()
- secants = replaced.find(1/cos(theta))
- if secants:
- replaced = replaced.xreplace({
- 1/cos(theta): sec(theta)
- })
- substep = integral_steps(replaced, theta)
- if not substep.contains_dont_know():
- return TrigSubstitutionRule(integrand, symbol,
- theta, x_func, replaced, substep, restriction)
- def heaviside_rule(integral):
- integrand, symbol = integral
- pattern, m, b, g = heaviside_pattern(symbol)
- match = integrand.match(pattern)
- if match and 0 != match[g]:
- # f = Heaviside(m*x + b)*g
- substep = integral_steps(match[g], symbol)
- m, b = match[m], match[b]
- return HeavisideRule(integrand, symbol, m*symbol + b, -b/m, substep)
- def dirac_delta_rule(integral: IntegralInfo):
- integrand, x = integral
- if len(integrand.args) == 1:
- n = S.Zero
- else:
- n = integrand.args[1]
- if not n.is_Integer or n < 0:
- return
- a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0])
- match = integrand.args[0].match(a+b*x)
- if not match:
- return
- a, b = match[a], match[b]
- generic_cond = Ne(b, 0)
- if generic_cond is S.true:
- degenerate_step = None
- else:
- degenerate_step = ConstantRule(DiracDelta(a, n), x)
- generic_step = DiracDeltaRule(integrand, x, n, a, b)
- return _add_degenerate_step(generic_cond, generic_step, degenerate_step)
- def substitution_rule(integral):
- integrand, symbol = integral
- u_var = Dummy("u")
- substitutions = find_substitutions(integrand, symbol, u_var)
- count = 0
- if substitutions:
- debug("List of Substitution Rules")
- ways = []
- for u_func, c, substituted in substitutions:
- subrule = integral_steps(substituted, u_var)
- count = count + 1
- debug("Rule {}: {}".format(count, subrule))
- if subrule.contains_dont_know():
- continue
- if simplify(c - 1) != 0:
- _, denom = c.as_numer_denom()
- if subrule:
- subrule = ConstantTimesRule(c * substituted, u_var, c, substituted, subrule)
- if denom.free_symbols:
- piecewise = []
- could_be_zero = []
- if isinstance(denom, Mul):
- could_be_zero = denom.args
- else:
- could_be_zero.append(denom)
- for expr in could_be_zero:
- if not fuzzy_not(expr.is_zero):
- substep = integral_steps(manual_subs(integrand, expr, 0), symbol)
- if substep:
- piecewise.append((
- substep,
- Eq(expr, 0)
- ))
- piecewise.append((subrule, True))
- subrule = PiecewiseRule(substituted, symbol, piecewise)
- ways.append(URule(integrand, symbol, u_var, u_func, subrule))
- if len(ways) > 1:
- return AlternativeRule(integrand, symbol, ways)
- elif ways:
- return ways[0]
- partial_fractions_rule = rewriter(
- lambda integrand, symbol: integrand.is_rational_function(),
- lambda integrand, symbol: integrand.apart(symbol))
- cancel_rule = rewriter(
- # lambda integrand, symbol: integrand.is_algebraic_expr(),
- # lambda integrand, symbol: isinstance(integrand, Mul),
- lambda integrand, symbol: True,
- lambda integrand, symbol: integrand.cancel())
- distribute_expand_rule = rewriter(
- lambda integrand, symbol: (
- all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)
- or isinstance(integrand, Pow)
- or isinstance(integrand, Mul)),
- lambda integrand, symbol: integrand.expand())
- trig_expand_rule = rewriter(
- # If there are trig functions with different arguments, expand them
- lambda integrand, symbol: (
- len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1),
- lambda integrand, symbol: integrand.expand(trig=True))
- def derivative_rule(integral):
- integrand = integral[0]
- diff_variables = integrand.variables
- undifferentiated_function = integrand.expr
- integrand_variables = undifferentiated_function.free_symbols
- if integral.symbol in integrand_variables:
- if integral.symbol in diff_variables:
- return DerivativeRule(*integral)
- else:
- return DontKnowRule(integrand, integral.symbol)
- else:
- return ConstantRule(*integral)
- def rewrites_rule(integral):
- integrand, symbol = integral
- if integrand.match(1/cos(symbol)):
- rewritten = integrand.subs(1/cos(symbol), sec(symbol))
- return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol))
- def fallback_rule(integral):
- return DontKnowRule(*integral)
- # Cache is used to break cyclic integrals.
- # Need to use the same dummy variable in cached expressions for them to match.
- # Also record "u" of integration by parts, to avoid infinite repetition.
- _integral_cache: dict[Expr, Expr | None] = {}
- _parts_u_cache: dict[Expr, int] = defaultdict(int)
- _cache_dummy = Dummy("z")
- def integral_steps(integrand, symbol, **options):
- """Returns the steps needed to compute an integral.
- Explanation
- ===========
- This function attempts to mirror what a student would do by hand as
- closely as possible.
- SymPy Gamma uses this to provide a step-by-step explanation of an
- integral. The code it uses to format the results of this function can be
- found at
- https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py.
- Examples
- ========
- >>> from sympy import exp, sin
- >>> from sympy.integrals.manualintegrate import integral_steps
- >>> from sympy.abc import x
- >>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \
- # doctest: +NORMALIZE_WHITESPACE
- URule(integrand=exp(x)/(exp(2*x) + 1), variable=x, u_var=_u, u_func=exp(x),
- substep=ArctanRule(integrand=1/(_u**2 + 1), variable=_u, a=1, b=1, c=1))
- >>> print(repr(integral_steps(sin(x), x))) \
- # doctest: +NORMALIZE_WHITESPACE
- SinRule(integrand=sin(x), variable=x)
- >>> print(repr(integral_steps((x**2 + 3)**2, x))) \
- # doctest: +NORMALIZE_WHITESPACE
- RewriteRule(integrand=(x**2 + 3)**2, variable=x, rewritten=x**4 + 6*x**2 + 9,
- substep=AddRule(integrand=x**4 + 6*x**2 + 9, variable=x,
- substeps=[PowerRule(integrand=x**4, variable=x, base=x, exp=4),
- ConstantTimesRule(integrand=6*x**2, variable=x, constant=6, other=x**2,
- substep=PowerRule(integrand=x**2, variable=x, base=x, exp=2)),
- ConstantRule(integrand=9, variable=x)]))
- Returns
- =======
- rule : Rule
- The first step; most rules have substeps that must also be
- considered. These substeps can be evaluated using ``manualintegrate``
- to obtain a result.
- """
- cachekey = integrand.xreplace({symbol: _cache_dummy})
- if cachekey in _integral_cache:
- if _integral_cache[cachekey] is None:
- # Stop this attempt, because it leads around in a loop
- return DontKnowRule(integrand, symbol)
- else:
- # TODO: This is for future development, as currently
- # _integral_cache gets no values other than None
- return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol),
- symbol)
- else:
- _integral_cache[cachekey] = None
- integral = IntegralInfo(integrand, symbol)
- def key(integral):
- integrand = integral.integrand
- if symbol not in integrand.free_symbols:
- return Number
- for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial):
- if isinstance(integrand, cls):
- return cls
- return type(integrand)
- def integral_is_subclass(*klasses):
- def _integral_is_subclass(integral):
- k = key(integral)
- return k and issubclass(k, klasses)
- return _integral_is_subclass
- result = do_one(
- null_safe(special_function_rule),
- null_safe(switch(key, {
- Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule),
- null_safe(sqrt_linear_rule),
- null_safe(quadratic_denom_rule)),
- Symbol: power_rule,
- exp: exp_rule,
- Add: add_rule,
- Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule),
- null_safe(heaviside_rule), null_safe(quadratic_denom_rule),
- null_safe(sqrt_linear_rule),
- null_safe(sqrt_quadratic_rule)),
- Derivative: derivative_rule,
- TrigonometricFunction: trig_rule,
- Heaviside: heaviside_rule,
- DiracDelta: dirac_delta_rule,
- OrthogonalPolynomial: orthogonal_poly_rule,
- Number: constant_rule
- })),
- do_one(
- null_safe(trig_rule),
- null_safe(hyperbolic_rule),
- null_safe(alternatives(
- rewrites_rule,
- substitution_rule,
- condition(
- integral_is_subclass(Mul, Pow),
- partial_fractions_rule),
- condition(
- integral_is_subclass(Mul, Pow),
- cancel_rule),
- condition(
- integral_is_subclass(Mul, log,
- *inverse_trig_functions),
- parts_rule),
- condition(
- integral_is_subclass(Mul, Pow),
- distribute_expand_rule),
- trig_powers_products_rule,
- trig_expand_rule
- )),
- null_safe(condition(integral_is_subclass(Mul, Pow), nested_pow_rule)),
- null_safe(trig_substitution_rule)
- ),
- fallback_rule)(integral)
- del _integral_cache[cachekey]
- return result
- def manualintegrate(f, var):
- """manualintegrate(f, var)
- Explanation
- ===========
- Compute indefinite integral of a single variable using an algorithm that
- resembles what a student would do by hand.
- Unlike :func:`~.integrate`, var can only be a single symbol.
- Examples
- ========
- >>> from sympy import sin, cos, tan, exp, log, integrate
- >>> from sympy.integrals.manualintegrate import manualintegrate
- >>> from sympy.abc import x
- >>> manualintegrate(1 / x, x)
- log(x)
- >>> integrate(1/x)
- log(x)
- >>> manualintegrate(log(x), x)
- x*log(x) - x
- >>> integrate(log(x))
- x*log(x) - x
- >>> manualintegrate(exp(x) / (1 + exp(2 * x)), x)
- atan(exp(x))
- >>> integrate(exp(x) / (1 + exp(2 * x)))
- RootSum(4*_z**2 + 1, Lambda(_i, _i*log(2*_i + exp(x))))
- >>> manualintegrate(cos(x)**4 * sin(x), x)
- -cos(x)**5/5
- >>> integrate(cos(x)**4 * sin(x), x)
- -cos(x)**5/5
- >>> manualintegrate(cos(x)**4 * sin(x)**3, x)
- cos(x)**7/7 - cos(x)**5/5
- >>> integrate(cos(x)**4 * sin(x)**3, x)
- cos(x)**7/7 - cos(x)**5/5
- >>> manualintegrate(tan(x), x)
- -log(cos(x))
- >>> integrate(tan(x), x)
- -log(cos(x))
- See Also
- ========
- sympy.integrals.integrals.integrate
- sympy.integrals.integrals.Integral.doit
- sympy.integrals.integrals.Integral
- """
- result = integral_steps(f, var).eval()
- # Clear the cache of u-parts
- _parts_u_cache.clear()
- # If we got Piecewise with two parts, put generic first
- if isinstance(result, Piecewise) and len(result.args) == 2:
- cond = result.args[0][1]
- if isinstance(cond, Eq) and result.args[1][1] == True:
- result = result.func(
- (result.args[1][0], Ne(*cond.args)),
- (result.args[0][0], True))
- return result
|