test_diffgeom.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. from sympy.core import Lambda, Symbol, symbols
  2. from sympy.diffgeom.rn import R2, R2_p, R2_r, R3_r, R3_c, R3_s, R2_origin
  3. from sympy.diffgeom import (Manifold, Patch, CoordSystem, Commutator, Differential, TensorProduct,
  4. WedgeProduct, BaseCovarDerivativeOp, CovarDerivativeOp, LieDerivative,
  5. covariant_order, contravariant_order, twoform_to_matrix, metric_to_Christoffel_1st,
  6. metric_to_Christoffel_2nd, metric_to_Riemann_components,
  7. metric_to_Ricci_components, intcurve_diffequ, intcurve_series)
  8. from sympy.simplify import trigsimp, simplify
  9. from sympy.functions import sqrt, atan2, sin
  10. from sympy.matrices import Matrix
  11. from sympy.testing.pytest import raises, nocache_fail
  12. from sympy.testing.pytest import warns_deprecated_sympy
  13. TP = TensorProduct
  14. def test_coordsys_transform():
  15. # test inverse transforms
  16. p, q, r, s = symbols('p q r s')
  17. rel = {('first', 'second'): [(p, q), (q, -p)]}
  18. R2_pq = CoordSystem('first', R2_origin, [p, q], rel)
  19. R2_rs = CoordSystem('second', R2_origin, [r, s], rel)
  20. r, s = R2_rs.symbols
  21. assert R2_rs.transform(R2_pq) == Matrix([[-s], [r]])
  22. # inverse transform impossible case
  23. a, b = symbols('a b', positive=True)
  24. rel = {('first', 'second'): [(a,), (-a,)]}
  25. R2_a = CoordSystem('first', R2_origin, [a], rel)
  26. R2_b = CoordSystem('second', R2_origin, [b], rel)
  27. # This transformation is uninvertible because there is no positive a, b satisfying a = -b
  28. with raises(NotImplementedError):
  29. R2_b.transform(R2_a)
  30. # inverse transform ambiguous case
  31. c, d = symbols('c d')
  32. rel = {('first', 'second'): [(c,), (c**2,)]}
  33. R2_c = CoordSystem('first', R2_origin, [c], rel)
  34. R2_d = CoordSystem('second', R2_origin, [d], rel)
  35. # The transform method should throw if it finds multiple inverses for a coordinate transformation.
  36. with raises(ValueError):
  37. R2_d.transform(R2_c)
  38. # test indirect transformation
  39. a, b, c, d, e, f = symbols('a, b, c, d, e, f')
  40. rel = {('C1', 'C2'): [(a, b), (2*a, 3*b)],
  41. ('C2', 'C3'): [(c, d), (3*c, 2*d)]}
  42. C1 = CoordSystem('C1', R2_origin, (a, b), rel)
  43. C2 = CoordSystem('C2', R2_origin, (c, d), rel)
  44. C3 = CoordSystem('C3', R2_origin, (e, f), rel)
  45. a, b = C1.symbols
  46. c, d = C2.symbols
  47. e, f = C3.symbols
  48. assert C2.transform(C1) == Matrix([c/2, d/3])
  49. assert C1.transform(C3) == Matrix([6*a, 6*b])
  50. assert C3.transform(C1) == Matrix([e/6, f/6])
  51. assert C3.transform(C2) == Matrix([e/3, f/2])
  52. a, b, c, d, e, f = symbols('a, b, c, d, e, f')
  53. rel = {('C1', 'C2'): [(a, b), (2*a, 3*b + 1)],
  54. ('C3', 'C2'): [(e, f), (-e - 2, 2*f)]}
  55. C1 = CoordSystem('C1', R2_origin, (a, b), rel)
  56. C2 = CoordSystem('C2', R2_origin, (c, d), rel)
  57. C3 = CoordSystem('C3', R2_origin, (e, f), rel)
  58. a, b = C1.symbols
  59. c, d = C2.symbols
  60. e, f = C3.symbols
  61. assert C2.transform(C1) == Matrix([c/2, (d - 1)/3])
  62. assert C1.transform(C3) == Matrix([-2*a - 2, (3*b + 1)/2])
  63. assert C3.transform(C1) == Matrix([-e/2 - 1, (2*f - 1)/3])
  64. assert C3.transform(C2) == Matrix([-e - 2, 2*f])
  65. # old signature uses Lambda
  66. a, b, c, d, e, f = symbols('a, b, c, d, e, f')
  67. rel = {('C1', 'C2'): Lambda((a, b), (2*a, 3*b + 1)),
  68. ('C3', 'C2'): Lambda((e, f), (-e - 2, 2*f))}
  69. C1 = CoordSystem('C1', R2_origin, (a, b), rel)
  70. C2 = CoordSystem('C2', R2_origin, (c, d), rel)
  71. C3 = CoordSystem('C3', R2_origin, (e, f), rel)
  72. a, b = C1.symbols
  73. c, d = C2.symbols
  74. e, f = C3.symbols
  75. assert C2.transform(C1) == Matrix([c/2, (d - 1)/3])
  76. assert C1.transform(C3) == Matrix([-2*a - 2, (3*b + 1)/2])
  77. assert C3.transform(C1) == Matrix([-e/2 - 1, (2*f - 1)/3])
  78. assert C3.transform(C2) == Matrix([-e - 2, 2*f])
  79. def test_R2():
  80. x0, y0, r0, theta0 = symbols('x0, y0, r0, theta0', real=True)
  81. point_r = R2_r.point([x0, y0])
  82. point_p = R2_p.point([r0, theta0])
  83. # r**2 = x**2 + y**2
  84. assert (R2.r**2 - R2.x**2 - R2.y**2).rcall(point_r) == 0
  85. assert trigsimp( (R2.r**2 - R2.x**2 - R2.y**2).rcall(point_p) ) == 0
  86. assert trigsimp(R2.e_r(R2.x**2 + R2.y**2).rcall(point_p).doit()) == 2*r0
  87. # polar->rect->polar == Id
  88. a, b = symbols('a b', positive=True)
  89. m = Matrix([[a], [b]])
  90. #TODO assert m == R2_r.transform(R2_p, R2_p.transform(R2_r, [a, b])).applyfunc(simplify)
  91. assert m == R2_p.transform(R2_r, R2_r.transform(R2_p, m)).applyfunc(simplify)
  92. # deprecated method
  93. with warns_deprecated_sympy():
  94. assert m == R2_p.coord_tuple_transform_to(
  95. R2_r, R2_r.coord_tuple_transform_to(R2_p, m)).applyfunc(simplify)
  96. def test_R3():
  97. a, b, c = symbols('a b c', positive=True)
  98. m = Matrix([[a], [b], [c]])
  99. assert m == R3_c.transform(R3_r, R3_r.transform(R3_c, m)).applyfunc(simplify)
  100. #TODO assert m == R3_r.transform(R3_c, R3_c.transform(R3_r, m)).applyfunc(simplify)
  101. assert m == R3_s.transform(
  102. R3_r, R3_r.transform(R3_s, m)).applyfunc(simplify)
  103. #TODO assert m == R3_r.transform(R3_s, R3_s.transform(R3_r, m)).applyfunc(simplify)
  104. assert m == R3_s.transform(
  105. R3_c, R3_c.transform(R3_s, m)).applyfunc(simplify)
  106. #TODO assert m == R3_c.transform(R3_s, R3_s.transform(R3_c, m)).applyfunc(simplify)
  107. with warns_deprecated_sympy():
  108. assert m == R3_c.coord_tuple_transform_to(
  109. R3_r, R3_r.coord_tuple_transform_to(R3_c, m)).applyfunc(simplify)
  110. #TODO assert m == R3_r.coord_tuple_transform_to(R3_c, R3_c.coord_tuple_transform_to(R3_r, m)).applyfunc(simplify)
  111. assert m == R3_s.coord_tuple_transform_to(
  112. R3_r, R3_r.coord_tuple_transform_to(R3_s, m)).applyfunc(simplify)
  113. #TODO assert m == R3_r.coord_tuple_transform_to(R3_s, R3_s.coord_tuple_transform_to(R3_r, m)).applyfunc(simplify)
  114. assert m == R3_s.coord_tuple_transform_to(
  115. R3_c, R3_c.coord_tuple_transform_to(R3_s, m)).applyfunc(simplify)
  116. #TODO assert m == R3_c.coord_tuple_transform_to(R3_s, R3_s.coord_tuple_transform_to(R3_c, m)).applyfunc(simplify)
  117. def test_CoordinateSymbol():
  118. x, y = R2_r.symbols
  119. r, theta = R2_p.symbols
  120. assert y.rewrite(R2_p) == r*sin(theta)
  121. def test_point():
  122. x, y = symbols('x, y')
  123. p = R2_r.point([x, y])
  124. assert p.free_symbols == {x, y}
  125. assert p.coords(R2_r) == p.coords() == Matrix([x, y])
  126. assert p.coords(R2_p) == Matrix([sqrt(x**2 + y**2), atan2(y, x)])
  127. def test_commutator():
  128. assert Commutator(R2.e_x, R2.e_y) == 0
  129. assert Commutator(R2.x*R2.e_x, R2.x*R2.e_x) == 0
  130. assert Commutator(R2.x*R2.e_x, R2.x*R2.e_y) == R2.x*R2.e_y
  131. c = Commutator(R2.e_x, R2.e_r)
  132. assert c(R2.x) == R2.y*(R2.x**2 + R2.y**2)**(-1)*sin(R2.theta)
  133. def test_differential():
  134. xdy = R2.x*R2.dy
  135. dxdy = Differential(xdy)
  136. assert xdy.rcall(None) == xdy
  137. assert dxdy(R2.e_x, R2.e_y) == 1
  138. assert dxdy(R2.e_x, R2.x*R2.e_y) == R2.x
  139. assert Differential(dxdy) == 0
  140. def test_products():
  141. assert TensorProduct(
  142. R2.dx, R2.dy)(R2.e_x, R2.e_y) == R2.dx(R2.e_x)*R2.dy(R2.e_y) == 1
  143. assert TensorProduct(R2.dx, R2.dy)(None, R2.e_y) == R2.dx
  144. assert TensorProduct(R2.dx, R2.dy)(R2.e_x, None) == R2.dy
  145. assert TensorProduct(R2.dx, R2.dy)(R2.e_x) == R2.dy
  146. assert TensorProduct(R2.x, R2.dx) == R2.x*R2.dx
  147. assert TensorProduct(
  148. R2.e_x, R2.e_y)(R2.x, R2.y) == R2.e_x(R2.x) * R2.e_y(R2.y) == 1
  149. assert TensorProduct(R2.e_x, R2.e_y)(None, R2.y) == R2.e_x
  150. assert TensorProduct(R2.e_x, R2.e_y)(R2.x, None) == R2.e_y
  151. assert TensorProduct(R2.e_x, R2.e_y)(R2.x) == R2.e_y
  152. assert TensorProduct(R2.x, R2.e_x) == R2.x * R2.e_x
  153. assert TensorProduct(
  154. R2.dx, R2.e_y)(R2.e_x, R2.y) == R2.dx(R2.e_x) * R2.e_y(R2.y) == 1
  155. assert TensorProduct(R2.dx, R2.e_y)(None, R2.y) == R2.dx
  156. assert TensorProduct(R2.dx, R2.e_y)(R2.e_x, None) == R2.e_y
  157. assert TensorProduct(R2.dx, R2.e_y)(R2.e_x) == R2.e_y
  158. assert TensorProduct(R2.x, R2.e_x) == R2.x * R2.e_x
  159. assert TensorProduct(
  160. R2.e_x, R2.dy)(R2.x, R2.e_y) == R2.e_x(R2.x) * R2.dy(R2.e_y) == 1
  161. assert TensorProduct(R2.e_x, R2.dy)(None, R2.e_y) == R2.e_x
  162. assert TensorProduct(R2.e_x, R2.dy)(R2.x, None) == R2.dy
  163. assert TensorProduct(R2.e_x, R2.dy)(R2.x) == R2.dy
  164. assert TensorProduct(R2.e_y,R2.e_x)(R2.x**2 + R2.y**2,R2.x**2 + R2.y**2) == 4*R2.x*R2.y
  165. assert WedgeProduct(R2.dx, R2.dy)(R2.e_x, R2.e_y) == 1
  166. assert WedgeProduct(R2.e_x, R2.e_y)(R2.x, R2.y) == 1
  167. def test_lie_derivative():
  168. assert LieDerivative(R2.e_x, R2.y) == R2.e_x(R2.y) == 0
  169. assert LieDerivative(R2.e_x, R2.x) == R2.e_x(R2.x) == 1
  170. assert LieDerivative(R2.e_x, R2.e_x) == Commutator(R2.e_x, R2.e_x) == 0
  171. assert LieDerivative(R2.e_x, R2.e_r) == Commutator(R2.e_x, R2.e_r)
  172. assert LieDerivative(R2.e_x + R2.e_y, R2.x) == 1
  173. assert LieDerivative(
  174. R2.e_x, TensorProduct(R2.dx, R2.dy))(R2.e_x, R2.e_y) == 0
  175. @nocache_fail
  176. def test_covar_deriv():
  177. ch = metric_to_Christoffel_2nd(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
  178. cvd = BaseCovarDerivativeOp(R2_r, 0, ch)
  179. assert cvd(R2.x) == 1
  180. # This line fails if the cache is disabled:
  181. assert cvd(R2.x*R2.e_x) == R2.e_x
  182. cvd = CovarDerivativeOp(R2.x*R2.e_x, ch)
  183. assert cvd(R2.x) == R2.x
  184. assert cvd(R2.x*R2.e_x) == R2.x*R2.e_x
  185. def test_intcurve_diffequ():
  186. t = symbols('t')
  187. start_point = R2_r.point([1, 0])
  188. vector_field = -R2.y*R2.e_x + R2.x*R2.e_y
  189. equations, init_cond = intcurve_diffequ(vector_field, t, start_point)
  190. assert str(equations) == '[f_1(t) + Derivative(f_0(t), t), -f_0(t) + Derivative(f_1(t), t)]'
  191. assert str(init_cond) == '[f_0(0) - 1, f_1(0)]'
  192. equations, init_cond = intcurve_diffequ(vector_field, t, start_point, R2_p)
  193. assert str(
  194. equations) == '[Derivative(f_0(t), t), Derivative(f_1(t), t) - 1]'
  195. assert str(init_cond) == '[f_0(0) - 1, f_1(0)]'
  196. def test_helpers_and_coordinate_dependent():
  197. one_form = R2.dr + R2.dx
  198. two_form = Differential(R2.x*R2.dr + R2.r*R2.dx)
  199. three_form = Differential(
  200. R2.y*two_form) + Differential(R2.x*Differential(R2.r*R2.dr))
  201. metric = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dy, R2.dy)
  202. metric_ambig = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dr, R2.dr)
  203. misform_a = TensorProduct(R2.dr, R2.dr) + R2.dr
  204. misform_b = R2.dr**4
  205. misform_c = R2.dx*R2.dy
  206. twoform_not_sym = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dx, R2.dy)
  207. twoform_not_TP = WedgeProduct(R2.dx, R2.dy)
  208. one_vector = R2.e_x + R2.e_y
  209. two_vector = TensorProduct(R2.e_x, R2.e_y)
  210. three_vector = TensorProduct(R2.e_x, R2.e_y, R2.e_x)
  211. two_wp = WedgeProduct(R2.e_x,R2.e_y)
  212. assert covariant_order(one_form) == 1
  213. assert covariant_order(two_form) == 2
  214. assert covariant_order(three_form) == 3
  215. assert covariant_order(two_form + metric) == 2
  216. assert covariant_order(two_form + metric_ambig) == 2
  217. assert covariant_order(two_form + twoform_not_sym) == 2
  218. assert covariant_order(two_form + twoform_not_TP) == 2
  219. assert contravariant_order(one_vector) == 1
  220. assert contravariant_order(two_vector) == 2
  221. assert contravariant_order(three_vector) == 3
  222. assert contravariant_order(two_vector + two_wp) == 2
  223. raises(ValueError, lambda: covariant_order(misform_a))
  224. raises(ValueError, lambda: covariant_order(misform_b))
  225. raises(ValueError, lambda: covariant_order(misform_c))
  226. assert twoform_to_matrix(metric) == Matrix([[1, 0], [0, 1]])
  227. assert twoform_to_matrix(twoform_not_sym) == Matrix([[1, 0], [1, 0]])
  228. assert twoform_to_matrix(twoform_not_TP) == Matrix([[0, -1], [1, 0]])
  229. raises(ValueError, lambda: twoform_to_matrix(one_form))
  230. raises(ValueError, lambda: twoform_to_matrix(three_form))
  231. raises(ValueError, lambda: twoform_to_matrix(metric_ambig))
  232. raises(ValueError, lambda: metric_to_Christoffel_1st(twoform_not_sym))
  233. raises(ValueError, lambda: metric_to_Christoffel_2nd(twoform_not_sym))
  234. raises(ValueError, lambda: metric_to_Riemann_components(twoform_not_sym))
  235. raises(ValueError, lambda: metric_to_Ricci_components(twoform_not_sym))
  236. def test_correct_arguments():
  237. raises(ValueError, lambda: R2.e_x(R2.e_x))
  238. raises(ValueError, lambda: R2.e_x(R2.dx))
  239. raises(ValueError, lambda: Commutator(R2.e_x, R2.x))
  240. raises(ValueError, lambda: Commutator(R2.dx, R2.e_x))
  241. raises(ValueError, lambda: Differential(Differential(R2.e_x)))
  242. raises(ValueError, lambda: R2.dx(R2.x))
  243. raises(ValueError, lambda: LieDerivative(R2.dx, R2.dx))
  244. raises(ValueError, lambda: LieDerivative(R2.x, R2.dx))
  245. raises(ValueError, lambda: CovarDerivativeOp(R2.dx, []))
  246. raises(ValueError, lambda: CovarDerivativeOp(R2.x, []))
  247. a = Symbol('a')
  248. raises(ValueError, lambda: intcurve_series(R2.dx, a, R2_r.point([1, 2])))
  249. raises(ValueError, lambda: intcurve_series(R2.x, a, R2_r.point([1, 2])))
  250. raises(ValueError, lambda: intcurve_diffequ(R2.dx, a, R2_r.point([1, 2])))
  251. raises(ValueError, lambda: intcurve_diffequ(R2.x, a, R2_r.point([1, 2])))
  252. raises(ValueError, lambda: contravariant_order(R2.e_x + R2.dx))
  253. raises(ValueError, lambda: covariant_order(R2.e_x + R2.dx))
  254. raises(ValueError, lambda: contravariant_order(R2.e_x*R2.e_y))
  255. raises(ValueError, lambda: covariant_order(R2.dx*R2.dy))
  256. def test_simplify():
  257. x, y = R2_r.coord_functions()
  258. dx, dy = R2_r.base_oneforms()
  259. ex, ey = R2_r.base_vectors()
  260. assert simplify(x) == x
  261. assert simplify(x*y) == x*y
  262. assert simplify(dx*dy) == dx*dy
  263. assert simplify(ex*ey) == ex*ey
  264. assert ((1-x)*dx)/(1-x)**2 == dx/(1-x)
  265. def test_issue_17917():
  266. X = R2.x*R2.e_x - R2.y*R2.e_y
  267. Y = (R2.x**2 + R2.y**2)*R2.e_x - R2.x*R2.y*R2.e_y
  268. assert LieDerivative(X, Y).expand() == (
  269. R2.x**2*R2.e_x - 3*R2.y**2*R2.e_x - R2.x*R2.y*R2.e_y)
  270. def test_deprecations():
  271. m = Manifold('M', 2)
  272. p = Patch('P', m)
  273. with warns_deprecated_sympy():
  274. CoordSystem('Car2d', p, names=['x', 'y'])
  275. with warns_deprecated_sympy():
  276. c = CoordSystem('Car2d', p, ['x', 'y'])
  277. with warns_deprecated_sympy():
  278. list(m.patches)
  279. with warns_deprecated_sympy():
  280. list(c.transforms)