test_groupby_subclass.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Index,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. from pandas.tests.groupby import get_groupby_method_args
  11. @pytest.mark.parametrize(
  12. "obj",
  13. [
  14. tm.SubclassedDataFrame({"A": np.arange(0, 10)}),
  15. tm.SubclassedSeries(np.arange(0, 10), name="A"),
  16. ],
  17. )
  18. def test_groupby_preserves_subclass(obj, groupby_func):
  19. # GH28330 -- preserve subclass through groupby operations
  20. if isinstance(obj, Series) and groupby_func in {"corrwith"}:
  21. pytest.skip(f"Not applicable for Series and {groupby_func}")
  22. grouped = obj.groupby(np.arange(0, 10))
  23. # Groups should preserve subclass type
  24. assert isinstance(grouped.get_group(0), type(obj))
  25. args = get_groupby_method_args(groupby_func, obj)
  26. result1 = getattr(grouped, groupby_func)(*args)
  27. result2 = grouped.agg(groupby_func, *args)
  28. # Reduction or transformation kernels should preserve type
  29. slices = {"ngroup", "cumcount", "size"}
  30. if isinstance(obj, DataFrame) and groupby_func in slices:
  31. assert isinstance(result1, tm.SubclassedSeries)
  32. else:
  33. assert isinstance(result1, type(obj))
  34. # Confirm .agg() groupby operations return same results
  35. if isinstance(result1, DataFrame):
  36. tm.assert_frame_equal(result1, result2)
  37. else:
  38. tm.assert_series_equal(result1, result2)
  39. def test_groupby_preserves_metadata():
  40. # GH-37343
  41. custom_df = tm.SubclassedDataFrame({"a": [1, 2, 3], "b": [1, 1, 2], "c": [7, 8, 9]})
  42. assert "testattr" in custom_df._metadata
  43. custom_df.testattr = "hello"
  44. for _, group_df in custom_df.groupby("c"):
  45. assert group_df.testattr == "hello"
  46. # GH-45314
  47. def func(group):
  48. assert isinstance(group, tm.SubclassedDataFrame)
  49. assert hasattr(group, "testattr")
  50. return group.testattr
  51. result = custom_df.groupby("c").apply(func)
  52. expected = tm.SubclassedSeries(["hello"] * 3, index=Index([7, 8, 9], name="c"))
  53. tm.assert_series_equal(result, expected)
  54. def func2(group):
  55. assert isinstance(group, tm.SubclassedSeries)
  56. assert hasattr(group, "testattr")
  57. return group.testattr
  58. custom_series = tm.SubclassedSeries([1, 2, 3])
  59. custom_series.testattr = "hello"
  60. result = custom_series.groupby(custom_df["c"]).apply(func2)
  61. tm.assert_series_equal(result, expected)
  62. result = custom_series.groupby(custom_df["c"]).agg(func2)
  63. tm.assert_series_equal(result, expected)
  64. @pytest.mark.parametrize("obj", [DataFrame, tm.SubclassedDataFrame])
  65. def test_groupby_resample_preserves_subclass(obj):
  66. # GH28330 -- preserve subclass through groupby.resample()
  67. df = obj(
  68. {
  69. "Buyer": "Carl Carl Carl Carl Joe Carl".split(),
  70. "Quantity": [18, 3, 5, 1, 9, 3],
  71. "Date": [
  72. datetime(2013, 9, 1, 13, 0),
  73. datetime(2013, 9, 1, 13, 5),
  74. datetime(2013, 10, 1, 20, 0),
  75. datetime(2013, 10, 3, 10, 0),
  76. datetime(2013, 12, 2, 12, 0),
  77. datetime(2013, 9, 2, 14, 0),
  78. ],
  79. }
  80. )
  81. df = df.set_index("Date")
  82. # Confirm groupby.resample() preserves dataframe type
  83. result = df.groupby("Buyer").resample("5D").sum()
  84. assert isinstance(result, obj)