test_cythonized_array_utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import numpy as np
  2. from scipy.linalg import bandwidth, issymmetric, ishermitian
  3. import pytest
  4. from pytest import raises
  5. def test_bandwidth_dtypes():
  6. n = 5
  7. for t in np.typecodes['All']:
  8. A = np.zeros([n, n], dtype=t)
  9. if t in 'eUVOMm':
  10. raises(TypeError, bandwidth, A)
  11. elif t == 'G': # No-op test. On win these pass on others fail.
  12. pass
  13. else:
  14. _ = bandwidth(A)
  15. def test_bandwidth_non2d_input():
  16. A = np.array([1, 2, 3])
  17. raises(ValueError, bandwidth, A)
  18. A = np.array([[[1, 2, 3], [4, 5, 6]]])
  19. raises(ValueError, bandwidth, A)
  20. @pytest.mark.parametrize('T', [x for x in np.typecodes['All']
  21. if x not in 'eGUVOMm'])
  22. def test_bandwidth_square_inputs(T):
  23. n = 20
  24. k = 4
  25. R = np.zeros([n, n], dtype=T, order='F')
  26. # form a banded matrix inplace
  27. R[[x for x in range(n)], [x for x in range(n)]] = 1
  28. R[[x for x in range(n-k)], [x for x in range(k, n)]] = 1
  29. R[[x for x in range(1, n)], [x for x in range(n-1)]] = 1
  30. R[[x for x in range(k, n)], [x for x in range(n-k)]] = 1
  31. assert bandwidth(R) == (k, k)
  32. @pytest.mark.parametrize('T', [x for x in np.typecodes['All']
  33. if x not in 'eGUVOMm'])
  34. def test_bandwidth_rect_inputs(T):
  35. n, m = 10, 20
  36. k = 5
  37. R = np.zeros([n, m], dtype=T, order='F')
  38. # form a banded matrix inplace
  39. R[[x for x in range(n)], [x for x in range(n)]] = 1
  40. R[[x for x in range(n-k)], [x for x in range(k, n)]] = 1
  41. R[[x for x in range(1, n)], [x for x in range(n-1)]] = 1
  42. R[[x for x in range(k, n)], [x for x in range(n-k)]] = 1
  43. assert bandwidth(R) == (k, k)
  44. def test_issymetric_ishermitian_dtypes():
  45. n = 5
  46. for t in np.typecodes['All']:
  47. A = np.zeros([n, n], dtype=t)
  48. if t in 'eUVOMm':
  49. raises(TypeError, issymmetric, A)
  50. raises(TypeError, ishermitian, A)
  51. elif t == 'G': # No-op test. On win these pass on others fail.
  52. pass
  53. else:
  54. assert issymmetric(A)
  55. assert ishermitian(A)
  56. def test_issymmetric_ishermitian_invalid_input():
  57. A = np.array([1, 2, 3])
  58. raises(ValueError, issymmetric, A)
  59. raises(ValueError, ishermitian, A)
  60. A = np.array([[[1, 2, 3], [4, 5, 6]]])
  61. raises(ValueError, issymmetric, A)
  62. raises(ValueError, ishermitian, A)
  63. A = np.array([[1, 2, 3], [4, 5, 6]])
  64. raises(ValueError, issymmetric, A)
  65. raises(ValueError, ishermitian, A)
  66. def test_issymetric_complex_decimals():
  67. A = np.arange(1, 10).astype(complex).reshape(3, 3)
  68. A += np.arange(-4, 5).astype(complex).reshape(3, 3)*1j
  69. # make entries decimal
  70. A /= np.pi
  71. A = A + A.T
  72. assert issymmetric(A)
  73. def test_ishermitian_complex_decimals():
  74. A = np.arange(1, 10).astype(complex).reshape(3, 3)
  75. A += np.arange(-4, 5).astype(complex).reshape(3, 3)*1j
  76. # make entries decimal
  77. A /= np.pi
  78. A = A + A.T.conj()
  79. assert ishermitian(A)
  80. def test_issymmetric_approximate_results():
  81. n = 20
  82. rng = np.random.RandomState(123456789)
  83. x = rng.uniform(high=5., size=[n, n])
  84. y = x @ x.T # symmetric
  85. p = rng.standard_normal([n, n])
  86. z = p @ y @ p.T
  87. assert issymmetric(z, atol=1e-10)
  88. assert issymmetric(z, atol=1e-10, rtol=0.)
  89. assert issymmetric(z, atol=0., rtol=1e-12)
  90. assert issymmetric(z, atol=1e-13, rtol=1e-12)
  91. def test_ishermitian_approximate_results():
  92. n = 20
  93. rng = np.random.RandomState(987654321)
  94. x = rng.uniform(high=5., size=[n, n])
  95. y = x @ x.T # symmetric
  96. p = rng.standard_normal([n, n]) + rng.standard_normal([n, n])*1j
  97. z = p @ y @ p.conj().T
  98. assert ishermitian(z, atol=1e-10)
  99. assert ishermitian(z, atol=1e-10, rtol=0.)
  100. assert ishermitian(z, atol=0., rtol=1e-12)
  101. assert ishermitian(z, atol=1e-13, rtol=1e-12)