test_sample.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import pytest
  2. from pandas import (
  3. DataFrame,
  4. Index,
  5. Series,
  6. )
  7. import pandas._testing as tm
  8. @pytest.mark.parametrize("n, frac", [(2, None), (None, 0.2)])
  9. def test_groupby_sample_balanced_groups_shape(n, frac):
  10. values = [1] * 10 + [2] * 10
  11. df = DataFrame({"a": values, "b": values})
  12. result = df.groupby("a").sample(n=n, frac=frac)
  13. values = [1] * 2 + [2] * 2
  14. expected = DataFrame({"a": values, "b": values}, index=result.index)
  15. tm.assert_frame_equal(result, expected)
  16. result = df.groupby("a")["b"].sample(n=n, frac=frac)
  17. expected = Series(values, name="b", index=result.index)
  18. tm.assert_series_equal(result, expected)
  19. def test_groupby_sample_unbalanced_groups_shape():
  20. values = [1] * 10 + [2] * 20
  21. df = DataFrame({"a": values, "b": values})
  22. result = df.groupby("a").sample(n=5)
  23. values = [1] * 5 + [2] * 5
  24. expected = DataFrame({"a": values, "b": values}, index=result.index)
  25. tm.assert_frame_equal(result, expected)
  26. result = df.groupby("a")["b"].sample(n=5)
  27. expected = Series(values, name="b", index=result.index)
  28. tm.assert_series_equal(result, expected)
  29. def test_groupby_sample_index_value_spans_groups():
  30. values = [1] * 3 + [2] * 3
  31. df = DataFrame({"a": values, "b": values}, index=[1, 2, 2, 2, 2, 2])
  32. result = df.groupby("a").sample(n=2)
  33. values = [1] * 2 + [2] * 2
  34. expected = DataFrame({"a": values, "b": values}, index=result.index)
  35. tm.assert_frame_equal(result, expected)
  36. result = df.groupby("a")["b"].sample(n=2)
  37. expected = Series(values, name="b", index=result.index)
  38. tm.assert_series_equal(result, expected)
  39. def test_groupby_sample_n_and_frac_raises():
  40. df = DataFrame({"a": [1, 2], "b": [1, 2]})
  41. msg = "Please enter a value for `frac` OR `n`, not both"
  42. with pytest.raises(ValueError, match=msg):
  43. df.groupby("a").sample(n=1, frac=1.0)
  44. with pytest.raises(ValueError, match=msg):
  45. df.groupby("a")["b"].sample(n=1, frac=1.0)
  46. def test_groupby_sample_frac_gt_one_without_replacement_raises():
  47. df = DataFrame({"a": [1, 2], "b": [1, 2]})
  48. msg = "Replace has to be set to `True` when upsampling the population `frac` > 1."
  49. with pytest.raises(ValueError, match=msg):
  50. df.groupby("a").sample(frac=1.5, replace=False)
  51. with pytest.raises(ValueError, match=msg):
  52. df.groupby("a")["b"].sample(frac=1.5, replace=False)
  53. @pytest.mark.parametrize("n", [-1, 1.5])
  54. def test_groupby_sample_invalid_n_raises(n):
  55. df = DataFrame({"a": [1, 2], "b": [1, 2]})
  56. if n < 0:
  57. msg = "A negative number of rows requested. Please provide `n` >= 0."
  58. else:
  59. msg = "Only integers accepted as `n` values"
  60. with pytest.raises(ValueError, match=msg):
  61. df.groupby("a").sample(n=n)
  62. with pytest.raises(ValueError, match=msg):
  63. df.groupby("a")["b"].sample(n=n)
  64. def test_groupby_sample_oversample():
  65. values = [1] * 10 + [2] * 10
  66. df = DataFrame({"a": values, "b": values})
  67. result = df.groupby("a").sample(frac=2.0, replace=True)
  68. values = [1] * 20 + [2] * 20
  69. expected = DataFrame({"a": values, "b": values}, index=result.index)
  70. tm.assert_frame_equal(result, expected)
  71. result = df.groupby("a")["b"].sample(frac=2.0, replace=True)
  72. expected = Series(values, name="b", index=result.index)
  73. tm.assert_series_equal(result, expected)
  74. def test_groupby_sample_without_n_or_frac():
  75. values = [1] * 10 + [2] * 10
  76. df = DataFrame({"a": values, "b": values})
  77. result = df.groupby("a").sample(n=None, frac=None)
  78. expected = DataFrame({"a": [1, 2], "b": [1, 2]}, index=result.index)
  79. tm.assert_frame_equal(result, expected)
  80. result = df.groupby("a")["b"].sample(n=None, frac=None)
  81. expected = Series([1, 2], name="b", index=result.index)
  82. tm.assert_series_equal(result, expected)
  83. @pytest.mark.parametrize(
  84. "index, expected_index",
  85. [(["w", "x", "y", "z"], ["w", "w", "y", "y"]), ([3, 4, 5, 6], [3, 3, 5, 5])],
  86. )
  87. def test_groupby_sample_with_weights(index, expected_index):
  88. # GH 39927 - tests for integer index needed
  89. values = [1] * 2 + [2] * 2
  90. df = DataFrame({"a": values, "b": values}, index=Index(index))
  91. result = df.groupby("a").sample(n=2, replace=True, weights=[1, 0, 1, 0])
  92. expected = DataFrame({"a": values, "b": values}, index=Index(expected_index))
  93. tm.assert_frame_equal(result, expected)
  94. result = df.groupby("a")["b"].sample(n=2, replace=True, weights=[1, 0, 1, 0])
  95. expected = Series(values, name="b", index=Index(expected_index))
  96. tm.assert_series_equal(result, expected)
  97. def test_groupby_sample_with_selections():
  98. # GH 39928
  99. values = [1] * 10 + [2] * 10
  100. df = DataFrame({"a": values, "b": values, "c": values})
  101. result = df.groupby("a")[["b", "c"]].sample(n=None, frac=None)
  102. expected = DataFrame({"b": [1, 2], "c": [1, 2]}, index=result.index)
  103. tm.assert_frame_equal(result, expected)
  104. def test_groupby_sample_with_empty_inputs():
  105. # GH48459
  106. df = DataFrame({"a": [], "b": []})
  107. groupby_df = df.groupby("a")
  108. result = groupby_df.sample()
  109. expected = df
  110. tm.assert_frame_equal(result, expected)