test_crosstab.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import pytest
  2. import numpy as np
  3. from numpy.testing import assert_array_equal, assert_equal
  4. from scipy.stats.contingency import crosstab
  5. @pytest.mark.parametrize('sparse', [False, True])
  6. def test_crosstab_basic(sparse):
  7. a = [0, 0, 9, 9, 0, 0, 9]
  8. b = [2, 1, 3, 1, 2, 3, 3]
  9. expected_avals = [0, 9]
  10. expected_bvals = [1, 2, 3]
  11. expected_count = np.array([[1, 2, 1],
  12. [1, 0, 2]])
  13. (avals, bvals), count = crosstab(a, b, sparse=sparse)
  14. assert_array_equal(avals, expected_avals)
  15. assert_array_equal(bvals, expected_bvals)
  16. if sparse:
  17. assert_array_equal(count.A, expected_count)
  18. else:
  19. assert_array_equal(count, expected_count)
  20. def test_crosstab_basic_1d():
  21. # Verify that a single input sequence works as expected.
  22. x = [1, 2, 3, 1, 2, 3, 3]
  23. expected_xvals = [1, 2, 3]
  24. expected_count = np.array([2, 2, 3])
  25. (xvals,), count = crosstab(x)
  26. assert_array_equal(xvals, expected_xvals)
  27. assert_array_equal(count, expected_count)
  28. def test_crosstab_basic_3d():
  29. # Verify the function for three input sequences.
  30. a = 'a'
  31. b = 'b'
  32. x = [0, 0, 9, 9, 0, 0, 9, 9]
  33. y = [a, a, a, a, b, b, b, a]
  34. z = [1, 2, 3, 1, 2, 3, 3, 1]
  35. expected_xvals = [0, 9]
  36. expected_yvals = [a, b]
  37. expected_zvals = [1, 2, 3]
  38. expected_count = np.array([[[1, 1, 0],
  39. [0, 1, 1]],
  40. [[2, 0, 1],
  41. [0, 0, 1]]])
  42. (xvals, yvals, zvals), count = crosstab(x, y, z)
  43. assert_array_equal(xvals, expected_xvals)
  44. assert_array_equal(yvals, expected_yvals)
  45. assert_array_equal(zvals, expected_zvals)
  46. assert_array_equal(count, expected_count)
  47. @pytest.mark.parametrize('sparse', [False, True])
  48. def test_crosstab_levels(sparse):
  49. a = [0, 0, 9, 9, 0, 0, 9]
  50. b = [1, 2, 3, 1, 2, 3, 3]
  51. expected_avals = [0, 9]
  52. expected_bvals = [0, 1, 2, 3]
  53. expected_count = np.array([[0, 1, 2, 1],
  54. [0, 1, 0, 2]])
  55. (avals, bvals), count = crosstab(a, b, levels=[None, [0, 1, 2, 3]],
  56. sparse=sparse)
  57. assert_array_equal(avals, expected_avals)
  58. assert_array_equal(bvals, expected_bvals)
  59. if sparse:
  60. assert_array_equal(count.A, expected_count)
  61. else:
  62. assert_array_equal(count, expected_count)
  63. @pytest.mark.parametrize('sparse', [False, True])
  64. def test_crosstab_extra_levels(sparse):
  65. # The pair of values (-1, 3) will be ignored, because we explicitly
  66. # request the counted `a` values to be [0, 9].
  67. a = [0, 0, 9, 9, 0, 0, 9, -1]
  68. b = [1, 2, 3, 1, 2, 3, 3, 3]
  69. expected_avals = [0, 9]
  70. expected_bvals = [0, 1, 2, 3]
  71. expected_count = np.array([[0, 1, 2, 1],
  72. [0, 1, 0, 2]])
  73. (avals, bvals), count = crosstab(a, b, levels=[[0, 9], [0, 1, 2, 3]],
  74. sparse=sparse)
  75. assert_array_equal(avals, expected_avals)
  76. assert_array_equal(bvals, expected_bvals)
  77. if sparse:
  78. assert_array_equal(count.A, expected_count)
  79. else:
  80. assert_array_equal(count, expected_count)
  81. def test_validation_at_least_one():
  82. with pytest.raises(TypeError, match='At least one'):
  83. crosstab()
  84. def test_validation_same_lengths():
  85. with pytest.raises(ValueError, match='must have the same length'):
  86. crosstab([1, 2], [1, 2, 3, 4])
  87. def test_validation_sparse_only_two_args():
  88. with pytest.raises(ValueError, match='only two input sequences'):
  89. crosstab([0, 1, 1], [8, 8, 9], [1, 3, 3], sparse=True)
  90. def test_validation_len_levels_matches_args():
  91. with pytest.raises(ValueError, match='number of input sequences'):
  92. crosstab([0, 1, 1], [8, 8, 9], levels=([0, 1, 2, 3],))
  93. def test_result():
  94. res = crosstab([0, 1], [1, 2])
  95. assert_equal((res.elements, res.count), res)