test_decomp_cossin.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import pytest
  2. import numpy as np
  3. from numpy.random import seed
  4. from numpy.testing import assert_allclose
  5. from scipy.linalg.lapack import _compute_lwork
  6. from scipy.stats import ortho_group, unitary_group
  7. from scipy.linalg import cossin, get_lapack_funcs
  8. REAL_DTYPES = (np.float32, np.float64)
  9. COMPLEX_DTYPES = (np.complex64, np.complex128)
  10. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  11. @pytest.mark.parametrize('dtype_', DTYPES)
  12. @pytest.mark.parametrize('m, p, q',
  13. [
  14. (2, 1, 1),
  15. (3, 2, 1),
  16. (3, 1, 2),
  17. (4, 2, 2),
  18. (4, 1, 2),
  19. (40, 12, 20),
  20. (40, 30, 1),
  21. (40, 1, 30),
  22. (100, 50, 1),
  23. (100, 50, 50),
  24. ])
  25. @pytest.mark.parametrize('swap_sign', [True, False])
  26. def test_cossin(dtype_, m, p, q, swap_sign):
  27. seed(1234)
  28. if dtype_ in COMPLEX_DTYPES:
  29. x = np.array(unitary_group.rvs(m), dtype=dtype_)
  30. else:
  31. x = np.array(ortho_group.rvs(m), dtype=dtype_)
  32. u, cs, vh = cossin(x, p, q,
  33. swap_sign=swap_sign)
  34. assert_allclose(x, u @ cs @ vh, rtol=0., atol=m*1e3*np.finfo(dtype_).eps)
  35. assert u.dtype == dtype_
  36. # Test for float32 or float 64
  37. assert cs.dtype == np.real(u).dtype
  38. assert vh.dtype == dtype_
  39. u, cs, vh = cossin([x[:p, :q], x[:p, q:], x[p:, :q], x[p:, q:]],
  40. swap_sign=swap_sign)
  41. assert_allclose(x, u @ cs @ vh, rtol=0., atol=m*1e3*np.finfo(dtype_).eps)
  42. assert u.dtype == dtype_
  43. assert cs.dtype == np.real(u).dtype
  44. assert vh.dtype == dtype_
  45. _, cs2, vh2 = cossin(x, p, q,
  46. compute_u=False,
  47. swap_sign=swap_sign)
  48. assert_allclose(cs, cs2, rtol=0., atol=10*np.finfo(dtype_).eps)
  49. assert_allclose(vh, vh2, rtol=0., atol=10*np.finfo(dtype_).eps)
  50. u2, cs2, _ = cossin(x, p, q,
  51. compute_vh=False,
  52. swap_sign=swap_sign)
  53. assert_allclose(u, u2, rtol=0., atol=10*np.finfo(dtype_).eps)
  54. assert_allclose(cs, cs2, rtol=0., atol=10*np.finfo(dtype_).eps)
  55. _, cs2, _ = cossin(x, p, q,
  56. compute_u=False,
  57. compute_vh=False,
  58. swap_sign=swap_sign)
  59. assert_allclose(cs, cs2, rtol=0., atol=10*np.finfo(dtype_).eps)
  60. def test_cossin_mixed_types():
  61. seed(1234)
  62. x = np.array(ortho_group.rvs(4), dtype=np.float64)
  63. u, cs, vh = cossin([x[:2, :2],
  64. np.array(x[:2, 2:], dtype=np.complex128),
  65. x[2:, :2],
  66. x[2:, 2:]])
  67. assert u.dtype == np.complex128
  68. assert cs.dtype == np.float64
  69. assert vh.dtype == np.complex128
  70. assert_allclose(x, u @ cs @ vh, rtol=0.,
  71. atol=1e4 * np.finfo(np.complex128).eps)
  72. def test_cossin_error_incorrect_subblocks():
  73. with pytest.raises(ValueError, match="be due to missing p, q arguments."):
  74. cossin(([1, 2], [3, 4, 5], [6, 7], [8, 9, 10]))
  75. def test_cossin_error_empty_subblocks():
  76. with pytest.raises(ValueError, match="x11.*empty"):
  77. cossin(([], [], [], []))
  78. with pytest.raises(ValueError, match="x12.*empty"):
  79. cossin(([1, 2], [], [6, 7], [8, 9, 10]))
  80. with pytest.raises(ValueError, match="x21.*empty"):
  81. cossin(([1, 2], [3, 4, 5], [], [8, 9, 10]))
  82. with pytest.raises(ValueError, match="x22.*empty"):
  83. cossin(([1, 2], [3, 4, 5], [2], []))
  84. def test_cossin_error_missing_partitioning():
  85. with pytest.raises(ValueError, match=".*exactly four arrays.* got 2"):
  86. cossin(unitary_group.rvs(2))
  87. with pytest.raises(ValueError, match=".*might be due to missing p, q"):
  88. cossin(unitary_group.rvs(4))
  89. def test_cossin_error_non_iterable():
  90. with pytest.raises(ValueError, match="containing the subblocks of X"):
  91. cossin(12j)
  92. def test_cossin_error_non_square():
  93. with pytest.raises(ValueError, match="only supports square"):
  94. cossin(np.array([[1, 2]]), 1, 1)
  95. def test_cossin_error_partitioning():
  96. x = np.array(ortho_group.rvs(4), dtype=np.float64)
  97. with pytest.raises(ValueError, match="invalid p=0.*0<p<4.*"):
  98. cossin(x, 0, 1)
  99. with pytest.raises(ValueError, match="invalid p=4.*0<p<4.*"):
  100. cossin(x, 4, 1)
  101. with pytest.raises(ValueError, match="invalid q=-2.*0<q<4.*"):
  102. cossin(x, 1, -2)
  103. with pytest.raises(ValueError, match="invalid q=5.*0<q<4.*"):
  104. cossin(x, 1, 5)
  105. @pytest.mark.parametrize("dtype_", DTYPES)
  106. def test_cossin_separate(dtype_):
  107. seed(1234)
  108. m, p, q = 250, 80, 170
  109. pfx = 'or' if dtype_ in REAL_DTYPES else 'un'
  110. X = ortho_group.rvs(m) if pfx == 'or' else unitary_group.rvs(m)
  111. X = np.array(X, dtype=dtype_)
  112. drv, dlw = get_lapack_funcs((pfx + 'csd', pfx + 'csd_lwork'),[X])
  113. lwval = _compute_lwork(dlw, m, p, q)
  114. lwvals = {'lwork': lwval} if pfx == 'or' else dict(zip(['lwork',
  115. 'lrwork'],
  116. lwval))
  117. *_, theta, u1, u2, v1t, v2t, _ = \
  118. drv(X[:p, :q], X[:p, q:], X[p:, :q], X[p:, q:], **lwvals)
  119. (u1_2, u2_2), theta2, (v1t_2, v2t_2) = cossin(X, p, q, separate=True)
  120. assert_allclose(u1_2, u1, rtol=0., atol=10*np.finfo(dtype_).eps)
  121. assert_allclose(u2_2, u2, rtol=0., atol=10*np.finfo(dtype_).eps)
  122. assert_allclose(v1t_2, v1t, rtol=0., atol=10*np.finfo(dtype_).eps)
  123. assert_allclose(v2t_2, v2t, rtol=0., atol=10*np.finfo(dtype_).eps)
  124. assert_allclose(theta2, theta, rtol=0., atol=10*np.finfo(dtype_).eps)