test_matrix_linalg.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """ Test functions for linalg module using the matrix class."""
  2. import numpy as np
  3. from numpy.linalg.tests.test_linalg import (
  4. LinalgCase, apply_tag, TestQR as _TestQR, LinalgTestCase,
  5. _TestNorm2D, _TestNormDoubleBase, _TestNormSingleBase, _TestNormInt64Base,
  6. SolveCases, InvCases, EigvalsCases, EigCases, SVDCases, CondCases,
  7. PinvCases, DetCases, LstsqCases)
  8. CASES = []
  9. # square test cases
  10. CASES += apply_tag('square', [
  11. LinalgCase("0x0_matrix",
  12. np.empty((0, 0), dtype=np.double).view(np.matrix),
  13. np.empty((0, 1), dtype=np.double).view(np.matrix),
  14. tags={'size-0'}),
  15. LinalgCase("matrix_b_only",
  16. np.array([[1., 2.], [3., 4.]]),
  17. np.matrix([2., 1.]).T),
  18. LinalgCase("matrix_a_and_b",
  19. np.matrix([[1., 2.], [3., 4.]]),
  20. np.matrix([2., 1.]).T),
  21. ])
  22. # hermitian test-cases
  23. CASES += apply_tag('hermitian', [
  24. LinalgCase("hmatrix_a_and_b",
  25. np.matrix([[1., 2.], [2., 1.]]),
  26. None),
  27. ])
  28. # No need to make generalized or strided cases for matrices.
  29. class MatrixTestCase(LinalgTestCase):
  30. TEST_CASES = CASES
  31. class TestSolveMatrix(SolveCases, MatrixTestCase):
  32. pass
  33. class TestInvMatrix(InvCases, MatrixTestCase):
  34. pass
  35. class TestEigvalsMatrix(EigvalsCases, MatrixTestCase):
  36. pass
  37. class TestEigMatrix(EigCases, MatrixTestCase):
  38. pass
  39. class TestSVDMatrix(SVDCases, MatrixTestCase):
  40. pass
  41. class TestCondMatrix(CondCases, MatrixTestCase):
  42. pass
  43. class TestPinvMatrix(PinvCases, MatrixTestCase):
  44. pass
  45. class TestDetMatrix(DetCases, MatrixTestCase):
  46. pass
  47. class TestLstsqMatrix(LstsqCases, MatrixTestCase):
  48. pass
  49. class _TestNorm2DMatrix(_TestNorm2D):
  50. array = np.matrix
  51. class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase):
  52. pass
  53. class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase):
  54. pass
  55. class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base):
  56. pass
  57. class TestQRMatrix(_TestQR):
  58. array = np.matrix