test_online.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import numpy as np
  2. import pytest
  3. from pandas.compat import (
  4. is_ci_environment,
  5. is_platform_mac,
  6. is_platform_windows,
  7. )
  8. import pandas.util._test_decorators as td
  9. from pandas import (
  10. DataFrame,
  11. Series,
  12. )
  13. import pandas._testing as tm
  14. # TODO(GH#44584): Mark these as pytest.mark.single_cpu
  15. pytestmark = pytest.mark.skipif(
  16. is_ci_environment() and (is_platform_windows() or is_platform_mac()),
  17. reason="On GHA CI, Windows can fail with "
  18. "'Windows fatal exception: stack overflow' "
  19. "and macOS can timeout",
  20. )
  21. @td.skip_if_no("numba")
  22. @pytest.mark.filterwarnings("ignore")
  23. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  24. class TestEWM:
  25. def test_invalid_update(self):
  26. df = DataFrame({"a": range(5), "b": range(5)})
  27. online_ewm = df.head(2).ewm(0.5).online()
  28. with pytest.raises(
  29. ValueError,
  30. match="Must call mean with update=None first before passing update",
  31. ):
  32. online_ewm.mean(update=df.head(1))
  33. @pytest.mark.slow
  34. @pytest.mark.parametrize(
  35. "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
  36. )
  37. def test_online_vs_non_online_mean(
  38. self, obj, nogil, parallel, nopython, adjust, ignore_na
  39. ):
  40. expected = obj.ewm(0.5, adjust=adjust, ignore_na=ignore_na).mean()
  41. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  42. online_ewm = (
  43. obj.head(2)
  44. .ewm(0.5, adjust=adjust, ignore_na=ignore_na)
  45. .online(engine_kwargs=engine_kwargs)
  46. )
  47. # Test resetting once
  48. for _ in range(2):
  49. result = online_ewm.mean()
  50. tm.assert_equal(result, expected.head(2))
  51. result = online_ewm.mean(update=obj.tail(3))
  52. tm.assert_equal(result, expected.tail(3))
  53. online_ewm.reset()
  54. @pytest.mark.xfail(raises=NotImplementedError)
  55. @pytest.mark.parametrize(
  56. "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")]
  57. )
  58. def test_update_times_mean(
  59. self, obj, nogil, parallel, nopython, adjust, ignore_na, halflife_with_times
  60. ):
  61. times = Series(
  62. np.array(
  63. ["2020-01-01", "2020-01-05", "2020-01-07", "2020-01-17", "2020-01-21"],
  64. dtype="datetime64[ns]",
  65. )
  66. )
  67. expected = obj.ewm(
  68. 0.5,
  69. adjust=adjust,
  70. ignore_na=ignore_na,
  71. times=times,
  72. halflife=halflife_with_times,
  73. ).mean()
  74. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  75. online_ewm = (
  76. obj.head(2)
  77. .ewm(
  78. 0.5,
  79. adjust=adjust,
  80. ignore_na=ignore_na,
  81. times=times.head(2),
  82. halflife=halflife_with_times,
  83. )
  84. .online(engine_kwargs=engine_kwargs)
  85. )
  86. # Test resetting once
  87. for _ in range(2):
  88. result = online_ewm.mean()
  89. tm.assert_equal(result, expected.head(2))
  90. result = online_ewm.mean(update=obj.tail(3), update_times=times.tail(3))
  91. tm.assert_equal(result, expected.tail(3))
  92. online_ewm.reset()
  93. @pytest.mark.parametrize("method", ["aggregate", "std", "corr", "cov", "var"])
  94. def test_ewm_notimplementederror_raises(self, method):
  95. ser = Series(range(10))
  96. kwargs = {}
  97. if method == "aggregate":
  98. kwargs["func"] = lambda x: x
  99. with pytest.raises(NotImplementedError, match=".* is not implemented."):
  100. getattr(ser.ewm(1).online(), method)(**kwargs)