test_arithmetic.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import operator
  2. import numpy as np
  3. import pytest
  4. import pandas as pd
  5. import pandas._testing as tm
  6. @pytest.fixture
  7. def data():
  8. """Fixture returning boolean array with valid and missing values."""
  9. return pd.array(
  10. [True, False] * 4 + [np.nan] + [True, False] * 44 + [np.nan] + [True, False],
  11. dtype="boolean",
  12. )
  13. @pytest.fixture
  14. def left_array():
  15. """Fixture returning boolean array with valid and missing values."""
  16. return pd.array([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean")
  17. @pytest.fixture
  18. def right_array():
  19. """Fixture returning boolean array with valid and missing values."""
  20. return pd.array([True, False, None] * 3, dtype="boolean")
  21. # Basic test for the arithmetic array ops
  22. # -----------------------------------------------------------------------------
  23. @pytest.mark.parametrize(
  24. "opname, exp",
  25. [
  26. ("add", [True, True, None, True, False, None, None, None, None]),
  27. ("mul", [True, False, None, False, False, None, None, None, None]),
  28. ],
  29. ids=["add", "mul"],
  30. )
  31. def test_add_mul(left_array, right_array, opname, exp):
  32. op = getattr(operator, opname)
  33. result = op(left_array, right_array)
  34. expected = pd.array(exp, dtype="boolean")
  35. tm.assert_extension_array_equal(result, expected)
  36. def test_sub(left_array, right_array):
  37. msg = (
  38. r"numpy boolean subtract, the `-` operator, is (?:deprecated|not supported), "
  39. r"use the bitwise_xor, the `\^` operator, or the logical_xor function instead\."
  40. )
  41. with pytest.raises(TypeError, match=msg):
  42. left_array - right_array
  43. def test_div(left_array, right_array):
  44. msg = "operator '.*' not implemented for bool dtypes"
  45. with pytest.raises(NotImplementedError, match=msg):
  46. # check that we are matching the non-masked Series behavior
  47. pd.Series(left_array._data) / pd.Series(right_array._data)
  48. with pytest.raises(NotImplementedError, match=msg):
  49. left_array / right_array
  50. @pytest.mark.parametrize(
  51. "opname",
  52. [
  53. "floordiv",
  54. "mod",
  55. "pow",
  56. ],
  57. )
  58. def test_op_int8(left_array, right_array, opname):
  59. op = getattr(operator, opname)
  60. if opname != "mod":
  61. msg = "operator '.*' not implemented for bool dtypes"
  62. with pytest.raises(NotImplementedError, match=msg):
  63. result = op(left_array, right_array)
  64. return
  65. result = op(left_array, right_array)
  66. expected = op(left_array.astype("Int8"), right_array.astype("Int8"))
  67. tm.assert_extension_array_equal(result, expected)
  68. # Test generic characteristics / errors
  69. # -----------------------------------------------------------------------------
  70. def test_error_invalid_values(data, all_arithmetic_operators):
  71. # invalid ops
  72. op = all_arithmetic_operators
  73. s = pd.Series(data)
  74. ops = getattr(s, op)
  75. # invalid scalars
  76. msg = (
  77. "did not contain a loop with signature matching types|"
  78. "BooleanArray cannot perform the operation|"
  79. "not supported for the input types, and the inputs could not be safely coerced "
  80. "to any supported types according to the casting rule ''safe''"
  81. )
  82. with pytest.raises(TypeError, match=msg):
  83. ops("foo")
  84. msg = "|".join(
  85. [
  86. r"unsupported operand type\(s\) for",
  87. "Concatenation operation is not implemented for NumPy arrays",
  88. ]
  89. )
  90. with pytest.raises(TypeError, match=msg):
  91. ops(pd.Timestamp("20180101"))
  92. # invalid array-likes
  93. if op not in ("__mul__", "__rmul__"):
  94. # TODO(extension) numpy's mul with object array sees booleans as numbers
  95. msg = "|".join(
  96. [
  97. r"unsupported operand type\(s\) for",
  98. "can only concatenate str",
  99. "not all arguments converted during string formatting",
  100. ]
  101. )
  102. with pytest.raises(TypeError, match=msg):
  103. ops(pd.Series("foo", index=s.index))