test_rotation_groups.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import pytest
  2. import numpy as np
  3. from numpy.testing import assert_array_almost_equal
  4. from scipy.spatial.transform import Rotation
  5. from scipy.optimize import linear_sum_assignment
  6. from scipy.spatial.distance import cdist
  7. from scipy.constants import golden as phi
  8. from scipy.spatial import cKDTree
  9. TOL = 1E-12
  10. NS = range(1, 13)
  11. NAMES = ["I", "O", "T"] + ["C%d" % n for n in NS] + ["D%d" % n for n in NS]
  12. SIZES = [60, 24, 12] + list(NS) + [2 * n for n in NS]
  13. def _calculate_rmsd(P, Q):
  14. """Calculates the root-mean-square distance between the points of P and Q.
  15. The distance is taken as the minimum over all possible matchings. It is
  16. zero if P and Q are identical and non-zero if not.
  17. """
  18. distance_matrix = cdist(P, Q, metric='sqeuclidean')
  19. matching = linear_sum_assignment(distance_matrix)
  20. return np.sqrt(distance_matrix[matching].sum())
  21. def _generate_pyramid(n, axis):
  22. thetas = np.linspace(0, 2 * np.pi, n + 1)[:-1]
  23. P = np.vstack([np.zeros(n), np.cos(thetas), np.sin(thetas)]).T
  24. P = np.concatenate((P, [[1, 0, 0]]))
  25. return np.roll(P, axis, axis=1)
  26. def _generate_prism(n, axis):
  27. thetas = np.linspace(0, 2 * np.pi, n + 1)[:-1]
  28. bottom = np.vstack([-np.ones(n), np.cos(thetas), np.sin(thetas)]).T
  29. top = np.vstack([+np.ones(n), np.cos(thetas), np.sin(thetas)]).T
  30. P = np.concatenate((bottom, top))
  31. return np.roll(P, axis, axis=1)
  32. def _generate_icosahedron():
  33. x = np.array([[0, -1, -phi],
  34. [0, -1, +phi],
  35. [0, +1, -phi],
  36. [0, +1, +phi]])
  37. return np.concatenate([np.roll(x, i, axis=1) for i in range(3)])
  38. def _generate_octahedron():
  39. return np.array([[-1, 0, 0], [+1, 0, 0], [0, -1, 0],
  40. [0, +1, 0], [0, 0, -1], [0, 0, +1]])
  41. def _generate_tetrahedron():
  42. return np.array([[1, 1, 1], [1, -1, -1], [-1, 1, -1], [-1, -1, 1]])
  43. @pytest.mark.parametrize("name", [-1, None, True, np.array(['C3'])])
  44. def test_group_type(name):
  45. with pytest.raises(ValueError,
  46. match="must be a string"):
  47. Rotation.create_group(name)
  48. @pytest.mark.parametrize("name", ["Q", " ", "CA", "C ", "DA", "D ", "I2", ""])
  49. def test_group_name(name):
  50. with pytest.raises(ValueError,
  51. match="must be one of 'I', 'O', 'T', 'Dn', 'Cn'"):
  52. Rotation.create_group(name)
  53. @pytest.mark.parametrize("name", ["C0", "D0"])
  54. def test_group_order_positive(name):
  55. with pytest.raises(ValueError,
  56. match="Group order must be positive"):
  57. Rotation.create_group(name)
  58. @pytest.mark.parametrize("axis", ['A', 'b', 0, 1, 2, 4, False, None])
  59. def test_axis_valid(axis):
  60. with pytest.raises(ValueError,
  61. match="`axis` must be one of"):
  62. Rotation.create_group("C1", axis)
  63. def test_icosahedral():
  64. """The icosahedral group fixes the rotations of an icosahedron. Here we
  65. test that the icosahedron is invariant after application of the elements
  66. of the rotation group."""
  67. P = _generate_icosahedron()
  68. for g in Rotation.create_group("I"):
  69. g = Rotation.from_quat(g.as_quat())
  70. assert _calculate_rmsd(P, g.apply(P)) < TOL
  71. def test_octahedral():
  72. """Test that the octahedral group correctly fixes the rotations of an
  73. octahedron."""
  74. P = _generate_octahedron()
  75. for g in Rotation.create_group("O"):
  76. assert _calculate_rmsd(P, g.apply(P)) < TOL
  77. def test_tetrahedral():
  78. """Test that the tetrahedral group correctly fixes the rotations of a
  79. tetrahedron."""
  80. P = _generate_tetrahedron()
  81. for g in Rotation.create_group("T"):
  82. assert _calculate_rmsd(P, g.apply(P)) < TOL
  83. @pytest.mark.parametrize("n", NS)
  84. @pytest.mark.parametrize("axis", 'XYZ')
  85. def test_dicyclic(n, axis):
  86. """Test that the dicyclic group correctly fixes the rotations of a
  87. prism."""
  88. P = _generate_prism(n, axis='XYZ'.index(axis))
  89. for g in Rotation.create_group("D%d" % n, axis=axis):
  90. assert _calculate_rmsd(P, g.apply(P)) < TOL
  91. @pytest.mark.parametrize("n", NS)
  92. @pytest.mark.parametrize("axis", 'XYZ')
  93. def test_cyclic(n, axis):
  94. """Test that the cyclic group correctly fixes the rotations of a
  95. pyramid."""
  96. P = _generate_pyramid(n, axis='XYZ'.index(axis))
  97. for g in Rotation.create_group("C%d" % n, axis=axis):
  98. assert _calculate_rmsd(P, g.apply(P)) < TOL
  99. @pytest.mark.parametrize("name, size", zip(NAMES, SIZES))
  100. def test_group_sizes(name, size):
  101. assert len(Rotation.create_group(name)) == size
  102. @pytest.mark.parametrize("name, size", zip(NAMES, SIZES))
  103. def test_group_no_duplicates(name, size):
  104. g = Rotation.create_group(name)
  105. kdtree = cKDTree(g.as_quat())
  106. assert len(kdtree.query_pairs(1E-3)) == 0
  107. @pytest.mark.parametrize("name, size", zip(NAMES, SIZES))
  108. def test_group_symmetry(name, size):
  109. g = Rotation.create_group(name)
  110. q = np.concatenate((-g.as_quat(), g.as_quat()))
  111. distance = np.sort(cdist(q, q))
  112. deltas = np.max(distance, axis=0) - np.min(distance, axis=0)
  113. assert (deltas < TOL).all()
  114. @pytest.mark.parametrize("name", NAMES)
  115. def test_reduction(name):
  116. """Test that the elements of the rotation group are correctly
  117. mapped onto the identity rotation."""
  118. g = Rotation.create_group(name)
  119. f = g.reduce(g)
  120. assert_array_almost_equal(f.magnitude(), np.zeros(len(g)))
  121. @pytest.mark.parametrize("name", NAMES)
  122. def test_single_reduction(name):
  123. g = Rotation.create_group(name)
  124. f = g[-1].reduce(g)
  125. assert_array_almost_equal(f.magnitude(), 0)
  126. assert f.as_quat().shape == (4,)