| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 | from __future__ import annotationsfrom typing import TYPE_CHECKINGfrom sympy.simplify import simplify as simp, trigsimp as tsimp  # type: ignorefrom sympy.core.decorators import call_highest_priority, _sympifyitfrom sympy.core.assumptions import StdFactKBfrom sympy.core.function import diff as dffrom sympy.integrals.integrals import Integralfrom sympy.polys.polytools import factor as fctrfrom sympy.core import S, Add, Mulfrom sympy.core.expr import Exprif TYPE_CHECKING:    from sympy.vector.vector import BaseVectorclass BasisDependent(Expr):    """    Super class containing functionality common to vectors and    dyadics.    Named so because the representation of these quantities in    sympy.vector is dependent on the basis they are expressed in.    """    zero: BasisDependentZero    @call_highest_priority('__radd__')    def __add__(self, other):        return self._add_func(self, other)    @call_highest_priority('__add__')    def __radd__(self, other):        return self._add_func(other, self)    @call_highest_priority('__rsub__')    def __sub__(self, other):        return self._add_func(self, -other)    @call_highest_priority('__sub__')    def __rsub__(self, other):        return self._add_func(other, -self)    @_sympifyit('other', NotImplemented)    @call_highest_priority('__rmul__')    def __mul__(self, other):        return self._mul_func(self, other)    @_sympifyit('other', NotImplemented)    @call_highest_priority('__mul__')    def __rmul__(self, other):        return self._mul_func(other, self)    def __neg__(self):        return self._mul_func(S.NegativeOne, self)    @_sympifyit('other', NotImplemented)    @call_highest_priority('__rtruediv__')    def __truediv__(self, other):        return self._div_helper(other)    @call_highest_priority('__truediv__')    def __rtruediv__(self, other):        return TypeError("Invalid divisor for division")    def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False):        """        Implements the SymPy evalf routine for this quantity.        evalf's documentation        =====================        """        options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict,                'quad':quad, 'verbose':verbose}        vec = self.zero        for k, v in self.components.items():            vec += v.evalf(n, **options) * k        return vec    evalf.__doc__ += Expr.evalf.__doc__  # type: ignore    n = evalf    def simplify(self, **kwargs):        """        Implements the SymPy simplify routine for this quantity.        simplify's documentation        ========================        """        simp_components = [simp(v, **kwargs) * k for                           k, v in self.components.items()]        return self._add_func(*simp_components)    simplify.__doc__ += simp.__doc__  # type: ignore    def trigsimp(self, **opts):        """        Implements the SymPy trigsimp routine, for this quantity.        trigsimp's documentation        ========================        """        trig_components = [tsimp(v, **opts) * k for                           k, v in self.components.items()]        return self._add_func(*trig_components)    trigsimp.__doc__ += tsimp.__doc__  # type: ignore    def _eval_simplify(self, **kwargs):        return self.simplify(**kwargs)    def _eval_trigsimp(self, **opts):        return self.trigsimp(**opts)    def _eval_derivative(self, wrt):        return self.diff(wrt)    def _eval_Integral(self, *symbols, **assumptions):        integral_components = [Integral(v, *symbols, **assumptions) * k                               for k, v in self.components.items()]        return self._add_func(*integral_components)    def as_numer_denom(self):        """        Returns the expression as a tuple wrt the following        transformation -        expression -> a/b -> a, b        """        return self, S.One    def factor(self, *args, **kwargs):        """        Implements the SymPy factor routine, on the scalar parts        of a basis-dependent expression.        factor's documentation        ========================        """        fctr_components = [fctr(v, *args, **kwargs) * k for                           k, v in self.components.items()]        return self._add_func(*fctr_components)    factor.__doc__ += fctr.__doc__  # type: ignore    def as_coeff_Mul(self, rational=False):        """Efficiently extract the coefficient of a product."""        return (S.One, self)    def as_coeff_add(self, *deps):        """Efficiently extract the coefficient of a summation."""        l = [x * self.components[x] for x in self.components]        return 0, tuple(l)    def diff(self, *args, **kwargs):        """        Implements the SymPy diff routine, for vectors.        diff's documentation        ========================        """        for x in args:            if isinstance(x, BasisDependent):                raise TypeError("Invalid arg for differentiation")        diff_components = [df(v, *args, **kwargs) * k for                           k, v in self.components.items()]        return self._add_func(*diff_components)    diff.__doc__ += df.__doc__  # type: ignore    def doit(self, **hints):        """Calls .doit() on each term in the Dyadic"""        doit_components = [self.components[x].doit(**hints) * x                           for x in self.components]        return self._add_func(*doit_components)class BasisDependentAdd(BasisDependent, Add):    """    Denotes sum of basis dependent quantities such that they cannot    be expressed as base or Mul instances.    """    def __new__(cls, *args, **options):        components = {}        # Check each arg and simultaneously learn the components        for i, arg in enumerate(args):            if not isinstance(arg, cls._expr_type):                if isinstance(arg, Mul):                    arg = cls._mul_func(*(arg.args))                elif isinstance(arg, Add):                    arg = cls._add_func(*(arg.args))                else:                    raise TypeError(str(arg) +                                    " cannot be interpreted correctly")            # If argument is zero, ignore            if arg == cls.zero:                continue            # Else, update components accordingly            if hasattr(arg, "components"):                for x in arg.components:                    components[x] = components.get(x, 0) + arg.components[x]        temp = list(components.keys())        for x in temp:            if components[x] == 0:                del components[x]        # Handle case of zero vector        if len(components) == 0:            return cls.zero        # Build object        newargs = [x * components[x] for x in components]        obj = super().__new__(cls, *newargs, **options)        if isinstance(obj, Mul):            return cls._mul_func(*obj.args)        assumptions = {'commutative': True}        obj._assumptions = StdFactKB(assumptions)        obj._components = components        obj._sys = (list(components.keys()))[0]._sys        return objclass BasisDependentMul(BasisDependent, Mul):    """    Denotes product of base- basis dependent quantity with a scalar.    """    def __new__(cls, *args, **options):        from sympy.vector import Cross, Dot, Curl, Gradient        count = 0        measure_number = S.One        zeroflag = False        extra_args = []        # Determine the component and check arguments        # Also keep a count to ensure two vectors aren't        # being multiplied        for arg in args:            if isinstance(arg, cls._zero_func):                count += 1                zeroflag = True            elif arg == S.Zero:                zeroflag = True            elif isinstance(arg, (cls._base_func, cls._mul_func)):                count += 1                expr = arg._base_instance                measure_number *= arg._measure_number            elif isinstance(arg, cls._add_func):                count += 1                expr = arg            elif isinstance(arg, (Cross, Dot, Curl, Gradient)):                extra_args.append(arg)            else:                measure_number *= arg        # Make sure incompatible types weren't multiplied        if count > 1:            raise ValueError("Invalid multiplication")        elif count == 0:            return Mul(*args, **options)        # Handle zero vector case        if zeroflag:            return cls.zero        # If one of the args was a VectorAdd, return an        # appropriate VectorAdd instance        if isinstance(expr, cls._add_func):            newargs = [cls._mul_func(measure_number, x) for                       x in expr.args]            return cls._add_func(*newargs)        obj = super().__new__(cls, measure_number,                              expr._base_instance,                              *extra_args,                              **options)        if isinstance(obj, Add):            return cls._add_func(*obj.args)        obj._base_instance = expr._base_instance        obj._measure_number = measure_number        assumptions = {'commutative': True}        obj._assumptions = StdFactKB(assumptions)        obj._components = {expr._base_instance: measure_number}        obj._sys = expr._base_instance._sys        return obj    def _sympystr(self, printer):        measure_str = printer._print(self._measure_number)        if ('(' in measure_str or '-' in measure_str or                '+' in measure_str):            measure_str = '(' + measure_str + ')'        return measure_str + '*' + printer._print(self._base_instance)class BasisDependentZero(BasisDependent):    """    Class to denote a zero basis dependent instance.    """    components: dict['BaseVector', Expr] = {}    _latex_form: str    def __new__(cls):        obj = super().__new__(cls)        # Pre-compute a specific hash value for the zero vector        # Use the same one always        obj._hash = (S.Zero, cls).__hash__()        return obj    def __hash__(self):        return self._hash    @call_highest_priority('__req__')    def __eq__(self, other):        return isinstance(other, self._zero_func)    __req__ = __eq__    @call_highest_priority('__radd__')    def __add__(self, other):        if isinstance(other, self._expr_type):            return other        else:            raise TypeError("Invalid argument types for addition")    @call_highest_priority('__add__')    def __radd__(self, other):        if isinstance(other, self._expr_type):            return other        else:            raise TypeError("Invalid argument types for addition")    @call_highest_priority('__rsub__')    def __sub__(self, other):        if isinstance(other, self._expr_type):            return -other        else:            raise TypeError("Invalid argument types for subtraction")    @call_highest_priority('__sub__')    def __rsub__(self, other):        if isinstance(other, self._expr_type):            return other        else:            raise TypeError("Invalid argument types for subtraction")    def __neg__(self):        return self    def normalize(self):        """        Returns the normalized version of this vector.        """        return self    def _sympystr(self, printer):        return '0'
 |