ops.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import numpy as np
  3. import pytest
  4. import pandas as pd
  5. import pandas._testing as tm
  6. from pandas.core import ops
  7. from pandas.tests.extension.base.base import BaseExtensionTests
  8. class BaseOpsUtil(BaseExtensionTests):
  9. def get_op_from_name(self, op_name: str):
  10. return tm.get_op_from_name(op_name)
  11. def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):
  12. op = self.get_op_from_name(op_name)
  13. self._check_op(ser, op, other, op_name, exc)
  14. def _combine(self, obj, other, op):
  15. if isinstance(obj, pd.DataFrame):
  16. if len(obj.columns) != 1:
  17. raise NotImplementedError
  18. expected = obj.iloc[:, 0].combine(other, op).to_frame()
  19. else:
  20. expected = obj.combine(other, op)
  21. return expected
  22. def _check_op(
  23. self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
  24. ):
  25. if exc is None:
  26. result = op(ser, other)
  27. expected = self._combine(ser, other, op)
  28. assert isinstance(result, type(ser))
  29. self.assert_equal(result, expected)
  30. else:
  31. with pytest.raises(exc):
  32. op(ser, other)
  33. def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception):
  34. # divmod has multiple return values, so check separately
  35. if exc is None:
  36. result_div, result_mod = op(ser, other)
  37. if op is divmod:
  38. expected_div, expected_mod = ser // other, ser % other
  39. else:
  40. expected_div, expected_mod = other // ser, other % ser
  41. self.assert_series_equal(result_div, expected_div)
  42. self.assert_series_equal(result_mod, expected_mod)
  43. else:
  44. with pytest.raises(exc):
  45. divmod(ser, other)
  46. class BaseArithmeticOpsTests(BaseOpsUtil):
  47. """
  48. Various Series and DataFrame arithmetic ops methods.
  49. Subclasses supporting various ops should set the class variables
  50. to indicate that they support ops of that kind
  51. * series_scalar_exc = TypeError
  52. * frame_scalar_exc = TypeError
  53. * series_array_exc = TypeError
  54. * divmod_exc = TypeError
  55. """
  56. series_scalar_exc: type[Exception] | None = TypeError
  57. frame_scalar_exc: type[Exception] | None = TypeError
  58. series_array_exc: type[Exception] | None = TypeError
  59. divmod_exc: type[Exception] | None = TypeError
  60. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  61. # series & scalar
  62. op_name = all_arithmetic_operators
  63. ser = pd.Series(data)
  64. self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc)
  65. def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
  66. # frame & scalar
  67. op_name = all_arithmetic_operators
  68. df = pd.DataFrame({"A": data})
  69. self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
  70. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  71. # ndarray & other series
  72. op_name = all_arithmetic_operators
  73. ser = pd.Series(data)
  74. self.check_opname(
  75. ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc
  76. )
  77. def test_divmod(self, data):
  78. ser = pd.Series(data)
  79. self._check_divmod_op(ser, divmod, 1, exc=self.divmod_exc)
  80. self._check_divmod_op(1, ops.rdivmod, ser, exc=self.divmod_exc)
  81. def test_divmod_series_array(self, data, data_for_twos):
  82. ser = pd.Series(data)
  83. self._check_divmod_op(ser, divmod, data)
  84. other = data_for_twos
  85. self._check_divmod_op(other, ops.rdivmod, ser)
  86. other = pd.Series(other)
  87. self._check_divmod_op(other, ops.rdivmod, ser)
  88. def test_add_series_with_extension_array(self, data):
  89. ser = pd.Series(data)
  90. result = ser + data
  91. expected = pd.Series(data + data)
  92. self.assert_series_equal(result, expected)
  93. @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
  94. def test_direct_arith_with_ndframe_returns_not_implemented(
  95. self, request, data, box
  96. ):
  97. # EAs should return NotImplemented for ops with Series/DataFrame
  98. # Pandas takes care of unboxing the series and calling the EA's op.
  99. other = pd.Series(data)
  100. if box is pd.DataFrame:
  101. other = other.to_frame()
  102. if not hasattr(data, "__add__"):
  103. request.node.add_marker(
  104. pytest.mark.xfail(
  105. reason=f"{type(data).__name__} does not implement add"
  106. )
  107. )
  108. result = data.__add__(other)
  109. assert result is NotImplemented
  110. class BaseComparisonOpsTests(BaseOpsUtil):
  111. """Various Series and DataFrame comparison ops methods."""
  112. def _compare_other(self, ser: pd.Series, data, op, other):
  113. if op.__name__ in ["eq", "ne"]:
  114. # comparison should match point-wise comparisons
  115. result = op(ser, other)
  116. expected = ser.combine(other, op)
  117. self.assert_series_equal(result, expected)
  118. else:
  119. exc = None
  120. try:
  121. result = op(ser, other)
  122. except Exception as err:
  123. exc = err
  124. if exc is None:
  125. # Didn't error, then should match pointwise behavior
  126. expected = ser.combine(other, op)
  127. self.assert_series_equal(result, expected)
  128. else:
  129. with pytest.raises(type(exc)):
  130. ser.combine(other, op)
  131. def test_compare_scalar(self, data, comparison_op):
  132. ser = pd.Series(data)
  133. self._compare_other(ser, data, comparison_op, 0)
  134. def test_compare_array(self, data, comparison_op):
  135. ser = pd.Series(data)
  136. other = pd.Series([data[0]] * len(data))
  137. self._compare_other(ser, data, comparison_op, other)
  138. @pytest.mark.parametrize("box", [pd.Series, pd.DataFrame])
  139. def test_direct_arith_with_ndframe_returns_not_implemented(self, data, box):
  140. # EAs should return NotImplemented for ops with Series/DataFrame
  141. # Pandas takes care of unboxing the series and calling the EA's op.
  142. other = pd.Series(data)
  143. if box is pd.DataFrame:
  144. other = other.to_frame()
  145. if hasattr(data, "__eq__"):
  146. result = data.__eq__(other)
  147. assert result is NotImplemented
  148. else:
  149. pytest.skip(f"{type(data).__name__} does not implement __eq__")
  150. if hasattr(data, "__ne__"):
  151. result = data.__ne__(other)
  152. assert result is NotImplemented
  153. else:
  154. pytest.skip(f"{type(data).__name__} does not implement __ne__")
  155. class BaseUnaryOpsTests(BaseOpsUtil):
  156. def test_invert(self, data):
  157. ser = pd.Series(data, name="name")
  158. result = ~ser
  159. expected = pd.Series(~data, name="name")
  160. self.assert_series_equal(result, expected)
  161. @pytest.mark.parametrize("ufunc", [np.positive, np.negative, np.abs])
  162. def test_unary_ufunc_dunder_equivalence(self, data, ufunc):
  163. # the dunder __pos__ works if and only if np.positive works,
  164. # same for __neg__/np.negative and __abs__/np.abs
  165. attr = {np.positive: "__pos__", np.negative: "__neg__", np.abs: "__abs__"}[
  166. ufunc
  167. ]
  168. exc = None
  169. try:
  170. result = getattr(data, attr)()
  171. except Exception as err:
  172. exc = err
  173. # if __pos__ raised, then so should the ufunc
  174. with pytest.raises((type(exc), TypeError)):
  175. ufunc(data)
  176. else:
  177. alt = ufunc(data)
  178. self.assert_extension_array_equal(result, alt)