basisdependent.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from sympy.simplify import simplify as simp, trigsimp as tsimp # type: ignore
  4. from sympy.core.decorators import call_highest_priority, _sympifyit
  5. from sympy.core.assumptions import StdFactKB
  6. from sympy.core.function import diff as df
  7. from sympy.integrals.integrals import Integral
  8. from sympy.polys.polytools import factor as fctr
  9. from sympy.core import S, Add, Mul
  10. from sympy.core.expr import Expr
  11. if TYPE_CHECKING:
  12. from sympy.vector.vector import BaseVector
  13. class BasisDependent(Expr):
  14. """
  15. Super class containing functionality common to vectors and
  16. dyadics.
  17. Named so because the representation of these quantities in
  18. sympy.vector is dependent on the basis they are expressed in.
  19. """
  20. zero: BasisDependentZero
  21. @call_highest_priority('__radd__')
  22. def __add__(self, other):
  23. return self._add_func(self, other)
  24. @call_highest_priority('__add__')
  25. def __radd__(self, other):
  26. return self._add_func(other, self)
  27. @call_highest_priority('__rsub__')
  28. def __sub__(self, other):
  29. return self._add_func(self, -other)
  30. @call_highest_priority('__sub__')
  31. def __rsub__(self, other):
  32. return self._add_func(other, -self)
  33. @_sympifyit('other', NotImplemented)
  34. @call_highest_priority('__rmul__')
  35. def __mul__(self, other):
  36. return self._mul_func(self, other)
  37. @_sympifyit('other', NotImplemented)
  38. @call_highest_priority('__mul__')
  39. def __rmul__(self, other):
  40. return self._mul_func(other, self)
  41. def __neg__(self):
  42. return self._mul_func(S.NegativeOne, self)
  43. @_sympifyit('other', NotImplemented)
  44. @call_highest_priority('__rtruediv__')
  45. def __truediv__(self, other):
  46. return self._div_helper(other)
  47. @call_highest_priority('__truediv__')
  48. def __rtruediv__(self, other):
  49. return TypeError("Invalid divisor for division")
  50. def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False):
  51. """
  52. Implements the SymPy evalf routine for this quantity.
  53. evalf's documentation
  54. =====================
  55. """
  56. options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict,
  57. 'quad':quad, 'verbose':verbose}
  58. vec = self.zero
  59. for k, v in self.components.items():
  60. vec += v.evalf(n, **options) * k
  61. return vec
  62. evalf.__doc__ += Expr.evalf.__doc__ # type: ignore
  63. n = evalf
  64. def simplify(self, **kwargs):
  65. """
  66. Implements the SymPy simplify routine for this quantity.
  67. simplify's documentation
  68. ========================
  69. """
  70. simp_components = [simp(v, **kwargs) * k for
  71. k, v in self.components.items()]
  72. return self._add_func(*simp_components)
  73. simplify.__doc__ += simp.__doc__ # type: ignore
  74. def trigsimp(self, **opts):
  75. """
  76. Implements the SymPy trigsimp routine, for this quantity.
  77. trigsimp's documentation
  78. ========================
  79. """
  80. trig_components = [tsimp(v, **opts) * k for
  81. k, v in self.components.items()]
  82. return self._add_func(*trig_components)
  83. trigsimp.__doc__ += tsimp.__doc__ # type: ignore
  84. def _eval_simplify(self, **kwargs):
  85. return self.simplify(**kwargs)
  86. def _eval_trigsimp(self, **opts):
  87. return self.trigsimp(**opts)
  88. def _eval_derivative(self, wrt):
  89. return self.diff(wrt)
  90. def _eval_Integral(self, *symbols, **assumptions):
  91. integral_components = [Integral(v, *symbols, **assumptions) * k
  92. for k, v in self.components.items()]
  93. return self._add_func(*integral_components)
  94. def as_numer_denom(self):
  95. """
  96. Returns the expression as a tuple wrt the following
  97. transformation -
  98. expression -> a/b -> a, b
  99. """
  100. return self, S.One
  101. def factor(self, *args, **kwargs):
  102. """
  103. Implements the SymPy factor routine, on the scalar parts
  104. of a basis-dependent expression.
  105. factor's documentation
  106. ========================
  107. """
  108. fctr_components = [fctr(v, *args, **kwargs) * k for
  109. k, v in self.components.items()]
  110. return self._add_func(*fctr_components)
  111. factor.__doc__ += fctr.__doc__ # type: ignore
  112. def as_coeff_Mul(self, rational=False):
  113. """Efficiently extract the coefficient of a product."""
  114. return (S.One, self)
  115. def as_coeff_add(self, *deps):
  116. """Efficiently extract the coefficient of a summation."""
  117. l = [x * self.components[x] for x in self.components]
  118. return 0, tuple(l)
  119. def diff(self, *args, **kwargs):
  120. """
  121. Implements the SymPy diff routine, for vectors.
  122. diff's documentation
  123. ========================
  124. """
  125. for x in args:
  126. if isinstance(x, BasisDependent):
  127. raise TypeError("Invalid arg for differentiation")
  128. diff_components = [df(v, *args, **kwargs) * k for
  129. k, v in self.components.items()]
  130. return self._add_func(*diff_components)
  131. diff.__doc__ += df.__doc__ # type: ignore
  132. def doit(self, **hints):
  133. """Calls .doit() on each term in the Dyadic"""
  134. doit_components = [self.components[x].doit(**hints) * x
  135. for x in self.components]
  136. return self._add_func(*doit_components)
  137. class BasisDependentAdd(BasisDependent, Add):
  138. """
  139. Denotes sum of basis dependent quantities such that they cannot
  140. be expressed as base or Mul instances.
  141. """
  142. def __new__(cls, *args, **options):
  143. components = {}
  144. # Check each arg and simultaneously learn the components
  145. for i, arg in enumerate(args):
  146. if not isinstance(arg, cls._expr_type):
  147. if isinstance(arg, Mul):
  148. arg = cls._mul_func(*(arg.args))
  149. elif isinstance(arg, Add):
  150. arg = cls._add_func(*(arg.args))
  151. else:
  152. raise TypeError(str(arg) +
  153. " cannot be interpreted correctly")
  154. # If argument is zero, ignore
  155. if arg == cls.zero:
  156. continue
  157. # Else, update components accordingly
  158. if hasattr(arg, "components"):
  159. for x in arg.components:
  160. components[x] = components.get(x, 0) + arg.components[x]
  161. temp = list(components.keys())
  162. for x in temp:
  163. if components[x] == 0:
  164. del components[x]
  165. # Handle case of zero vector
  166. if len(components) == 0:
  167. return cls.zero
  168. # Build object
  169. newargs = [x * components[x] for x in components]
  170. obj = super().__new__(cls, *newargs, **options)
  171. if isinstance(obj, Mul):
  172. return cls._mul_func(*obj.args)
  173. assumptions = {'commutative': True}
  174. obj._assumptions = StdFactKB(assumptions)
  175. obj._components = components
  176. obj._sys = (list(components.keys()))[0]._sys
  177. return obj
  178. class BasisDependentMul(BasisDependent, Mul):
  179. """
  180. Denotes product of base- basis dependent quantity with a scalar.
  181. """
  182. def __new__(cls, *args, **options):
  183. from sympy.vector import Cross, Dot, Curl, Gradient
  184. count = 0
  185. measure_number = S.One
  186. zeroflag = False
  187. extra_args = []
  188. # Determine the component and check arguments
  189. # Also keep a count to ensure two vectors aren't
  190. # being multiplied
  191. for arg in args:
  192. if isinstance(arg, cls._zero_func):
  193. count += 1
  194. zeroflag = True
  195. elif arg == S.Zero:
  196. zeroflag = True
  197. elif isinstance(arg, (cls._base_func, cls._mul_func)):
  198. count += 1
  199. expr = arg._base_instance
  200. measure_number *= arg._measure_number
  201. elif isinstance(arg, cls._add_func):
  202. count += 1
  203. expr = arg
  204. elif isinstance(arg, (Cross, Dot, Curl, Gradient)):
  205. extra_args.append(arg)
  206. else:
  207. measure_number *= arg
  208. # Make sure incompatible types weren't multiplied
  209. if count > 1:
  210. raise ValueError("Invalid multiplication")
  211. elif count == 0:
  212. return Mul(*args, **options)
  213. # Handle zero vector case
  214. if zeroflag:
  215. return cls.zero
  216. # If one of the args was a VectorAdd, return an
  217. # appropriate VectorAdd instance
  218. if isinstance(expr, cls._add_func):
  219. newargs = [cls._mul_func(measure_number, x) for
  220. x in expr.args]
  221. return cls._add_func(*newargs)
  222. obj = super().__new__(cls, measure_number,
  223. expr._base_instance,
  224. *extra_args,
  225. **options)
  226. if isinstance(obj, Add):
  227. return cls._add_func(*obj.args)
  228. obj._base_instance = expr._base_instance
  229. obj._measure_number = measure_number
  230. assumptions = {'commutative': True}
  231. obj._assumptions = StdFactKB(assumptions)
  232. obj._components = {expr._base_instance: measure_number}
  233. obj._sys = expr._base_instance._sys
  234. return obj
  235. def _sympystr(self, printer):
  236. measure_str = printer._print(self._measure_number)
  237. if ('(' in measure_str or '-' in measure_str or
  238. '+' in measure_str):
  239. measure_str = '(' + measure_str + ')'
  240. return measure_str + '*' + printer._print(self._base_instance)
  241. class BasisDependentZero(BasisDependent):
  242. """
  243. Class to denote a zero basis dependent instance.
  244. """
  245. components: dict['BaseVector', Expr] = {}
  246. _latex_form: str
  247. def __new__(cls):
  248. obj = super().__new__(cls)
  249. # Pre-compute a specific hash value for the zero vector
  250. # Use the same one always
  251. obj._hash = (S.Zero, cls).__hash__()
  252. return obj
  253. def __hash__(self):
  254. return self._hash
  255. @call_highest_priority('__req__')
  256. def __eq__(self, other):
  257. return isinstance(other, self._zero_func)
  258. __req__ = __eq__
  259. @call_highest_priority('__radd__')
  260. def __add__(self, other):
  261. if isinstance(other, self._expr_type):
  262. return other
  263. else:
  264. raise TypeError("Invalid argument types for addition")
  265. @call_highest_priority('__add__')
  266. def __radd__(self, other):
  267. if isinstance(other, self._expr_type):
  268. return other
  269. else:
  270. raise TypeError("Invalid argument types for addition")
  271. @call_highest_priority('__rsub__')
  272. def __sub__(self, other):
  273. if isinstance(other, self._expr_type):
  274. return -other
  275. else:
  276. raise TypeError("Invalid argument types for subtraction")
  277. @call_highest_priority('__sub__')
  278. def __rsub__(self, other):
  279. if isinstance(other, self._expr_type):
  280. return other
  281. else:
  282. raise TypeError("Invalid argument types for subtraction")
  283. def __neg__(self):
  284. return self
  285. def normalize(self):
  286. """
  287. Returns the normalized version of this vector.
  288. """
  289. return self
  290. def _sympystr(self, printer):
  291. return '0'