test_functions.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from sympy.vector.vector import Vector
  2. from sympy.vector.coordsysrect import CoordSys3D
  3. from sympy.vector.functions import express, matrix_to_vector, orthogonalize
  4. from sympy.core.numbers import Rational
  5. from sympy.core.singleton import S
  6. from sympy.core.symbol import symbols
  7. from sympy.functions.elementary.miscellaneous import sqrt
  8. from sympy.functions.elementary.trigonometric import (cos, sin)
  9. from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix
  10. from sympy.testing.pytest import raises
  11. N = CoordSys3D('N')
  12. q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5')
  13. A = N.orient_new_axis('A', q1, N.k) # type: ignore
  14. B = A.orient_new_axis('B', q2, A.i)
  15. C = B.orient_new_axis('C', q3, B.j)
  16. def test_express():
  17. assert express(Vector.zero, N) == Vector.zero
  18. assert express(S.Zero, N) is S.Zero
  19. assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k
  20. assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \
  21. sin(q2)*cos(q3)*C.k
  22. assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \
  23. cos(q2)*cos(q3)*C.k
  24. assert express(A.i, N) == cos(q1)*N.i + sin(q1)*N.j
  25. assert express(A.j, N) == -sin(q1)*N.i + cos(q1)*N.j
  26. assert express(A.k, N) == N.k
  27. assert express(A.i, A) == A.i
  28. assert express(A.j, A) == A.j
  29. assert express(A.k, A) == A.k
  30. assert express(A.i, B) == B.i
  31. assert express(A.j, B) == cos(q2)*B.j - sin(q2)*B.k
  32. assert express(A.k, B) == sin(q2)*B.j + cos(q2)*B.k
  33. assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k
  34. assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \
  35. sin(q2)*cos(q3)*C.k
  36. assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \
  37. cos(q2)*cos(q3)*C.k
  38. # Check to make sure UnitVectors get converted properly
  39. assert express(N.i, N) == N.i
  40. assert express(N.j, N) == N.j
  41. assert express(N.k, N) == N.k
  42. assert express(N.i, A) == (cos(q1)*A.i - sin(q1)*A.j)
  43. assert express(N.j, A) == (sin(q1)*A.i + cos(q1)*A.j)
  44. assert express(N.k, A) == A.k
  45. assert express(N.i, B) == (cos(q1)*B.i - sin(q1)*cos(q2)*B.j +
  46. sin(q1)*sin(q2)*B.k)
  47. assert express(N.j, B) == (sin(q1)*B.i + cos(q1)*cos(q2)*B.j -
  48. sin(q2)*cos(q1)*B.k)
  49. assert express(N.k, B) == (sin(q2)*B.j + cos(q2)*B.k)
  50. assert express(N.i, C) == (
  51. (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.i -
  52. sin(q1)*cos(q2)*C.j +
  53. (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.k)
  54. assert express(N.j, C) == (
  55. (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.i +
  56. cos(q1)*cos(q2)*C.j +
  57. (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.k)
  58. assert express(N.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j +
  59. cos(q2)*cos(q3)*C.k)
  60. assert express(A.i, N) == (cos(q1)*N.i + sin(q1)*N.j)
  61. assert express(A.j, N) == (-sin(q1)*N.i + cos(q1)*N.j)
  62. assert express(A.k, N) == N.k
  63. assert express(A.i, A) == A.i
  64. assert express(A.j, A) == A.j
  65. assert express(A.k, A) == A.k
  66. assert express(A.i, B) == B.i
  67. assert express(A.j, B) == (cos(q2)*B.j - sin(q2)*B.k)
  68. assert express(A.k, B) == (sin(q2)*B.j + cos(q2)*B.k)
  69. assert express(A.i, C) == (cos(q3)*C.i + sin(q3)*C.k)
  70. assert express(A.j, C) == (sin(q2)*sin(q3)*C.i + cos(q2)*C.j -
  71. sin(q2)*cos(q3)*C.k)
  72. assert express(A.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j +
  73. cos(q2)*cos(q3)*C.k)
  74. assert express(B.i, N) == (cos(q1)*N.i + sin(q1)*N.j)
  75. assert express(B.j, N) == (-sin(q1)*cos(q2)*N.i +
  76. cos(q1)*cos(q2)*N.j + sin(q2)*N.k)
  77. assert express(B.k, N) == (sin(q1)*sin(q2)*N.i -
  78. sin(q2)*cos(q1)*N.j + cos(q2)*N.k)
  79. assert express(B.i, A) == A.i
  80. assert express(B.j, A) == (cos(q2)*A.j + sin(q2)*A.k)
  81. assert express(B.k, A) == (-sin(q2)*A.j + cos(q2)*A.k)
  82. assert express(B.i, B) == B.i
  83. assert express(B.j, B) == B.j
  84. assert express(B.k, B) == B.k
  85. assert express(B.i, C) == (cos(q3)*C.i + sin(q3)*C.k)
  86. assert express(B.j, C) == C.j
  87. assert express(B.k, C) == (-sin(q3)*C.i + cos(q3)*C.k)
  88. assert express(C.i, N) == (
  89. (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.i +
  90. (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.j -
  91. sin(q3)*cos(q2)*N.k)
  92. assert express(C.j, N) == (
  93. -sin(q1)*cos(q2)*N.i + cos(q1)*cos(q2)*N.j + sin(q2)*N.k)
  94. assert express(C.k, N) == (
  95. (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.i +
  96. (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.j +
  97. cos(q2)*cos(q3)*N.k)
  98. assert express(C.i, A) == (cos(q3)*A.i + sin(q2)*sin(q3)*A.j -
  99. sin(q3)*cos(q2)*A.k)
  100. assert express(C.j, A) == (cos(q2)*A.j + sin(q2)*A.k)
  101. assert express(C.k, A) == (sin(q3)*A.i - sin(q2)*cos(q3)*A.j +
  102. cos(q2)*cos(q3)*A.k)
  103. assert express(C.i, B) == (cos(q3)*B.i - sin(q3)*B.k)
  104. assert express(C.j, B) == B.j
  105. assert express(C.k, B) == (sin(q3)*B.i + cos(q3)*B.k)
  106. assert express(C.i, C) == C.i
  107. assert express(C.j, C) == C.j
  108. assert express(C.k, C) == C.k == (C.k)
  109. # Check to make sure Vectors get converted back to UnitVectors
  110. assert N.i == express((cos(q1)*A.i - sin(q1)*A.j), N).simplify()
  111. assert N.j == express((sin(q1)*A.i + cos(q1)*A.j), N).simplify()
  112. assert N.i == express((cos(q1)*B.i - sin(q1)*cos(q2)*B.j +
  113. sin(q1)*sin(q2)*B.k), N).simplify()
  114. assert N.j == express((sin(q1)*B.i + cos(q1)*cos(q2)*B.j -
  115. sin(q2)*cos(q1)*B.k), N).simplify()
  116. assert N.k == express((sin(q2)*B.j + cos(q2)*B.k), N).simplify()
  117. assert A.i == express((cos(q1)*N.i + sin(q1)*N.j), A).simplify()
  118. assert A.j == express((-sin(q1)*N.i + cos(q1)*N.j), A).simplify()
  119. assert A.j == express((cos(q2)*B.j - sin(q2)*B.k), A).simplify()
  120. assert A.k == express((sin(q2)*B.j + cos(q2)*B.k), A).simplify()
  121. assert A.i == express((cos(q3)*C.i + sin(q3)*C.k), A).simplify()
  122. assert A.j == express((sin(q2)*sin(q3)*C.i + cos(q2)*C.j -
  123. sin(q2)*cos(q3)*C.k), A).simplify()
  124. assert A.k == express((-sin(q3)*cos(q2)*C.i + sin(q2)*C.j +
  125. cos(q2)*cos(q3)*C.k), A).simplify()
  126. assert B.i == express((cos(q1)*N.i + sin(q1)*N.j), B).simplify()
  127. assert B.j == express((-sin(q1)*cos(q2)*N.i +
  128. cos(q1)*cos(q2)*N.j + sin(q2)*N.k), B).simplify()
  129. assert B.k == express((sin(q1)*sin(q2)*N.i -
  130. sin(q2)*cos(q1)*N.j + cos(q2)*N.k), B).simplify()
  131. assert B.j == express((cos(q2)*A.j + sin(q2)*A.k), B).simplify()
  132. assert B.k == express((-sin(q2)*A.j + cos(q2)*A.k), B).simplify()
  133. assert B.i == express((cos(q3)*C.i + sin(q3)*C.k), B).simplify()
  134. assert B.k == express((-sin(q3)*C.i + cos(q3)*C.k), B).simplify()
  135. assert C.i == express((cos(q3)*A.i + sin(q2)*sin(q3)*A.j -
  136. sin(q3)*cos(q2)*A.k), C).simplify()
  137. assert C.j == express((cos(q2)*A.j + sin(q2)*A.k), C).simplify()
  138. assert C.k == express((sin(q3)*A.i - sin(q2)*cos(q3)*A.j +
  139. cos(q2)*cos(q3)*A.k), C).simplify()
  140. assert C.i == express((cos(q3)*B.i - sin(q3)*B.k), C).simplify()
  141. assert C.k == express((sin(q3)*B.i + cos(q3)*B.k), C).simplify()
  142. def test_matrix_to_vector():
  143. m = Matrix([[1], [2], [3]])
  144. assert matrix_to_vector(m, C) == C.i + 2*C.j + 3*C.k
  145. m = Matrix([[0], [0], [0]])
  146. assert matrix_to_vector(m, N) == matrix_to_vector(m, C) == \
  147. Vector.zero
  148. m = Matrix([[q1], [q2], [q3]])
  149. assert matrix_to_vector(m, N) == q1*N.i + q2*N.j + q3*N.k
  150. def test_orthogonalize():
  151. C = CoordSys3D('C')
  152. a, b = symbols('a b', integer=True)
  153. i, j, k = C.base_vectors()
  154. v1 = i + 2*j
  155. v2 = 2*i + 3*j
  156. v3 = 3*i + 5*j
  157. v4 = 3*i + j
  158. v5 = 2*i + 2*j
  159. v6 = a*i + b*j
  160. v7 = 4*a*i + 4*b*j
  161. assert orthogonalize(v1, v2) == [C.i + 2*C.j, C.i*Rational(2, 5) + -C.j/5]
  162. # from wikipedia
  163. assert orthogonalize(v4, v5, orthonormal=True) == \
  164. [(3*sqrt(10))*C.i/10 + (sqrt(10))*C.j/10, (-sqrt(10))*C.i/10 + (3*sqrt(10))*C.j/10]
  165. raises(ValueError, lambda: orthogonalize(v1, v2, v3))
  166. raises(ValueError, lambda: orthogonalize(v6, v7))