common.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. """
  2. Assertion helpers for arithmetic tests.
  3. """
  4. import numpy as np
  5. import pytest
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. Series,
  10. array,
  11. )
  12. import pandas._testing as tm
  13. from pandas.core.arrays import (
  14. BooleanArray,
  15. PandasArray,
  16. )
  17. def assert_cannot_add(left, right, msg="cannot add"):
  18. """
  19. Helper to assert that left and right cannot be added.
  20. Parameters
  21. ----------
  22. left : object
  23. right : object
  24. msg : str, default "cannot add"
  25. """
  26. with pytest.raises(TypeError, match=msg):
  27. left + right
  28. with pytest.raises(TypeError, match=msg):
  29. right + left
  30. def assert_invalid_addsub_type(left, right, msg=None):
  31. """
  32. Helper to assert that left and right can be neither added nor subtracted.
  33. Parameters
  34. ----------
  35. left : object
  36. right : object
  37. msg : str or None, default None
  38. """
  39. with pytest.raises(TypeError, match=msg):
  40. left + right
  41. with pytest.raises(TypeError, match=msg):
  42. right + left
  43. with pytest.raises(TypeError, match=msg):
  44. left - right
  45. with pytest.raises(TypeError, match=msg):
  46. right - left
  47. def get_upcast_box(left, right, is_cmp: bool = False):
  48. """
  49. Get the box to use for 'expected' in an arithmetic or comparison operation.
  50. Parameters
  51. left : Any
  52. right : Any
  53. is_cmp : bool, default False
  54. Whether the operation is a comparison method.
  55. """
  56. if isinstance(left, DataFrame) or isinstance(right, DataFrame):
  57. return DataFrame
  58. if isinstance(left, Series) or isinstance(right, Series):
  59. if is_cmp and isinstance(left, Index):
  60. # Index does not defer for comparisons
  61. return np.array
  62. return Series
  63. if isinstance(left, Index) or isinstance(right, Index):
  64. if is_cmp:
  65. return np.array
  66. return Index
  67. return tm.to_array
  68. def assert_invalid_comparison(left, right, box):
  69. """
  70. Assert that comparison operations with mismatched types behave correctly.
  71. Parameters
  72. ----------
  73. left : np.ndarray, ExtensionArray, Index, or Series
  74. right : object
  75. box : {pd.DataFrame, pd.Series, pd.Index, pd.array, tm.to_array}
  76. """
  77. # Not for tznaive-tzaware comparison
  78. # Note: not quite the same as how we do this for tm.box_expected
  79. xbox = box if box not in [Index, array] else np.array
  80. def xbox2(x):
  81. # Eventually we'd like this to be tighter, but for now we'll
  82. # just exclude PandasArray[bool]
  83. if isinstance(x, PandasArray):
  84. return x._ndarray
  85. if isinstance(x, BooleanArray):
  86. # NB: we are assuming no pd.NAs for now
  87. return x.astype(bool)
  88. return x
  89. # rev_box: box to use for reversed comparisons
  90. rev_box = xbox
  91. if isinstance(right, Index) and isinstance(left, Series):
  92. rev_box = np.array
  93. result = xbox2(left == right)
  94. expected = xbox(np.zeros(result.shape, dtype=np.bool_))
  95. tm.assert_equal(result, expected)
  96. result = xbox2(right == left)
  97. tm.assert_equal(result, rev_box(expected))
  98. result = xbox2(left != right)
  99. tm.assert_equal(result, ~expected)
  100. result = xbox2(right != left)
  101. tm.assert_equal(result, rev_box(~expected))
  102. msg = "|".join(
  103. [
  104. "Invalid comparison between",
  105. "Cannot compare type",
  106. "not supported between",
  107. "invalid type promotion",
  108. (
  109. # GH#36706 npdev 1.20.0 2020-09-28
  110. r"The DTypes <class 'numpy.dtype\[datetime64\]'> and "
  111. r"<class 'numpy.dtype\[int64\]'> do not have a common DType. "
  112. "For example they cannot be stored in a single array unless the "
  113. "dtype is `object`."
  114. ),
  115. ]
  116. )
  117. with pytest.raises(TypeError, match=msg):
  118. left < right
  119. with pytest.raises(TypeError, match=msg):
  120. left <= right
  121. with pytest.raises(TypeError, match=msg):
  122. left > right
  123. with pytest.raises(TypeError, match=msg):
  124. left >= right
  125. with pytest.raises(TypeError, match=msg):
  126. right < left
  127. with pytest.raises(TypeError, match=msg):
  128. right <= left
  129. with pytest.raises(TypeError, match=msg):
  130. right > left
  131. with pytest.raises(TypeError, match=msg):
  132. right >= left