test_decomp_polar.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. from numpy.linalg import norm
  3. from numpy.testing import (assert_, assert_allclose, assert_equal)
  4. from scipy.linalg import polar, eigh
  5. diag2 = np.array([[2, 0], [0, 3]])
  6. a13 = np.array([[1, 2, 2]])
  7. precomputed_cases = [
  8. [[[0]], 'right', [[1]], [[0]]],
  9. [[[0]], 'left', [[1]], [[0]]],
  10. [[[9]], 'right', [[1]], [[9]]],
  11. [[[9]], 'left', [[1]], [[9]]],
  12. [diag2, 'right', np.eye(2), diag2],
  13. [diag2, 'left', np.eye(2), diag2],
  14. [a13, 'right', a13/norm(a13[0]), a13.T.dot(a13)/norm(a13[0])],
  15. ]
  16. verify_cases = [
  17. [[1, 2], [3, 4]],
  18. [[1, 2, 3]],
  19. [[1], [2], [3]],
  20. [[1, 2, 3], [3, 4, 0]],
  21. [[1, 2], [3, 4], [5, 5]],
  22. [[1, 2], [3, 4+5j]],
  23. [[1, 2, 3j]],
  24. [[1], [2], [3j]],
  25. [[1, 2, 3+2j], [3, 4-1j, -4j]],
  26. [[1, 2], [3-2j, 4+0.5j], [5, 5]],
  27. [[10000, 10, 1], [-1, 2, 3j], [0, 1, 2]],
  28. ]
  29. def check_precomputed_polar(a, side, expected_u, expected_p):
  30. # Compare the result of the polar decomposition to a
  31. # precomputed result.
  32. u, p = polar(a, side=side)
  33. assert_allclose(u, expected_u, atol=1e-15)
  34. assert_allclose(p, expected_p, atol=1e-15)
  35. def verify_polar(a):
  36. # Compute the polar decomposition, and then verify that
  37. # the result has all the expected properties.
  38. product_atol = np.sqrt(np.finfo(float).eps)
  39. aa = np.asarray(a)
  40. m, n = aa.shape
  41. u, p = polar(a, side='right')
  42. assert_equal(u.shape, (m, n))
  43. assert_equal(p.shape, (n, n))
  44. # a = up
  45. assert_allclose(u.dot(p), a, atol=product_atol)
  46. if m >= n:
  47. assert_allclose(u.conj().T.dot(u), np.eye(n), atol=1e-15)
  48. else:
  49. assert_allclose(u.dot(u.conj().T), np.eye(m), atol=1e-15)
  50. # p is Hermitian positive semidefinite.
  51. assert_allclose(p.conj().T, p)
  52. evals = eigh(p, eigvals_only=True)
  53. nonzero_evals = evals[abs(evals) > 1e-14]
  54. assert_((nonzero_evals >= 0).all())
  55. u, p = polar(a, side='left')
  56. assert_equal(u.shape, (m, n))
  57. assert_equal(p.shape, (m, m))
  58. # a = pu
  59. assert_allclose(p.dot(u), a, atol=product_atol)
  60. if m >= n:
  61. assert_allclose(u.conj().T.dot(u), np.eye(n), atol=1e-15)
  62. else:
  63. assert_allclose(u.dot(u.conj().T), np.eye(m), atol=1e-15)
  64. # p is Hermitian positive semidefinite.
  65. assert_allclose(p.conj().T, p)
  66. evals = eigh(p, eigvals_only=True)
  67. nonzero_evals = evals[abs(evals) > 1e-14]
  68. assert_((nonzero_evals >= 0).all())
  69. def test_precomputed_cases():
  70. for a, side, expected_u, expected_p in precomputed_cases:
  71. check_precomputed_polar(a, side, expected_u, expected_p)
  72. def test_verify_cases():
  73. for a in verify_cases:
  74. verify_polar(a)