test_linear_assignment.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Author: Brian M. Clapper, G. Varoquaux, Lars Buitinck
  2. # License: BSD
  3. from numpy.testing import assert_array_equal
  4. import pytest
  5. import numpy as np
  6. from scipy.optimize import linear_sum_assignment
  7. from scipy.sparse import random
  8. from scipy.sparse._sputils import matrix
  9. from scipy.sparse.csgraph import min_weight_full_bipartite_matching
  10. from scipy.sparse.csgraph.tests.test_matching import (
  11. linear_sum_assignment_assertions, linear_sum_assignment_test_cases
  12. )
  13. def test_linear_sum_assignment_input_shape():
  14. with pytest.raises(ValueError, match="expected a matrix"):
  15. linear_sum_assignment([1, 2, 3])
  16. def test_linear_sum_assignment_input_object():
  17. C = [[1, 2, 3], [4, 5, 6]]
  18. assert_array_equal(linear_sum_assignment(C),
  19. linear_sum_assignment(np.asarray(C)))
  20. assert_array_equal(linear_sum_assignment(C),
  21. linear_sum_assignment(matrix(C)))
  22. def test_linear_sum_assignment_input_bool():
  23. I = np.identity(3)
  24. assert_array_equal(linear_sum_assignment(I.astype(np.bool_)),
  25. linear_sum_assignment(I))
  26. def test_linear_sum_assignment_input_string():
  27. I = np.identity(3)
  28. with pytest.raises(TypeError, match="Cannot cast array data"):
  29. linear_sum_assignment(I.astype(str))
  30. def test_linear_sum_assignment_input_nan():
  31. I = np.diag([np.nan, 1, 1])
  32. with pytest.raises(ValueError, match="contains invalid numeric entries"):
  33. linear_sum_assignment(I)
  34. def test_linear_sum_assignment_input_neginf():
  35. I = np.diag([1, -np.inf, 1])
  36. with pytest.raises(ValueError, match="contains invalid numeric entries"):
  37. linear_sum_assignment(I)
  38. def test_linear_sum_assignment_input_inf():
  39. I = np.identity(3)
  40. I[:, 0] = np.inf
  41. with pytest.raises(ValueError, match="cost matrix is infeasible"):
  42. linear_sum_assignment(I)
  43. def test_constant_cost_matrix():
  44. # Fixes #11602
  45. n = 8
  46. C = np.ones((n, n))
  47. row_ind, col_ind = linear_sum_assignment(C)
  48. assert_array_equal(row_ind, np.arange(n))
  49. assert_array_equal(col_ind, np.arange(n))
  50. @pytest.mark.parametrize('num_rows,num_cols', [(0, 0), (2, 0), (0, 3)])
  51. def test_linear_sum_assignment_trivial_cost(num_rows, num_cols):
  52. C = np.empty(shape=(num_cols, num_rows))
  53. row_ind, col_ind = linear_sum_assignment(C)
  54. assert len(row_ind) == 0
  55. assert len(col_ind) == 0
  56. @pytest.mark.parametrize('sign,test_case', linear_sum_assignment_test_cases)
  57. def test_linear_sum_assignment_small_inputs(sign, test_case):
  58. linear_sum_assignment_assertions(
  59. linear_sum_assignment, np.array, sign, test_case)
  60. # Tests that combine scipy.optimize.linear_sum_assignment and
  61. # scipy.sparse.csgraph.min_weight_full_bipartite_matching
  62. def test_two_methods_give_same_result_on_many_sparse_inputs():
  63. # As opposed to the test above, here we do not spell out the expected
  64. # output; only assert that the two methods give the same result.
  65. # Concretely, the below tests 100 cases of size 100x100, out of which
  66. # 36 are infeasible.
  67. np.random.seed(1234)
  68. for _ in range(100):
  69. lsa_raises = False
  70. mwfbm_raises = False
  71. sparse = random(100, 100, density=0.06,
  72. data_rvs=lambda size: np.random.randint(1, 100, size))
  73. # In csgraph, zeros correspond to missing edges, so we explicitly
  74. # replace those with infinities
  75. dense = np.full(sparse.shape, np.inf)
  76. dense[sparse.row, sparse.col] = sparse.data
  77. sparse = sparse.tocsr()
  78. try:
  79. row_ind, col_ind = linear_sum_assignment(dense)
  80. lsa_cost = dense[row_ind, col_ind].sum()
  81. except ValueError:
  82. lsa_raises = True
  83. try:
  84. row_ind, col_ind = min_weight_full_bipartite_matching(sparse)
  85. mwfbm_cost = sparse[row_ind, col_ind].sum()
  86. except ValueError:
  87. mwfbm_raises = True
  88. # Ensure that if one method raises, so does the other one.
  89. assert lsa_raises == mwfbm_raises
  90. if not lsa_raises:
  91. assert lsa_cost == mwfbm_cost