123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409 |
- from sympy.core.function import diff
- from sympy.core.function import expand
- from sympy.core.numbers import (E, I, Rational, pi)
- from sympy.core.singleton import S
- from sympy.core.symbol import (Symbol, symbols)
- from sympy.functions.elementary.complexes import (Abs, conjugate, im, re, sign)
- from sympy.functions.elementary.exponential import log
- from sympy.functions.elementary.miscellaneous import sqrt
- from sympy.functions.elementary.trigonometric import (acos, asin, cos, sin, atan2, atan)
- from sympy.integrals.integrals import integrate
- from sympy.matrices.dense import Matrix
- from sympy.simplify import simplify
- from sympy.simplify.trigsimp import trigsimp
- from sympy.algebras.quaternion import Quaternion
- from sympy.testing.pytest import raises
- from itertools import permutations, product
- w, x, y, z = symbols('w:z')
- phi = symbols('phi')
- def test_quaternion_construction():
- q = Quaternion(w, x, y, z)
- assert q + q == Quaternion(2*w, 2*x, 2*y, 2*z)
- q2 = Quaternion.from_axis_angle((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3),
- pi*Rational(2, 3))
- assert q2 == Quaternion(S.Half, S.Half,
- S.Half, S.Half)
- M = Matrix([[cos(phi), -sin(phi), 0], [sin(phi), cos(phi), 0], [0, 0, 1]])
- q3 = trigsimp(Quaternion.from_rotation_matrix(M))
- assert q3 == Quaternion(
- sqrt(2)*sqrt(cos(phi) + 1)/2, 0, 0, sqrt(2 - 2*cos(phi))*sign(sin(phi))/2)
- nc = Symbol('nc', commutative=False)
- raises(ValueError, lambda: Quaternion(w, x, nc, z))
- def test_quaternion_construction_norm():
- q1 = Quaternion(*symbols('a:d'))
- q2 = Quaternion(w, x, y, z)
- assert expand((q1*q2).norm()**2 - (q1.norm()**2 * q2.norm()**2)) == 0
- q3 = Quaternion(w, x, y, z, norm=1)
- assert (q1 * q3).norm() == q1.norm()
- def test_to_and_from_Matrix():
- q = Quaternion(w, x, y, z)
- q_full = Quaternion.from_Matrix(q.to_Matrix())
- q_vect = Quaternion.from_Matrix(q.to_Matrix(True))
- assert (q - q_full).is_zero_quaternion()
- assert (q.vector_part() - q_vect).is_zero_quaternion()
- def test_product_matrices():
- q1 = Quaternion(w, x, y, z)
- q2 = Quaternion(*(symbols("a:d")))
- assert (q1 * q2).to_Matrix() == q1.product_matrix_left * q2.to_Matrix()
- assert (q1 * q2).to_Matrix() == q2.product_matrix_right * q1.to_Matrix()
- R1 = (q1.product_matrix_left * q1.product_matrix_right.T)[1:, 1:]
- R2 = simplify(q1.to_rotation_matrix()*q1.norm()**2)
- assert R1 == R2
- def test_quaternion_axis_angle():
- test_data = [ # axis, angle, expected_quaternion
- ((1, 0, 0), 0, (1, 0, 0, 0)),
- ((1, 0, 0), pi/2, (sqrt(2)/2, sqrt(2)/2, 0, 0)),
- ((0, 1, 0), pi/2, (sqrt(2)/2, 0, sqrt(2)/2, 0)),
- ((0, 0, 1), pi/2, (sqrt(2)/2, 0, 0, sqrt(2)/2)),
- ((1, 0, 0), pi, (0, 1, 0, 0)),
- ((0, 1, 0), pi, (0, 0, 1, 0)),
- ((0, 0, 1), pi, (0, 0, 0, 1)),
- ((1, 1, 1), pi, (0, 1/sqrt(3),1/sqrt(3),1/sqrt(3))),
- ((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3), pi*2/3, (S.Half, S.Half, S.Half, S.Half))
- ]
- for axis, angle, expected in test_data:
- assert Quaternion.from_axis_angle(axis, angle) == Quaternion(*expected)
- def test_quaternion_axis_angle_simplification():
- result = Quaternion.from_axis_angle((1, 2, 3), asin(4))
- assert result.a == cos(asin(4)/2)
- assert result.b == sqrt(14)*sin(asin(4)/2)/14
- assert result.c == sqrt(14)*sin(asin(4)/2)/7
- assert result.d == 3*sqrt(14)*sin(asin(4)/2)/14
- def test_quaternion_complex_real_addition():
- a = symbols("a", complex=True)
- b = symbols("b", real=True)
- # This symbol is not complex:
- c = symbols("c", commutative=False)
- q = Quaternion(w, x, y, z)
- assert a + q == Quaternion(w + re(a), x + im(a), y, z)
- assert 1 + q == Quaternion(1 + w, x, y, z)
- assert I + q == Quaternion(w, 1 + x, y, z)
- assert b + q == Quaternion(w + b, x, y, z)
- raises(ValueError, lambda: c + q)
- raises(ValueError, lambda: q * c)
- raises(ValueError, lambda: c * q)
- assert -q == Quaternion(-w, -x, -y, -z)
- q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False)
- q2 = Quaternion(1, 4, 7, 8)
- assert q1 + (2 + 3*I) == Quaternion(5 + 7*I, 2 + 5*I, 0, 7 + 8*I)
- assert q2 + (2 + 3*I) == Quaternion(3, 7, 7, 8)
- assert q1 * (2 + 3*I) == \
- Quaternion((2 + 3*I)*(3 + 4*I), (2 + 3*I)*(2 + 5*I), 0, (2 + 3*I)*(7 + 8*I))
- assert q2 * (2 + 3*I) == Quaternion(-10, 11, 38, -5)
- q1 = Quaternion(1, 2, 3, 4)
- q0 = Quaternion(0, 0, 0, 0)
- assert q1 + q0 == q1
- assert q1 - q0 == q1
- assert q1 - q1 == q0
- def test_quaternion_evalf():
- assert (Quaternion(sqrt(2), 0, 0, sqrt(3)).evalf() ==
- Quaternion(sqrt(2).evalf(), 0, 0, sqrt(3).evalf()))
- assert (Quaternion(1/sqrt(2), 0, 0, 1/sqrt(2)).evalf() ==
- Quaternion((1/sqrt(2)).evalf(), 0, 0, (1/sqrt(2)).evalf()))
- def test_quaternion_functions():
- q = Quaternion(w, x, y, z)
- q1 = Quaternion(1, 2, 3, 4)
- q0 = Quaternion(0, 0, 0, 0)
- assert conjugate(q) == Quaternion(w, -x, -y, -z)
- assert q.norm() == sqrt(w**2 + x**2 + y**2 + z**2)
- assert q.normalize() == Quaternion(w, x, y, z) / sqrt(w**2 + x**2 + y**2 + z**2)
- assert q.inverse() == Quaternion(w, -x, -y, -z) / (w**2 + x**2 + y**2 + z**2)
- assert q.inverse() == q.pow(-1)
- raises(ValueError, lambda: q0.inverse())
- assert q.pow(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z)
- assert q**(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z)
- assert q1.pow(-2) == Quaternion(
- Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225))
- assert q1**(-2) == Quaternion(
- Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225))
- assert q1.pow(-0.5) == NotImplemented
- raises(TypeError, lambda: q1**(-0.5))
- assert q1.exp() == \
- Quaternion(E * cos(sqrt(29)),
- 2 * sqrt(29) * E * sin(sqrt(29)) / 29,
- 3 * sqrt(29) * E * sin(sqrt(29)) / 29,
- 4 * sqrt(29) * E * sin(sqrt(29)) / 29)
- assert q1._ln() == \
- Quaternion(log(sqrt(30)),
- 2 * sqrt(29) * acos(sqrt(30)/30) / 29,
- 3 * sqrt(29) * acos(sqrt(30)/30) / 29,
- 4 * sqrt(29) * acos(sqrt(30)/30) / 29)
- assert q1.pow_cos_sin(2) == \
- Quaternion(30 * cos(2 * acos(sqrt(30)/30)),
- 60 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29,
- 90 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29,
- 120 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29)
- assert diff(Quaternion(x, x, x, x), x) == Quaternion(1, 1, 1, 1)
- assert integrate(Quaternion(x, x, x, x), x) == \
- Quaternion(x**2 / 2, x**2 / 2, x**2 / 2, x**2 / 2)
- assert Quaternion.rotate_point((1, 1, 1), q1) == (S.One / 5, 1, S(7) / 5)
- n = Symbol('n')
- raises(TypeError, lambda: q1**n)
- n = Symbol('n', integer=True)
- raises(TypeError, lambda: q1**n)
- assert Quaternion(22, 23, 55, 8).scalar_part() == 22
- assert Quaternion(w, x, y, z).scalar_part() == w
- assert Quaternion(22, 23, 55, 8).vector_part() == Quaternion(0, 23, 55, 8)
- assert Quaternion(w, x, y, z).vector_part() == Quaternion(0, x, y, z)
- assert q1.axis() == Quaternion(0, 2*sqrt(29)/29, 3*sqrt(29)/29, 4*sqrt(29)/29)
- assert q1.axis().pow(2) == Quaternion(-1, 0, 0, 0)
- assert q0.axis().scalar_part() == 0
- assert (q.axis() == Quaternion(0,
- x/sqrt(x**2 + y**2 + z**2),
- y/sqrt(x**2 + y**2 + z**2),
- z/sqrt(x**2 + y**2 + z**2)))
- assert q0.is_pure() is True
- assert q1.is_pure() is False
- assert Quaternion(0, 0, 0, 3).is_pure() is True
- assert Quaternion(0, 2, 10, 3).is_pure() is True
- assert Quaternion(w, 2, 10, 3).is_pure() is None
- assert q1.angle() == atan(sqrt(29))
- assert q.angle() == atan2(sqrt(x**2 + y**2 + z**2), w)
- assert Quaternion.arc_coplanar(q1, Quaternion(2, 4, 6, 8)) is True
- assert Quaternion.arc_coplanar(q1, Quaternion(1, -2, -3, -4)) is True
- assert Quaternion.arc_coplanar(q1, Quaternion(1, 8, 12, 16)) is True
- assert Quaternion.arc_coplanar(q1, Quaternion(1, 2, 3, 4)) is True
- assert Quaternion.arc_coplanar(q1, Quaternion(w, 4, 6, 8)) is True
- assert Quaternion.arc_coplanar(q1, Quaternion(2, 7, 4, 1)) is False
- assert Quaternion.arc_coplanar(q1, Quaternion(w, x, y, z)) is None
- raises(ValueError, lambda: Quaternion.arc_coplanar(q1, q0))
- assert Quaternion.vector_coplanar(
- Quaternion(0, 8, 12, 16),
- Quaternion(0, 4, 6, 8),
- Quaternion(0, 2, 3, 4)) is True
- assert Quaternion.vector_coplanar(
- Quaternion(0, 0, 0, 0), Quaternion(0, 4, 6, 8), Quaternion(0, 2, 3, 4)) is True
- assert Quaternion.vector_coplanar(
- Quaternion(0, 8, 2, 6), Quaternion(0, 1, 6, 6), Quaternion(0, 0, 3, 4)) is False
- assert Quaternion.vector_coplanar(
- Quaternion(0, 1, 3, 4),
- Quaternion(0, 4, w, 6),
- Quaternion(0, 6, 8, 1)) is None
- raises(ValueError, lambda:
- Quaternion.vector_coplanar(q0, Quaternion(0, 4, 6, 8), q1))
- assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 4, 6)) is True
- assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 2, 6)) is False
- assert Quaternion(0, 1, 2, 3).parallel(Quaternion(w, x, y, 6)) is None
- raises(ValueError, lambda: q0.parallel(q1))
- assert Quaternion(0, 1, 2, 3).orthogonal(Quaternion(0, -2, 1, 0)) is True
- assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(0, 2, 2, 6)) is False
- assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(w, x, y, 6)) is None
- raises(ValueError, lambda: q0.orthogonal(q1))
- assert q1.index_vector() == Quaternion(
- 0, 2*sqrt(870)/29,
- 3*sqrt(870)/29,
- 4*sqrt(870)/29)
- assert Quaternion(0, 3, 9, 4).index_vector() == Quaternion(0, 3, 9, 4)
- assert Quaternion(4, 3, 9, 4).mensor() == log(sqrt(122))
- assert Quaternion(3, 3, 0, 2).mensor() == log(sqrt(22))
- assert q0.is_zero_quaternion() is True
- assert q1.is_zero_quaternion() is False
- assert Quaternion(w, 0, 0, 0).is_zero_quaternion() is None
- def test_quaternion_conversions():
- q1 = Quaternion(1, 2, 3, 4)
- assert q1.to_axis_angle() == ((2 * sqrt(29)/29,
- 3 * sqrt(29)/29,
- 4 * sqrt(29)/29),
- 2 * acos(sqrt(30)/30))
- assert (q1.to_rotation_matrix() ==
- Matrix([[Rational(-2, 3), Rational(2, 15), Rational(11, 15)],
- [Rational(2, 3), Rational(-1, 3), Rational(2, 3)],
- [Rational(1, 3), Rational(14, 15), Rational(2, 15)]]))
- assert (q1.to_rotation_matrix((1, 1, 1)) ==
- Matrix([
- [Rational(-2, 3), Rational(2, 15), Rational(11, 15), Rational(4, 5)],
- [Rational(2, 3), Rational(-1, 3), Rational(2, 3), S.Zero],
- [Rational(1, 3), Rational(14, 15), Rational(2, 15), Rational(-2, 5)],
- [S.Zero, S.Zero, S.Zero, S.One]]))
- theta = symbols("theta", real=True)
- q2 = Quaternion(cos(theta/2), 0, 0, sin(theta/2))
- assert trigsimp(q2.to_rotation_matrix()) == Matrix([
- [cos(theta), -sin(theta), 0],
- [sin(theta), cos(theta), 0],
- [0, 0, 1]])
- assert q2.to_axis_angle() == ((0, 0, sin(theta/2)/Abs(sin(theta/2))),
- 2*acos(cos(theta/2)))
- assert trigsimp(q2.to_rotation_matrix((1, 1, 1))) == Matrix([
- [cos(theta), -sin(theta), 0, sin(theta) - cos(theta) + 1],
- [sin(theta), cos(theta), 0, -sin(theta) - cos(theta) + 1],
- [0, 0, 1, 0],
- [0, 0, 0, 1]])
- def test_rotation_matrix_homogeneous():
- q = Quaternion(w, x, y, z)
- R1 = q.to_rotation_matrix(homogeneous=True) * q.norm()**2
- R2 = simplify(q.to_rotation_matrix(homogeneous=False) * q.norm()**2)
- assert R1 == R2
- def test_quaternion_rotation_iss1593():
- """
- There was a sign mistake in the definition,
- of the rotation matrix. This tests that particular sign mistake.
- See issue 1593 for reference.
- See wikipedia
- https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix
- for the correct definition
- """
- q = Quaternion(cos(phi/2), sin(phi/2), 0, 0)
- assert(trigsimp(q.to_rotation_matrix()) == Matrix([
- [1, 0, 0],
- [0, cos(phi), -sin(phi)],
- [0, sin(phi), cos(phi)]]))
- def test_quaternion_multiplication():
- q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False)
- q2 = Quaternion(1, 2, 3, 5)
- q3 = Quaternion(1, 1, 1, y)
- assert Quaternion._generic_mul(S(4), S.One) == 4
- assert (Quaternion._generic_mul(S(4), q1) ==
- Quaternion(12 + 16*I, 8 + 20*I, 0, 28 + 32*I))
- assert q2.mul(2) == Quaternion(2, 4, 6, 10)
- assert q2.mul(q3) == Quaternion(-5*y - 4, 3*y - 2, 9 - 2*y, y + 4)
- assert q2.mul(q3) == q2*q3
- z = symbols('z', complex=True)
- z_quat = Quaternion(re(z), im(z), 0, 0)
- q = Quaternion(*symbols('q:4', real=True))
- assert z * q == z_quat * q
- assert q * z == q * z_quat
- def test_issue_16318():
- #for rtruediv
- q0 = Quaternion(0, 0, 0, 0)
- raises(ValueError, lambda: 1/q0)
- #for rotate_point
- q = Quaternion(1, 2, 3, 4)
- (axis, angle) = q.to_axis_angle()
- assert Quaternion.rotate_point((1, 1, 1), (axis, angle)) == (S.One / 5, 1, S(7) / 5)
- #test for to_axis_angle
- q = Quaternion(-1, 1, 1, 1)
- axis = (-sqrt(3)/3, -sqrt(3)/3, -sqrt(3)/3)
- angle = 2*pi/3
- assert (axis, angle) == q.to_axis_angle()
- def test_to_euler():
- q = Quaternion(w, x, y, z)
- q_normalized = q.normalize()
- seqs = ['zxy', 'zyx', 'zyz', 'zxz']
- seqs += [seq.upper() for seq in seqs]
- for seq in seqs:
- euler_from_q = q.to_euler(seq)
- q_back = simplify(Quaternion.from_euler(euler_from_q, seq))
- assert q_back == q_normalized
- def test_to_euler_iss24504():
- """
- There was a mistake in the degenerate case testing
- See issue 24504 for reference.
- """
- q = Quaternion.from_euler((phi, 0, 0), 'zyz')
- assert trigsimp(q.to_euler('zyz'), inverse=True) == (phi, 0, 0)
- def test_to_euler_numerical_singilarities():
- def test_one_case(angles, seq):
- q = Quaternion.from_euler(angles, seq)
- assert q.to_euler(seq) == angles
- # symmetric
- test_one_case((pi/2, 0, 0), 'zyz')
- test_one_case((pi/2, 0, 0), 'ZYZ')
- test_one_case((pi/2, pi, 0), 'zyz')
- test_one_case((pi/2, pi, 0), 'ZYZ')
- # asymmetric
- test_one_case((pi/2, pi/2, 0), 'zyx')
- test_one_case((pi/2, -pi/2, 0), 'zyx')
- test_one_case((pi/2, pi/2, 0), 'ZYX')
- test_one_case((pi/2, -pi/2, 0), 'ZYX')
- def test_to_euler_options():
- def test_one_case(q):
- angles1 = Matrix(q.to_euler(seq, True, True))
- angles2 = Matrix(q.to_euler(seq, False, False))
- angle_errors = simplify(angles1-angles2).evalf()
- for angle_error in angle_errors:
- # forcing angles to set {-pi, pi}
- angle_error = (angle_error + pi) % (2 * pi) - pi
- assert angle_error < 10e-7
- for xyz in ('xyz', 'XYZ'):
- for seq_tuple in permutations(xyz):
- for symmetric in (True, False):
- if symmetric:
- seq = ''.join([seq_tuple[0], seq_tuple[1], seq_tuple[0]])
- else:
- seq = ''.join(seq_tuple)
- for elements in product([-1, 0, 1], repeat=4):
- q = Quaternion(*elements)
- if not q.is_zero_quaternion():
- test_one_case(q)
|