test_numpy_compat.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. CategoricalIndex,
  5. DatetimeIndex,
  6. Index,
  7. PeriodIndex,
  8. TimedeltaIndex,
  9. isna,
  10. )
  11. import pandas._testing as tm
  12. from pandas.api.types import (
  13. is_complex_dtype,
  14. is_numeric_dtype,
  15. )
  16. from pandas.core.arrays import BooleanArray
  17. from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
  18. def test_numpy_ufuncs_out(index):
  19. result = index == index
  20. out = np.empty(index.shape, dtype=bool)
  21. np.equal(index, index, out=out)
  22. tm.assert_numpy_array_equal(out, result)
  23. if not index._is_multi:
  24. # same thing on the ExtensionArray
  25. out = np.empty(index.shape, dtype=bool)
  26. np.equal(index.array, index.array, out=out)
  27. tm.assert_numpy_array_equal(out, result)
  28. @pytest.mark.parametrize(
  29. "func",
  30. [
  31. np.exp,
  32. np.exp2,
  33. np.expm1,
  34. np.log,
  35. np.log2,
  36. np.log10,
  37. np.log1p,
  38. np.sqrt,
  39. np.sin,
  40. np.cos,
  41. np.tan,
  42. np.arcsin,
  43. np.arccos,
  44. np.arctan,
  45. np.sinh,
  46. np.cosh,
  47. np.tanh,
  48. np.arcsinh,
  49. np.arccosh,
  50. np.arctanh,
  51. np.deg2rad,
  52. np.rad2deg,
  53. ],
  54. ids=lambda x: x.__name__,
  55. )
  56. def test_numpy_ufuncs_basic(index, func):
  57. # test ufuncs of numpy, see:
  58. # https://numpy.org/doc/stable/reference/ufuncs.html
  59. if isinstance(index, DatetimeIndexOpsMixin):
  60. with tm.external_error_raised((TypeError, AttributeError)):
  61. with np.errstate(all="ignore"):
  62. func(index)
  63. elif is_numeric_dtype(index) and not (
  64. is_complex_dtype(index) and func in [np.deg2rad, np.rad2deg]
  65. ):
  66. # coerces to float (e.g. np.sin)
  67. with np.errstate(all="ignore"):
  68. result = func(index)
  69. arr_result = func(index.values)
  70. if arr_result.dtype == np.float16:
  71. arr_result = arr_result.astype(np.float32)
  72. exp = Index(arr_result, name=index.name)
  73. tm.assert_index_equal(result, exp)
  74. if isinstance(index.dtype, np.dtype) and is_numeric_dtype(index):
  75. if is_complex_dtype(index):
  76. assert result.dtype == index.dtype
  77. elif index.dtype in ["bool", "int8", "uint8"]:
  78. assert result.dtype in ["float16", "float32"]
  79. elif index.dtype in ["int16", "uint16", "float32"]:
  80. assert result.dtype == "float32"
  81. else:
  82. assert result.dtype == "float64"
  83. else:
  84. # e.g. np.exp with Int64 -> Float64
  85. assert type(result) is Index
  86. else:
  87. # raise AttributeError or TypeError
  88. if len(index) == 0:
  89. pass
  90. else:
  91. with tm.external_error_raised((TypeError, AttributeError)):
  92. with np.errstate(all="ignore"):
  93. func(index)
  94. @pytest.mark.parametrize(
  95. "func", [np.isfinite, np.isinf, np.isnan, np.signbit], ids=lambda x: x.__name__
  96. )
  97. def test_numpy_ufuncs_other(index, func):
  98. # test ufuncs of numpy, see:
  99. # https://numpy.org/doc/stable/reference/ufuncs.html
  100. if isinstance(index, (DatetimeIndex, TimedeltaIndex)):
  101. if func in (np.isfinite, np.isinf, np.isnan):
  102. # numpy 1.18 changed isinf and isnan to not raise on dt64/td64
  103. result = func(index)
  104. assert isinstance(result, np.ndarray)
  105. out = np.empty(index.shape, dtype=bool)
  106. func(index, out=out)
  107. tm.assert_numpy_array_equal(out, result)
  108. else:
  109. with tm.external_error_raised(TypeError):
  110. func(index)
  111. elif isinstance(index, PeriodIndex):
  112. with tm.external_error_raised(TypeError):
  113. func(index)
  114. elif is_numeric_dtype(index) and not (
  115. is_complex_dtype(index) and func is np.signbit
  116. ):
  117. # Results in bool array
  118. result = func(index)
  119. if not isinstance(index.dtype, np.dtype):
  120. # e.g. Int64 we expect to get BooleanArray back
  121. assert isinstance(result, BooleanArray)
  122. else:
  123. assert isinstance(result, np.ndarray)
  124. out = np.empty(index.shape, dtype=bool)
  125. func(index, out=out)
  126. if not isinstance(index.dtype, np.dtype):
  127. tm.assert_numpy_array_equal(out, result._data)
  128. else:
  129. tm.assert_numpy_array_equal(out, result)
  130. else:
  131. if len(index) == 0:
  132. pass
  133. else:
  134. with tm.external_error_raised(TypeError):
  135. func(index)
  136. @pytest.mark.parametrize("func", [np.maximum, np.minimum])
  137. def test_numpy_ufuncs_reductions(index, func, request):
  138. # TODO: overlap with tests.series.test_ufunc.test_reductions
  139. if len(index) == 0:
  140. return
  141. if isinstance(index, CategoricalIndex) and index.dtype.ordered is False:
  142. with pytest.raises(TypeError, match="is not ordered for"):
  143. func.reduce(index)
  144. return
  145. else:
  146. result = func.reduce(index)
  147. if func is np.maximum:
  148. expected = index.max(skipna=False)
  149. else:
  150. expected = index.min(skipna=False)
  151. # TODO: do we have cases both with and without NAs?
  152. assert type(result) is type(expected)
  153. if isna(result):
  154. assert isna(expected)
  155. else:
  156. assert result == expected
  157. @pytest.mark.parametrize("func", [np.bitwise_and, np.bitwise_or, np.bitwise_xor])
  158. def test_numpy_ufuncs_bitwise(func):
  159. # https://github.com/pandas-dev/pandas/issues/46769
  160. idx1 = Index([1, 2, 3, 4], dtype="int64")
  161. idx2 = Index([3, 4, 5, 6], dtype="int64")
  162. with tm.assert_produces_warning(None):
  163. result = func(idx1, idx2)
  164. expected = Index(func(idx1.values, idx2.values))
  165. tm.assert_index_equal(result, expected)