test_json.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. import collections
  2. import operator
  3. import sys
  4. import pytest
  5. import pandas as pd
  6. import pandas._testing as tm
  7. from pandas.tests.extension import base
  8. from pandas.tests.extension.json.array import (
  9. JSONArray,
  10. JSONDtype,
  11. make_data,
  12. )
  13. @pytest.fixture
  14. def dtype():
  15. return JSONDtype()
  16. @pytest.fixture
  17. def data():
  18. """Length-100 PeriodArray for semantics test."""
  19. data = make_data()
  20. # Why the while loop? NumPy is unable to construct an ndarray from
  21. # equal-length ndarrays. Many of our operations involve coercing the
  22. # EA to an ndarray of objects. To avoid random test failures, we ensure
  23. # that our data is coercible to an ndarray. Several tests deal with only
  24. # the first two elements, so that's what we'll check.
  25. while len(data[0]) == len(data[1]):
  26. data = make_data()
  27. return JSONArray(data)
  28. @pytest.fixture
  29. def data_missing():
  30. """Length 2 array with [NA, Valid]"""
  31. return JSONArray([{}, {"a": 10}])
  32. @pytest.fixture
  33. def data_for_sorting():
  34. return JSONArray([{"b": 1}, {"c": 4}, {"a": 2, "c": 3}])
  35. @pytest.fixture
  36. def data_missing_for_sorting():
  37. return JSONArray([{"b": 1}, {}, {"a": 4}])
  38. @pytest.fixture
  39. def na_value(dtype):
  40. return dtype.na_value
  41. @pytest.fixture
  42. def na_cmp():
  43. return operator.eq
  44. @pytest.fixture
  45. def data_for_grouping():
  46. return JSONArray(
  47. [
  48. {"b": 1},
  49. {"b": 1},
  50. {},
  51. {},
  52. {"a": 0, "c": 2},
  53. {"a": 0, "c": 2},
  54. {"b": 1},
  55. {"c": 2},
  56. ]
  57. )
  58. class BaseJSON:
  59. # NumPy doesn't handle an array of equal-length UserDicts.
  60. # The default assert_series_equal eventually does a
  61. # Series.values, which raises. We work around it by
  62. # converting the UserDicts to dicts.
  63. @classmethod
  64. def assert_series_equal(cls, left, right, *args, **kwargs):
  65. if left.dtype.name == "json":
  66. assert left.dtype == right.dtype
  67. left = pd.Series(
  68. JSONArray(left.values.astype(object)), index=left.index, name=left.name
  69. )
  70. right = pd.Series(
  71. JSONArray(right.values.astype(object)),
  72. index=right.index,
  73. name=right.name,
  74. )
  75. tm.assert_series_equal(left, right, *args, **kwargs)
  76. @classmethod
  77. def assert_frame_equal(cls, left, right, *args, **kwargs):
  78. obj_type = kwargs.get("obj", "DataFrame")
  79. tm.assert_index_equal(
  80. left.columns,
  81. right.columns,
  82. exact=kwargs.get("check_column_type", "equiv"),
  83. check_names=kwargs.get("check_names", True),
  84. check_exact=kwargs.get("check_exact", False),
  85. check_categorical=kwargs.get("check_categorical", True),
  86. obj=f"{obj_type}.columns",
  87. )
  88. jsons = (left.dtypes == "json").index
  89. for col in jsons:
  90. cls.assert_series_equal(left[col], right[col], *args, **kwargs)
  91. left = left.drop(columns=jsons)
  92. right = right.drop(columns=jsons)
  93. tm.assert_frame_equal(left, right, *args, **kwargs)
  94. class TestDtype(BaseJSON, base.BaseDtypeTests):
  95. pass
  96. class TestInterface(BaseJSON, base.BaseInterfaceTests):
  97. def test_custom_asserts(self):
  98. # This would always trigger the KeyError from trying to put
  99. # an array of equal-length UserDicts inside an ndarray.
  100. data = JSONArray(
  101. [
  102. collections.UserDict({"a": 1}),
  103. collections.UserDict({"b": 2}),
  104. collections.UserDict({"c": 3}),
  105. ]
  106. )
  107. a = pd.Series(data)
  108. self.assert_series_equal(a, a)
  109. self.assert_frame_equal(a.to_frame(), a.to_frame())
  110. b = pd.Series(data.take([0, 0, 1]))
  111. msg = r"Series are different"
  112. with pytest.raises(AssertionError, match=msg):
  113. self.assert_series_equal(a, b)
  114. with pytest.raises(AssertionError, match=msg):
  115. self.assert_frame_equal(a.to_frame(), b.to_frame())
  116. @pytest.mark.xfail(
  117. reason="comparison method not implemented for JSONArray (GH-37867)"
  118. )
  119. def test_contains(self, data):
  120. # GH-37867
  121. super().test_contains(data)
  122. class TestConstructors(BaseJSON, base.BaseConstructorsTests):
  123. @pytest.mark.xfail(reason="not implemented constructor from dtype")
  124. def test_from_dtype(self, data):
  125. # construct from our dtype & string dtype
  126. super().test_from_dtype(data)
  127. @pytest.mark.xfail(reason="RecursionError, GH-33900")
  128. def test_series_constructor_no_data_with_index(self, dtype, na_value):
  129. # RecursionError: maximum recursion depth exceeded in comparison
  130. rec_limit = sys.getrecursionlimit()
  131. try:
  132. # Limit to avoid stack overflow on Windows CI
  133. sys.setrecursionlimit(100)
  134. super().test_series_constructor_no_data_with_index(dtype, na_value)
  135. finally:
  136. sys.setrecursionlimit(rec_limit)
  137. @pytest.mark.xfail(reason="RecursionError, GH-33900")
  138. def test_series_constructor_scalar_na_with_index(self, dtype, na_value):
  139. # RecursionError: maximum recursion depth exceeded in comparison
  140. rec_limit = sys.getrecursionlimit()
  141. try:
  142. # Limit to avoid stack overflow on Windows CI
  143. sys.setrecursionlimit(100)
  144. super().test_series_constructor_scalar_na_with_index(dtype, na_value)
  145. finally:
  146. sys.setrecursionlimit(rec_limit)
  147. @pytest.mark.xfail(reason="collection as scalar, GH-33901")
  148. def test_series_constructor_scalar_with_index(self, data, dtype):
  149. # TypeError: All values must be of type <class 'collections.abc.Mapping'>
  150. rec_limit = sys.getrecursionlimit()
  151. try:
  152. # Limit to avoid stack overflow on Windows CI
  153. sys.setrecursionlimit(100)
  154. super().test_series_constructor_scalar_with_index(data, dtype)
  155. finally:
  156. sys.setrecursionlimit(rec_limit)
  157. class TestReshaping(BaseJSON, base.BaseReshapingTests):
  158. @pytest.mark.xfail(reason="Different definitions of NA")
  159. def test_stack(self):
  160. """
  161. The test does .astype(object).stack(). If we happen to have
  162. any missing values in `data`, then we'll end up with different
  163. rows since we consider `{}` NA, but `.astype(object)` doesn't.
  164. """
  165. super().test_stack()
  166. @pytest.mark.xfail(reason="dict for NA")
  167. def test_unstack(self, data, index):
  168. # The base test has NaN for the expected NA value.
  169. # this matches otherwise
  170. return super().test_unstack(data, index)
  171. class TestGetitem(BaseJSON, base.BaseGetitemTests):
  172. pass
  173. class TestIndex(BaseJSON, base.BaseIndexTests):
  174. pass
  175. class TestMissing(BaseJSON, base.BaseMissingTests):
  176. @pytest.mark.xfail(reason="Setting a dict as a scalar")
  177. def test_fillna_series(self):
  178. """We treat dictionaries as a mapping in fillna, not a scalar."""
  179. super().test_fillna_series()
  180. @pytest.mark.xfail(reason="Setting a dict as a scalar")
  181. def test_fillna_frame(self):
  182. """We treat dictionaries as a mapping in fillna, not a scalar."""
  183. super().test_fillna_frame()
  184. unhashable = pytest.mark.xfail(reason="Unhashable")
  185. class TestReduce(base.BaseNoReduceTests):
  186. pass
  187. class TestMethods(BaseJSON, base.BaseMethodsTests):
  188. @unhashable
  189. def test_value_counts(self, all_data, dropna):
  190. super().test_value_counts(all_data, dropna)
  191. @unhashable
  192. def test_value_counts_with_normalize(self, data):
  193. super().test_value_counts_with_normalize(data)
  194. @unhashable
  195. def test_sort_values_frame(self):
  196. # TODO (EA.factorize): see if _values_for_factorize allows this.
  197. super().test_sort_values_frame()
  198. @pytest.mark.parametrize("ascending", [True, False])
  199. def test_sort_values(self, data_for_sorting, ascending, sort_by_key):
  200. super().test_sort_values(data_for_sorting, ascending, sort_by_key)
  201. @pytest.mark.parametrize("ascending", [True, False])
  202. def test_sort_values_missing(
  203. self, data_missing_for_sorting, ascending, sort_by_key
  204. ):
  205. super().test_sort_values_missing(
  206. data_missing_for_sorting, ascending, sort_by_key
  207. )
  208. @pytest.mark.xfail(reason="combine for JSONArray not supported")
  209. def test_combine_le(self, data_repeated):
  210. super().test_combine_le(data_repeated)
  211. @pytest.mark.xfail(reason="combine for JSONArray not supported")
  212. def test_combine_add(self, data_repeated):
  213. super().test_combine_add(data_repeated)
  214. @pytest.mark.xfail(
  215. reason="combine for JSONArray not supported - "
  216. "may pass depending on random data",
  217. strict=False,
  218. raises=AssertionError,
  219. )
  220. def test_combine_first(self, data):
  221. super().test_combine_first(data)
  222. @unhashable
  223. def test_hash_pandas_object_works(self, data, kind):
  224. super().test_hash_pandas_object_works(data, kind)
  225. @pytest.mark.xfail(reason="broadcasting error")
  226. def test_where_series(self, data, na_value):
  227. # Fails with
  228. # *** ValueError: operands could not be broadcast together
  229. # with shapes (4,) (4,) (0,)
  230. super().test_where_series(data, na_value)
  231. @pytest.mark.xfail(reason="Can't compare dicts.")
  232. def test_searchsorted(self, data_for_sorting):
  233. super().test_searchsorted(data_for_sorting)
  234. @pytest.mark.xfail(reason="Can't compare dicts.")
  235. def test_equals(self, data, na_value, as_series):
  236. super().test_equals(data, na_value, as_series)
  237. @pytest.mark.skip("fill-value is interpreted as a dict of values")
  238. def test_fillna_copy_frame(self, data_missing):
  239. super().test_fillna_copy_frame(data_missing)
  240. class TestCasting(BaseJSON, base.BaseCastingTests):
  241. @pytest.mark.xfail(reason="failing on np.array(self, dtype=str)")
  242. def test_astype_str(self):
  243. """This currently fails in NumPy on np.array(self, dtype=str) with
  244. *** ValueError: setting an array element with a sequence
  245. """
  246. super().test_astype_str()
  247. # We intentionally don't run base.BaseSetitemTests because pandas'
  248. # internals has trouble setting sequences of values into scalar positions.
  249. class TestGroupby(BaseJSON, base.BaseGroupbyTests):
  250. @unhashable
  251. def test_groupby_extension_transform(self):
  252. """
  253. This currently fails in Series.name.setter, since the
  254. name must be hashable, but the value is a dictionary.
  255. I think this is what we want, i.e. `.name` should be the original
  256. values, and not the values for factorization.
  257. """
  258. super().test_groupby_extension_transform()
  259. @unhashable
  260. def test_groupby_extension_apply(self):
  261. """
  262. This fails in Index._do_unique_check with
  263. > hash(val)
  264. E TypeError: unhashable type: 'UserDict' with
  265. I suspect that once we support Index[ExtensionArray],
  266. we'll be able to dispatch unique.
  267. """
  268. super().test_groupby_extension_apply()
  269. @unhashable
  270. def test_groupby_extension_agg(self):
  271. """
  272. This fails when we get to tm.assert_series_equal when left.index
  273. contains dictionaries, which are not hashable.
  274. """
  275. super().test_groupby_extension_agg()
  276. @unhashable
  277. def test_groupby_extension_no_sort(self):
  278. """
  279. This fails when we get to tm.assert_series_equal when left.index
  280. contains dictionaries, which are not hashable.
  281. """
  282. super().test_groupby_extension_no_sort()
  283. @pytest.mark.xfail(reason="GH#39098: Converts agg result to object")
  284. def test_groupby_agg_extension(self, data_for_grouping):
  285. super().test_groupby_agg_extension(data_for_grouping)
  286. class TestArithmeticOps(BaseJSON, base.BaseArithmeticOpsTests):
  287. def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
  288. if len(data[0]) != 1:
  289. mark = pytest.mark.xfail(reason="raises in coercing to Series")
  290. request.node.add_marker(mark)
  291. super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
  292. def test_add_series_with_extension_array(self, data):
  293. ser = pd.Series(data)
  294. with pytest.raises(TypeError, match="unsupported"):
  295. ser + data
  296. @pytest.mark.xfail(reason="not implemented")
  297. def test_divmod_series_array(self):
  298. # GH 23287
  299. # skipping because it is not implemented
  300. super().test_divmod_series_array()
  301. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  302. return super()._check_divmod_op(s, op, other, exc=TypeError)
  303. class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
  304. pass
  305. class TestPrinting(BaseJSON, base.BasePrintingTests):
  306. pass