operators.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import collections
  2. from sympy.core.expr import Expr
  3. from sympy.core import sympify, S, preorder_traversal
  4. from sympy.vector.coordsysrect import CoordSys3D
  5. from sympy.vector.vector import Vector, VectorMul, VectorAdd, Cross, Dot
  6. from sympy.core.function import Derivative
  7. from sympy.core.add import Add
  8. from sympy.core.mul import Mul
  9. def _get_coord_systems(expr):
  10. g = preorder_traversal(expr)
  11. ret = set()
  12. for i in g:
  13. if isinstance(i, CoordSys3D):
  14. ret.add(i)
  15. g.skip()
  16. return frozenset(ret)
  17. def _split_mul_args_wrt_coordsys(expr):
  18. d = collections.defaultdict(lambda: S.One)
  19. for i in expr.args:
  20. d[_get_coord_systems(i)] *= i
  21. return list(d.values())
  22. class Gradient(Expr):
  23. """
  24. Represents unevaluated Gradient.
  25. Examples
  26. ========
  27. >>> from sympy.vector import CoordSys3D, Gradient
  28. >>> R = CoordSys3D('R')
  29. >>> s = R.x*R.y*R.z
  30. >>> Gradient(s)
  31. Gradient(R.x*R.y*R.z)
  32. """
  33. def __new__(cls, expr):
  34. expr = sympify(expr)
  35. obj = Expr.__new__(cls, expr)
  36. obj._expr = expr
  37. return obj
  38. def doit(self, **hints):
  39. return gradient(self._expr, doit=True)
  40. class Divergence(Expr):
  41. """
  42. Represents unevaluated Divergence.
  43. Examples
  44. ========
  45. >>> from sympy.vector import CoordSys3D, Divergence
  46. >>> R = CoordSys3D('R')
  47. >>> v = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k
  48. >>> Divergence(v)
  49. Divergence(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k)
  50. """
  51. def __new__(cls, expr):
  52. expr = sympify(expr)
  53. obj = Expr.__new__(cls, expr)
  54. obj._expr = expr
  55. return obj
  56. def doit(self, **hints):
  57. return divergence(self._expr, doit=True)
  58. class Curl(Expr):
  59. """
  60. Represents unevaluated Curl.
  61. Examples
  62. ========
  63. >>> from sympy.vector import CoordSys3D, Curl
  64. >>> R = CoordSys3D('R')
  65. >>> v = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k
  66. >>> Curl(v)
  67. Curl(R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k)
  68. """
  69. def __new__(cls, expr):
  70. expr = sympify(expr)
  71. obj = Expr.__new__(cls, expr)
  72. obj._expr = expr
  73. return obj
  74. def doit(self, **hints):
  75. return curl(self._expr, doit=True)
  76. def curl(vect, doit=True):
  77. """
  78. Returns the curl of a vector field computed wrt the base scalars
  79. of the given coordinate system.
  80. Parameters
  81. ==========
  82. vect : Vector
  83. The vector operand
  84. doit : bool
  85. If True, the result is returned after calling .doit() on
  86. each component. Else, the returned expression contains
  87. Derivative instances
  88. Examples
  89. ========
  90. >>> from sympy.vector import CoordSys3D, curl
  91. >>> R = CoordSys3D('R')
  92. >>> v1 = R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k
  93. >>> curl(v1)
  94. 0
  95. >>> v2 = R.x*R.y*R.z*R.i
  96. >>> curl(v2)
  97. R.x*R.y*R.j + (-R.x*R.z)*R.k
  98. """
  99. coord_sys = _get_coord_systems(vect)
  100. if len(coord_sys) == 0:
  101. return Vector.zero
  102. elif len(coord_sys) == 1:
  103. coord_sys = next(iter(coord_sys))
  104. i, j, k = coord_sys.base_vectors()
  105. x, y, z = coord_sys.base_scalars()
  106. h1, h2, h3 = coord_sys.lame_coefficients()
  107. vectx = vect.dot(i)
  108. vecty = vect.dot(j)
  109. vectz = vect.dot(k)
  110. outvec = Vector.zero
  111. outvec += (Derivative(vectz * h3, y) -
  112. Derivative(vecty * h2, z)) * i / (h2 * h3)
  113. outvec += (Derivative(vectx * h1, z) -
  114. Derivative(vectz * h3, x)) * j / (h1 * h3)
  115. outvec += (Derivative(vecty * h2, x) -
  116. Derivative(vectx * h1, y)) * k / (h2 * h1)
  117. if doit:
  118. return outvec.doit()
  119. return outvec
  120. else:
  121. if isinstance(vect, (Add, VectorAdd)):
  122. from sympy.vector import express
  123. try:
  124. cs = next(iter(coord_sys))
  125. args = [express(i, cs, variables=True) for i in vect.args]
  126. except ValueError:
  127. args = vect.args
  128. return VectorAdd.fromiter(curl(i, doit=doit) for i in args)
  129. elif isinstance(vect, (Mul, VectorMul)):
  130. vector = [i for i in vect.args if isinstance(i, (Vector, Cross, Gradient))][0]
  131. scalar = Mul.fromiter(i for i in vect.args if not isinstance(i, (Vector, Cross, Gradient)))
  132. res = Cross(gradient(scalar), vector).doit() + scalar*curl(vector, doit=doit)
  133. if doit:
  134. return res.doit()
  135. return res
  136. elif isinstance(vect, (Cross, Curl, Gradient)):
  137. return Curl(vect)
  138. else:
  139. raise Curl(vect)
  140. def divergence(vect, doit=True):
  141. """
  142. Returns the divergence of a vector field computed wrt the base
  143. scalars of the given coordinate system.
  144. Parameters
  145. ==========
  146. vector : Vector
  147. The vector operand
  148. doit : bool
  149. If True, the result is returned after calling .doit() on
  150. each component. Else, the returned expression contains
  151. Derivative instances
  152. Examples
  153. ========
  154. >>> from sympy.vector import CoordSys3D, divergence
  155. >>> R = CoordSys3D('R')
  156. >>> v1 = R.x*R.y*R.z * (R.i+R.j+R.k)
  157. >>> divergence(v1)
  158. R.x*R.y + R.x*R.z + R.y*R.z
  159. >>> v2 = 2*R.y*R.z*R.j
  160. >>> divergence(v2)
  161. 2*R.z
  162. """
  163. coord_sys = _get_coord_systems(vect)
  164. if len(coord_sys) == 0:
  165. return S.Zero
  166. elif len(coord_sys) == 1:
  167. if isinstance(vect, (Cross, Curl, Gradient)):
  168. return Divergence(vect)
  169. # TODO: is case of many coord systems, this gets a random one:
  170. coord_sys = next(iter(coord_sys))
  171. i, j, k = coord_sys.base_vectors()
  172. x, y, z = coord_sys.base_scalars()
  173. h1, h2, h3 = coord_sys.lame_coefficients()
  174. vx = _diff_conditional(vect.dot(i), x, h2, h3) \
  175. / (h1 * h2 * h3)
  176. vy = _diff_conditional(vect.dot(j), y, h3, h1) \
  177. / (h1 * h2 * h3)
  178. vz = _diff_conditional(vect.dot(k), z, h1, h2) \
  179. / (h1 * h2 * h3)
  180. res = vx + vy + vz
  181. if doit:
  182. return res.doit()
  183. return res
  184. else:
  185. if isinstance(vect, (Add, VectorAdd)):
  186. return Add.fromiter(divergence(i, doit=doit) for i in vect.args)
  187. elif isinstance(vect, (Mul, VectorMul)):
  188. vector = [i for i in vect.args if isinstance(i, (Vector, Cross, Gradient))][0]
  189. scalar = Mul.fromiter(i for i in vect.args if not isinstance(i, (Vector, Cross, Gradient)))
  190. res = Dot(vector, gradient(scalar)) + scalar*divergence(vector, doit=doit)
  191. if doit:
  192. return res.doit()
  193. return res
  194. elif isinstance(vect, (Cross, Curl, Gradient)):
  195. return Divergence(vect)
  196. else:
  197. raise Divergence(vect)
  198. def gradient(scalar_field, doit=True):
  199. """
  200. Returns the vector gradient of a scalar field computed wrt the
  201. base scalars of the given coordinate system.
  202. Parameters
  203. ==========
  204. scalar_field : SymPy Expr
  205. The scalar field to compute the gradient of
  206. doit : bool
  207. If True, the result is returned after calling .doit() on
  208. each component. Else, the returned expression contains
  209. Derivative instances
  210. Examples
  211. ========
  212. >>> from sympy.vector import CoordSys3D, gradient
  213. >>> R = CoordSys3D('R')
  214. >>> s1 = R.x*R.y*R.z
  215. >>> gradient(s1)
  216. R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k
  217. >>> s2 = 5*R.x**2*R.z
  218. >>> gradient(s2)
  219. 10*R.x*R.z*R.i + 5*R.x**2*R.k
  220. """
  221. coord_sys = _get_coord_systems(scalar_field)
  222. if len(coord_sys) == 0:
  223. return Vector.zero
  224. elif len(coord_sys) == 1:
  225. coord_sys = next(iter(coord_sys))
  226. h1, h2, h3 = coord_sys.lame_coefficients()
  227. i, j, k = coord_sys.base_vectors()
  228. x, y, z = coord_sys.base_scalars()
  229. vx = Derivative(scalar_field, x) / h1
  230. vy = Derivative(scalar_field, y) / h2
  231. vz = Derivative(scalar_field, z) / h3
  232. if doit:
  233. return (vx * i + vy * j + vz * k).doit()
  234. return vx * i + vy * j + vz * k
  235. else:
  236. if isinstance(scalar_field, (Add, VectorAdd)):
  237. return VectorAdd.fromiter(gradient(i) for i in scalar_field.args)
  238. if isinstance(scalar_field, (Mul, VectorMul)):
  239. s = _split_mul_args_wrt_coordsys(scalar_field)
  240. return VectorAdd.fromiter(scalar_field / i * gradient(i) for i in s)
  241. return Gradient(scalar_field)
  242. class Laplacian(Expr):
  243. """
  244. Represents unevaluated Laplacian.
  245. Examples
  246. ========
  247. >>> from sympy.vector import CoordSys3D, Laplacian
  248. >>> R = CoordSys3D('R')
  249. >>> v = 3*R.x**3*R.y**2*R.z**3
  250. >>> Laplacian(v)
  251. Laplacian(3*R.x**3*R.y**2*R.z**3)
  252. """
  253. def __new__(cls, expr):
  254. expr = sympify(expr)
  255. obj = Expr.__new__(cls, expr)
  256. obj._expr = expr
  257. return obj
  258. def doit(self, **hints):
  259. from sympy.vector.functions import laplacian
  260. return laplacian(self._expr)
  261. def _diff_conditional(expr, base_scalar, coeff_1, coeff_2):
  262. """
  263. First re-expresses expr in the system that base_scalar belongs to.
  264. If base_scalar appears in the re-expressed form, differentiates
  265. it wrt base_scalar.
  266. Else, returns 0
  267. """
  268. from sympy.vector.functions import express
  269. new_expr = express(expr, base_scalar.system, variables=True)
  270. arg = coeff_1 * coeff_2 * new_expr
  271. return Derivative(arg, base_scalar) if arg else S.Zero