test_arithmetic.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from __future__ import annotations
  2. from typing import Any
  3. import numpy as np
  4. import pytest
  5. import pandas as pd
  6. import pandas._testing as tm
  7. # integer dtypes
  8. arrays = [pd.array([1, 2, 3, None], dtype=dtype) for dtype in tm.ALL_INT_EA_DTYPES]
  9. scalars: list[Any] = [2] * len(arrays)
  10. # floating dtypes
  11. arrays += [pd.array([0.1, 0.2, 0.3, None], dtype=dtype) for dtype in tm.FLOAT_EA_DTYPES]
  12. scalars += [0.2, 0.2]
  13. # boolean
  14. arrays += [pd.array([True, False, True, None], dtype="boolean")]
  15. scalars += [False]
  16. @pytest.fixture(params=zip(arrays, scalars), ids=[a.dtype.name for a in arrays])
  17. def data(request):
  18. """Fixture returning parametrized (array, scalar) tuple.
  19. Used to test equivalence of scalars, numpy arrays with array ops, and the
  20. equivalence of DataFrame and Series ops.
  21. """
  22. return request.param
  23. def check_skip(data, op_name):
  24. if isinstance(data.dtype, pd.BooleanDtype) and "sub" in op_name:
  25. pytest.skip("subtract not implemented for boolean")
  26. def is_bool_not_implemented(data, op_name):
  27. # match non-masked behavior
  28. return data.dtype.kind == "b" and op_name.strip("_").lstrip("r") in [
  29. "pow",
  30. "truediv",
  31. "floordiv",
  32. ]
  33. # Test equivalence of scalars, numpy arrays with array ops
  34. # -----------------------------------------------------------------------------
  35. def test_array_scalar_like_equivalence(data, all_arithmetic_operators):
  36. data, scalar = data
  37. op = tm.get_op_from_name(all_arithmetic_operators)
  38. check_skip(data, all_arithmetic_operators)
  39. scalar_array = pd.array([scalar] * len(data), dtype=data.dtype)
  40. # TODO also add len-1 array (np.array([scalar], dtype=data.dtype.numpy_dtype))
  41. for scalar in [scalar, data.dtype.type(scalar)]:
  42. if is_bool_not_implemented(data, all_arithmetic_operators):
  43. msg = "operator '.*' not implemented for bool dtypes"
  44. with pytest.raises(NotImplementedError, match=msg):
  45. op(data, scalar)
  46. with pytest.raises(NotImplementedError, match=msg):
  47. op(data, scalar_array)
  48. else:
  49. result = op(data, scalar)
  50. expected = op(data, scalar_array)
  51. tm.assert_extension_array_equal(result, expected)
  52. def test_array_NA(data, all_arithmetic_operators):
  53. data, _ = data
  54. op = tm.get_op_from_name(all_arithmetic_operators)
  55. check_skip(data, all_arithmetic_operators)
  56. scalar = pd.NA
  57. scalar_array = pd.array([pd.NA] * len(data), dtype=data.dtype)
  58. mask = data._mask.copy()
  59. if is_bool_not_implemented(data, all_arithmetic_operators):
  60. msg = "operator '.*' not implemented for bool dtypes"
  61. with pytest.raises(NotImplementedError, match=msg):
  62. op(data, scalar)
  63. # GH#45421 check op doesn't alter data._mask inplace
  64. tm.assert_numpy_array_equal(mask, data._mask)
  65. return
  66. result = op(data, scalar)
  67. # GH#45421 check op doesn't alter data._mask inplace
  68. tm.assert_numpy_array_equal(mask, data._mask)
  69. expected = op(data, scalar_array)
  70. tm.assert_numpy_array_equal(mask, data._mask)
  71. tm.assert_extension_array_equal(result, expected)
  72. def test_numpy_array_equivalence(data, all_arithmetic_operators):
  73. data, scalar = data
  74. op = tm.get_op_from_name(all_arithmetic_operators)
  75. check_skip(data, all_arithmetic_operators)
  76. numpy_array = np.array([scalar] * len(data), dtype=data.dtype.numpy_dtype)
  77. pd_array = pd.array(numpy_array, dtype=data.dtype)
  78. if is_bool_not_implemented(data, all_arithmetic_operators):
  79. msg = "operator '.*' not implemented for bool dtypes"
  80. with pytest.raises(NotImplementedError, match=msg):
  81. op(data, numpy_array)
  82. with pytest.raises(NotImplementedError, match=msg):
  83. op(data, pd_array)
  84. return
  85. result = op(data, numpy_array)
  86. expected = op(data, pd_array)
  87. tm.assert_extension_array_equal(result, expected)
  88. # Test equivalence with Series and DataFrame ops
  89. # -----------------------------------------------------------------------------
  90. def test_frame(data, all_arithmetic_operators):
  91. data, scalar = data
  92. op = tm.get_op_from_name(all_arithmetic_operators)
  93. check_skip(data, all_arithmetic_operators)
  94. # DataFrame with scalar
  95. df = pd.DataFrame({"A": data})
  96. if is_bool_not_implemented(data, all_arithmetic_operators):
  97. msg = "operator '.*' not implemented for bool dtypes"
  98. with pytest.raises(NotImplementedError, match=msg):
  99. op(df, scalar)
  100. with pytest.raises(NotImplementedError, match=msg):
  101. op(data, scalar)
  102. return
  103. result = op(df, scalar)
  104. expected = pd.DataFrame({"A": op(data, scalar)})
  105. tm.assert_frame_equal(result, expected)
  106. def test_series(data, all_arithmetic_operators):
  107. data, scalar = data
  108. op = tm.get_op_from_name(all_arithmetic_operators)
  109. check_skip(data, all_arithmetic_operators)
  110. ser = pd.Series(data)
  111. others = [
  112. scalar,
  113. np.array([scalar] * len(data), dtype=data.dtype.numpy_dtype),
  114. pd.array([scalar] * len(data), dtype=data.dtype),
  115. pd.Series([scalar] * len(data), dtype=data.dtype),
  116. ]
  117. for other in others:
  118. if is_bool_not_implemented(data, all_arithmetic_operators):
  119. msg = "operator '.*' not implemented for bool dtypes"
  120. with pytest.raises(NotImplementedError, match=msg):
  121. op(ser, other)
  122. else:
  123. result = op(ser, other)
  124. expected = pd.Series(op(data, other))
  125. tm.assert_series_equal(result, expected)
  126. # Test generic characteristics / errors
  127. # -----------------------------------------------------------------------------
  128. def test_error_invalid_object(data, all_arithmetic_operators):
  129. data, _ = data
  130. op = all_arithmetic_operators
  131. opa = getattr(data, op)
  132. # 2d -> return NotImplemented
  133. result = opa(pd.DataFrame({"A": data}))
  134. assert result is NotImplemented
  135. msg = r"can only perform ops with 1-d structures"
  136. with pytest.raises(NotImplementedError, match=msg):
  137. opa(np.arange(len(data)).reshape(-1, len(data)))
  138. def test_error_len_mismatch(data, all_arithmetic_operators):
  139. # operating with a list-like with non-matching length raises
  140. data, scalar = data
  141. op = tm.get_op_from_name(all_arithmetic_operators)
  142. other = [scalar] * (len(data) - 1)
  143. err = ValueError
  144. msg = "|".join(
  145. [
  146. r"operands could not be broadcast together with shapes \(3,\) \(4,\)",
  147. r"operands could not be broadcast together with shapes \(4,\) \(3,\)",
  148. ]
  149. )
  150. if data.dtype.kind == "b" and all_arithmetic_operators.strip("_") in [
  151. "sub",
  152. "rsub",
  153. ]:
  154. err = TypeError
  155. msg = (
  156. r"numpy boolean subtract, the `\-` operator, is not supported, use "
  157. r"the bitwise_xor, the `\^` operator, or the logical_xor function instead"
  158. )
  159. elif is_bool_not_implemented(data, all_arithmetic_operators):
  160. msg = "operator '.*' not implemented for bool dtypes"
  161. err = NotImplementedError
  162. for other in [other, np.array(other)]:
  163. with pytest.raises(err, match=msg):
  164. op(data, other)
  165. s = pd.Series(data)
  166. with pytest.raises(err, match=msg):
  167. op(s, other)
  168. @pytest.mark.parametrize("op", ["__neg__", "__abs__", "__invert__"])
  169. def test_unary_op_does_not_propagate_mask(data, op):
  170. # https://github.com/pandas-dev/pandas/issues/39943
  171. data, _ = data
  172. ser = pd.Series(data)
  173. if op == "__invert__" and data.dtype.kind == "f":
  174. # we follow numpy in raising
  175. msg = "ufunc 'invert' not supported for the input types"
  176. with pytest.raises(TypeError, match=msg):
  177. getattr(ser, op)()
  178. with pytest.raises(TypeError, match=msg):
  179. getattr(data, op)()
  180. with pytest.raises(TypeError, match=msg):
  181. # Check that this is still the numpy behavior
  182. getattr(data._data, op)()
  183. return
  184. result = getattr(ser, op)()
  185. expected = result.copy(deep=True)
  186. ser[0] = None
  187. tm.assert_series_equal(result, expected)