test_procrustes.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from itertools import product, permutations
  2. import numpy as np
  3. from numpy.testing import assert_array_less, assert_allclose
  4. from pytest import raises as assert_raises
  5. from scipy.linalg import inv, eigh, norm
  6. from scipy.linalg import orthogonal_procrustes
  7. from scipy.sparse._sputils import matrix
  8. def test_orthogonal_procrustes_ndim_too_large():
  9. np.random.seed(1234)
  10. A = np.random.randn(3, 4, 5)
  11. B = np.random.randn(3, 4, 5)
  12. assert_raises(ValueError, orthogonal_procrustes, A, B)
  13. def test_orthogonal_procrustes_ndim_too_small():
  14. np.random.seed(1234)
  15. A = np.random.randn(3)
  16. B = np.random.randn(3)
  17. assert_raises(ValueError, orthogonal_procrustes, A, B)
  18. def test_orthogonal_procrustes_shape_mismatch():
  19. np.random.seed(1234)
  20. shapes = ((3, 3), (3, 4), (4, 3), (4, 4))
  21. for a, b in permutations(shapes, 2):
  22. A = np.random.randn(*a)
  23. B = np.random.randn(*b)
  24. assert_raises(ValueError, orthogonal_procrustes, A, B)
  25. def test_orthogonal_procrustes_checkfinite_exception():
  26. np.random.seed(1234)
  27. m, n = 2, 3
  28. A_good = np.random.randn(m, n)
  29. B_good = np.random.randn(m, n)
  30. for bad_value in np.inf, -np.inf, np.nan:
  31. A_bad = A_good.copy()
  32. A_bad[1, 2] = bad_value
  33. B_bad = B_good.copy()
  34. B_bad[1, 2] = bad_value
  35. for A, B in ((A_good, B_bad), (A_bad, B_good), (A_bad, B_bad)):
  36. assert_raises(ValueError, orthogonal_procrustes, A, B)
  37. def test_orthogonal_procrustes_scale_invariance():
  38. np.random.seed(1234)
  39. m, n = 4, 3
  40. for i in range(3):
  41. A_orig = np.random.randn(m, n)
  42. B_orig = np.random.randn(m, n)
  43. R_orig, s = orthogonal_procrustes(A_orig, B_orig)
  44. for A_scale in np.square(np.random.randn(3)):
  45. for B_scale in np.square(np.random.randn(3)):
  46. R, s = orthogonal_procrustes(A_orig * A_scale, B_orig * B_scale)
  47. assert_allclose(R, R_orig)
  48. def test_orthogonal_procrustes_array_conversion():
  49. np.random.seed(1234)
  50. for m, n in ((6, 4), (4, 4), (4, 6)):
  51. A_arr = np.random.randn(m, n)
  52. B_arr = np.random.randn(m, n)
  53. As = (A_arr, A_arr.tolist(), matrix(A_arr))
  54. Bs = (B_arr, B_arr.tolist(), matrix(B_arr))
  55. R_arr, s = orthogonal_procrustes(A_arr, B_arr)
  56. AR_arr = A_arr.dot(R_arr)
  57. for A, B in product(As, Bs):
  58. R, s = orthogonal_procrustes(A, B)
  59. AR = A_arr.dot(R)
  60. assert_allclose(AR, AR_arr)
  61. def test_orthogonal_procrustes():
  62. np.random.seed(1234)
  63. for m, n in ((6, 4), (4, 4), (4, 6)):
  64. # Sample a random target matrix.
  65. B = np.random.randn(m, n)
  66. # Sample a random orthogonal matrix
  67. # by computing eigh of a sampled symmetric matrix.
  68. X = np.random.randn(n, n)
  69. w, V = eigh(X.T + X)
  70. assert_allclose(inv(V), V.T)
  71. # Compute a matrix with a known orthogonal transformation that gives B.
  72. A = np.dot(B, V.T)
  73. # Check that an orthogonal transformation from A to B can be recovered.
  74. R, s = orthogonal_procrustes(A, B)
  75. assert_allclose(inv(R), R.T)
  76. assert_allclose(A.dot(R), B)
  77. # Create a perturbed input matrix.
  78. A_perturbed = A + 1e-2 * np.random.randn(m, n)
  79. # Check that the orthogonal procrustes function can find an orthogonal
  80. # transformation that is better than the orthogonal transformation
  81. # computed from the original input matrix.
  82. R_prime, s = orthogonal_procrustes(A_perturbed, B)
  83. assert_allclose(inv(R_prime), R_prime.T)
  84. # Compute the naive and optimal transformations of the perturbed input.
  85. naive_approx = A_perturbed.dot(R)
  86. optim_approx = A_perturbed.dot(R_prime)
  87. # Compute the Frobenius norm errors of the matrix approximations.
  88. naive_approx_error = norm(naive_approx - B, ord='fro')
  89. optim_approx_error = norm(optim_approx - B, ord='fro')
  90. # Check that the orthogonal Procrustes approximation is better.
  91. assert_array_less(optim_approx_error, naive_approx_error)
  92. def _centered(A):
  93. mu = A.mean(axis=0)
  94. return A - mu, mu
  95. def test_orthogonal_procrustes_exact_example():
  96. # Check a small application.
  97. # It uses translation, scaling, reflection, and rotation.
  98. #
  99. # |
  100. # a b |
  101. # |
  102. # d c | w
  103. # |
  104. # --------+--- x ----- z ---
  105. # |
  106. # | y
  107. # |
  108. #
  109. A_orig = np.array([[-3, 3], [-2, 3], [-2, 2], [-3, 2]], dtype=float)
  110. B_orig = np.array([[3, 2], [1, 0], [3, -2], [5, 0]], dtype=float)
  111. A, A_mu = _centered(A_orig)
  112. B, B_mu = _centered(B_orig)
  113. R, s = orthogonal_procrustes(A, B)
  114. scale = s / np.square(norm(A))
  115. B_approx = scale * np.dot(A, R) + B_mu
  116. assert_allclose(B_approx, B_orig, atol=1e-8)
  117. def test_orthogonal_procrustes_stretched_example():
  118. # Try again with a target with a stretched y axis.
  119. A_orig = np.array([[-3, 3], [-2, 3], [-2, 2], [-3, 2]], dtype=float)
  120. B_orig = np.array([[3, 40], [1, 0], [3, -40], [5, 0]], dtype=float)
  121. A, A_mu = _centered(A_orig)
  122. B, B_mu = _centered(B_orig)
  123. R, s = orthogonal_procrustes(A, B)
  124. scale = s / np.square(norm(A))
  125. B_approx = scale * np.dot(A, R) + B_mu
  126. expected = np.array([[3, 21], [-18, 0], [3, -21], [24, 0]], dtype=float)
  127. assert_allclose(B_approx, expected, atol=1e-8)
  128. # Check disparity symmetry.
  129. expected_disparity = 0.4501246882793018
  130. AB_disparity = np.square(norm(B_approx - B_orig) / norm(B))
  131. assert_allclose(AB_disparity, expected_disparity)
  132. R, s = orthogonal_procrustes(B, A)
  133. scale = s / np.square(norm(B))
  134. A_approx = scale * np.dot(B, R) + A_mu
  135. BA_disparity = np.square(norm(A_approx - A_orig) / norm(A))
  136. assert_allclose(BA_disparity, expected_disparity)
  137. def test_orthogonal_procrustes_skbio_example():
  138. # This transformation is also exact.
  139. # It uses translation, scaling, and reflection.
  140. #
  141. # |
  142. # | a
  143. # | b
  144. # | c d
  145. # --+---------
  146. # |
  147. # | w
  148. # |
  149. # | x
  150. # |
  151. # | z y
  152. # |
  153. #
  154. A_orig = np.array([[4, -2], [4, -4], [4, -6], [2, -6]], dtype=float)
  155. B_orig = np.array([[1, 3], [1, 2], [1, 1], [2, 1]], dtype=float)
  156. B_standardized = np.array([
  157. [-0.13363062, 0.6681531],
  158. [-0.13363062, 0.13363062],
  159. [-0.13363062, -0.40089186],
  160. [0.40089186, -0.40089186]])
  161. A, A_mu = _centered(A_orig)
  162. B, B_mu = _centered(B_orig)
  163. R, s = orthogonal_procrustes(A, B)
  164. scale = s / np.square(norm(A))
  165. B_approx = scale * np.dot(A, R) + B_mu
  166. assert_allclose(B_approx, B_orig)
  167. assert_allclose(B / norm(B), B_standardized)