test_extract.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. """test sparse matrix construction functions"""
  2. from numpy.testing import assert_equal
  3. from scipy.sparse import csr_matrix
  4. import numpy as np
  5. from scipy.sparse import _extract
  6. class TestExtract:
  7. def setup_method(self):
  8. self.cases = [
  9. csr_matrix([[1,2]]),
  10. csr_matrix([[1,0]]),
  11. csr_matrix([[0,0]]),
  12. csr_matrix([[1],[2]]),
  13. csr_matrix([[1],[0]]),
  14. csr_matrix([[0],[0]]),
  15. csr_matrix([[1,2],[3,4]]),
  16. csr_matrix([[0,1],[0,0]]),
  17. csr_matrix([[0,0],[1,0]]),
  18. csr_matrix([[0,0],[0,0]]),
  19. csr_matrix([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]),
  20. csr_matrix([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]).T,
  21. ]
  22. def find(self):
  23. for A in self.cases:
  24. I,J,V = _extract.find(A)
  25. assert_equal(A.toarray(), csr_matrix(((I,J),V), shape=A.shape))
  26. def test_tril(self):
  27. for A in self.cases:
  28. B = A.toarray()
  29. for k in [-3,-2,-1,0,1,2,3]:
  30. assert_equal(_extract.tril(A,k=k).toarray(), np.tril(B,k=k))
  31. def test_triu(self):
  32. for A in self.cases:
  33. B = A.toarray()
  34. for k in [-3,-2,-1,0,1,2,3]:
  35. assert_equal(_extract.triu(A,k=k).toarray(), np.triu(B,k=k))