__init__.py 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168
  1. from __future__ import annotations
  2. import collections
  3. from datetime import datetime
  4. from decimal import Decimal
  5. import operator
  6. import os
  7. import re
  8. import string
  9. from sys import byteorder
  10. from typing import (
  11. TYPE_CHECKING,
  12. Callable,
  13. ContextManager,
  14. Counter,
  15. Iterable,
  16. cast,
  17. )
  18. import numpy as np
  19. from pandas._config.localization import (
  20. can_set_locale,
  21. get_locales,
  22. set_locale,
  23. )
  24. from pandas._typing import (
  25. Dtype,
  26. Frequency,
  27. NpDtype,
  28. )
  29. from pandas.compat import pa_version_under7p0
  30. from pandas.core.dtypes.common import (
  31. is_float_dtype,
  32. is_integer_dtype,
  33. is_sequence,
  34. is_signed_integer_dtype,
  35. is_unsigned_integer_dtype,
  36. pandas_dtype,
  37. )
  38. import pandas as pd
  39. from pandas import (
  40. ArrowDtype,
  41. Categorical,
  42. CategoricalIndex,
  43. DataFrame,
  44. DatetimeIndex,
  45. Index,
  46. IntervalIndex,
  47. MultiIndex,
  48. RangeIndex,
  49. Series,
  50. bdate_range,
  51. )
  52. from pandas._testing._io import (
  53. close,
  54. network,
  55. round_trip_localpath,
  56. round_trip_pathlib,
  57. round_trip_pickle,
  58. write_to_compressed,
  59. )
  60. from pandas._testing._random import (
  61. rands,
  62. rands_array,
  63. )
  64. from pandas._testing._warnings import (
  65. assert_produces_warning,
  66. maybe_produces_warning,
  67. )
  68. from pandas._testing.asserters import (
  69. assert_almost_equal,
  70. assert_attr_equal,
  71. assert_categorical_equal,
  72. assert_class_equal,
  73. assert_contains_all,
  74. assert_copy,
  75. assert_datetime_array_equal,
  76. assert_dict_equal,
  77. assert_equal,
  78. assert_extension_array_equal,
  79. assert_frame_equal,
  80. assert_index_equal,
  81. assert_indexing_slices_equivalent,
  82. assert_interval_array_equal,
  83. assert_is_sorted,
  84. assert_is_valid_plot_return_object,
  85. assert_metadata_equivalent,
  86. assert_numpy_array_equal,
  87. assert_period_array_equal,
  88. assert_series_equal,
  89. assert_sp_array_equal,
  90. assert_timedelta_array_equal,
  91. raise_assert_detail,
  92. )
  93. from pandas._testing.compat import (
  94. get_dtype,
  95. get_obj,
  96. )
  97. from pandas._testing.contexts import (
  98. decompress_file,
  99. ensure_clean,
  100. ensure_safe_environment_variables,
  101. raises_chained_assignment_error,
  102. set_timezone,
  103. use_numexpr,
  104. with_csv_dialect,
  105. )
  106. from pandas.core.arrays import (
  107. BaseMaskedArray,
  108. ExtensionArray,
  109. PandasArray,
  110. )
  111. from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
  112. from pandas.core.construction import extract_array
  113. if TYPE_CHECKING:
  114. from pandas import (
  115. PeriodIndex,
  116. TimedeltaIndex,
  117. )
  118. from pandas.core.arrays import ArrowExtensionArray
  119. _N = 30
  120. _K = 4
  121. UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
  122. UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
  123. SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
  124. SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
  125. ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
  126. ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
  127. ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
  128. FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
  129. FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
  130. ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
  131. COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
  132. STRING_DTYPES: list[Dtype] = [str, "str", "U"]
  133. DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
  134. TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
  135. BOOL_DTYPES: list[Dtype] = [bool, "bool"]
  136. BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
  137. OBJECT_DTYPES: list[Dtype] = [object, "object"]
  138. ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
  139. ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
  140. ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
  141. ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
  142. ALL_NUMPY_DTYPES = (
  143. ALL_REAL_NUMPY_DTYPES
  144. + COMPLEX_DTYPES
  145. + STRING_DTYPES
  146. + DATETIME64_DTYPES
  147. + TIMEDELTA64_DTYPES
  148. + BOOL_DTYPES
  149. + OBJECT_DTYPES
  150. + BYTES_DTYPES
  151. )
  152. NARROW_NP_DTYPES = [
  153. np.float16,
  154. np.float32,
  155. np.int8,
  156. np.int16,
  157. np.int32,
  158. np.uint8,
  159. np.uint16,
  160. np.uint32,
  161. ]
  162. ENDIAN = {"little": "<", "big": ">"}[byteorder]
  163. NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
  164. NP_NAT_OBJECTS = [
  165. cls("NaT", unit)
  166. for cls in [np.datetime64, np.timedelta64]
  167. for unit in [
  168. "Y",
  169. "M",
  170. "W",
  171. "D",
  172. "h",
  173. "m",
  174. "s",
  175. "ms",
  176. "us",
  177. "ns",
  178. "ps",
  179. "fs",
  180. "as",
  181. ]
  182. ]
  183. if not pa_version_under7p0:
  184. import pyarrow as pa
  185. UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
  186. SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
  187. ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
  188. ALL_INT_PYARROW_DTYPES_STR_REPR = [
  189. str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
  190. ]
  191. # pa.float16 doesn't seem supported
  192. # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
  193. FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
  194. FLOAT_PYARROW_DTYPES_STR_REPR = [
  195. str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
  196. ]
  197. DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
  198. STRING_PYARROW_DTYPES = [pa.string()]
  199. BINARY_PYARROW_DTYPES = [pa.binary()]
  200. TIME_PYARROW_DTYPES = [
  201. pa.time32("s"),
  202. pa.time32("ms"),
  203. pa.time64("us"),
  204. pa.time64("ns"),
  205. ]
  206. DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
  207. DATETIME_PYARROW_DTYPES = [
  208. pa.timestamp(unit=unit, tz=tz)
  209. for unit in ["s", "ms", "us", "ns"]
  210. for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
  211. ]
  212. TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
  213. BOOL_PYARROW_DTYPES = [pa.bool_()]
  214. # TODO: Add container like pyarrow types:
  215. # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
  216. ALL_PYARROW_DTYPES = (
  217. ALL_INT_PYARROW_DTYPES
  218. + FLOAT_PYARROW_DTYPES
  219. + DECIMAL_PYARROW_DTYPES
  220. + STRING_PYARROW_DTYPES
  221. + BINARY_PYARROW_DTYPES
  222. + TIME_PYARROW_DTYPES
  223. + DATE_PYARROW_DTYPES
  224. + DATETIME_PYARROW_DTYPES
  225. + TIMEDELTA_PYARROW_DTYPES
  226. + BOOL_PYARROW_DTYPES
  227. )
  228. else:
  229. FLOAT_PYARROW_DTYPES_STR_REPR = []
  230. ALL_INT_PYARROW_DTYPES_STR_REPR = []
  231. ALL_PYARROW_DTYPES = []
  232. EMPTY_STRING_PATTERN = re.compile("^$")
  233. def reset_display_options() -> None:
  234. """
  235. Reset the display options for printing and representing objects.
  236. """
  237. pd.reset_option("^display.", silent=True)
  238. # -----------------------------------------------------------------------------
  239. # Comparators
  240. def equalContents(arr1, arr2) -> bool:
  241. """
  242. Checks if the set of unique elements of arr1 and arr2 are equivalent.
  243. """
  244. return frozenset(arr1) == frozenset(arr2)
  245. def box_expected(expected, box_cls, transpose: bool = True):
  246. """
  247. Helper function to wrap the expected output of a test in a given box_class.
  248. Parameters
  249. ----------
  250. expected : np.ndarray, Index, Series
  251. box_cls : {Index, Series, DataFrame}
  252. Returns
  253. -------
  254. subclass of box_cls
  255. """
  256. if box_cls is pd.array:
  257. if isinstance(expected, RangeIndex):
  258. # pd.array would return an IntegerArray
  259. expected = PandasArray(np.asarray(expected._values))
  260. else:
  261. expected = pd.array(expected, copy=False)
  262. elif box_cls is Index:
  263. expected = Index(expected)
  264. elif box_cls is Series:
  265. expected = Series(expected)
  266. elif box_cls is DataFrame:
  267. expected = Series(expected).to_frame()
  268. if transpose:
  269. # for vector operations, we need a DataFrame to be a single-row,
  270. # not a single-column, in order to operate against non-DataFrame
  271. # vectors of the same length. But convert to two rows to avoid
  272. # single-row special cases in datetime arithmetic
  273. expected = expected.T
  274. expected = pd.concat([expected] * 2, ignore_index=True)
  275. elif box_cls is np.ndarray or box_cls is np.array:
  276. expected = np.array(expected)
  277. elif box_cls is to_array:
  278. expected = to_array(expected)
  279. else:
  280. raise NotImplementedError(box_cls)
  281. return expected
  282. def to_array(obj):
  283. """
  284. Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
  285. """
  286. # temporary implementation until we get pd.array in place
  287. dtype = getattr(obj, "dtype", None)
  288. if dtype is None:
  289. return np.asarray(obj)
  290. return extract_array(obj, extract_numpy=True)
  291. # -----------------------------------------------------------------------------
  292. # Others
  293. def getCols(k) -> str:
  294. return string.ascii_uppercase[:k]
  295. # make index
  296. def makeStringIndex(k: int = 10, name=None) -> Index:
  297. return Index(rands_array(nchars=10, size=k), name=name)
  298. def makeCategoricalIndex(
  299. k: int = 10, n: int = 3, name=None, **kwargs
  300. ) -> CategoricalIndex:
  301. """make a length k index or n categories"""
  302. x = rands_array(nchars=4, size=n, replace=False)
  303. return CategoricalIndex(
  304. Categorical.from_codes(np.arange(k) % n, categories=x), name=name, **kwargs
  305. )
  306. def makeIntervalIndex(k: int = 10, name=None, **kwargs) -> IntervalIndex:
  307. """make a length k IntervalIndex"""
  308. x = np.linspace(0, 100, num=(k + 1))
  309. return IntervalIndex.from_breaks(x, name=name, **kwargs)
  310. def makeBoolIndex(k: int = 10, name=None) -> Index:
  311. if k == 1:
  312. return Index([True], name=name)
  313. elif k == 2:
  314. return Index([False, True], name=name)
  315. return Index([False, True] + [False] * (k - 2), name=name)
  316. def makeNumericIndex(k: int = 10, *, name=None, dtype: Dtype | None) -> Index:
  317. dtype = pandas_dtype(dtype)
  318. assert isinstance(dtype, np.dtype)
  319. if is_integer_dtype(dtype):
  320. values = np.arange(k, dtype=dtype)
  321. if is_unsigned_integer_dtype(dtype):
  322. values += 2 ** (dtype.itemsize * 8 - 1)
  323. elif is_float_dtype(dtype):
  324. values = np.random.random_sample(k) - np.random.random_sample(1)
  325. values.sort()
  326. values = values * (10 ** np.random.randint(0, 9))
  327. else:
  328. raise NotImplementedError(f"wrong dtype {dtype}")
  329. return Index(values, dtype=dtype, name=name)
  330. def makeIntIndex(k: int = 10, *, name=None, dtype: Dtype = "int64") -> Index:
  331. dtype = pandas_dtype(dtype)
  332. if not is_signed_integer_dtype(dtype):
  333. raise TypeError(f"Wrong dtype {dtype}")
  334. return makeNumericIndex(k, name=name, dtype=dtype)
  335. def makeUIntIndex(k: int = 10, *, name=None, dtype: Dtype = "uint64") -> Index:
  336. dtype = pandas_dtype(dtype)
  337. if not is_unsigned_integer_dtype(dtype):
  338. raise TypeError(f"Wrong dtype {dtype}")
  339. return makeNumericIndex(k, name=name, dtype=dtype)
  340. def makeRangeIndex(k: int = 10, name=None, **kwargs) -> RangeIndex:
  341. return RangeIndex(0, k, 1, name=name, **kwargs)
  342. def makeFloatIndex(k: int = 10, *, name=None, dtype: Dtype = "float64") -> Index:
  343. dtype = pandas_dtype(dtype)
  344. if not is_float_dtype(dtype):
  345. raise TypeError(f"Wrong dtype {dtype}")
  346. return makeNumericIndex(k, name=name, dtype=dtype)
  347. def makeDateIndex(
  348. k: int = 10, freq: Frequency = "B", name=None, **kwargs
  349. ) -> DatetimeIndex:
  350. dt = datetime(2000, 1, 1)
  351. dr = bdate_range(dt, periods=k, freq=freq, name=name)
  352. return DatetimeIndex(dr, name=name, **kwargs)
  353. def makeTimedeltaIndex(
  354. k: int = 10, freq: Frequency = "D", name=None, **kwargs
  355. ) -> TimedeltaIndex:
  356. return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs)
  357. def makePeriodIndex(k: int = 10, name=None, **kwargs) -> PeriodIndex:
  358. dt = datetime(2000, 1, 1)
  359. return pd.period_range(start=dt, periods=k, freq="B", name=name, **kwargs)
  360. def makeMultiIndex(k: int = 10, names=None, **kwargs):
  361. N = (k // 2) + 1
  362. rng = range(N)
  363. mi = MultiIndex.from_product([("foo", "bar"), rng], names=names, **kwargs)
  364. assert len(mi) >= k # GH#38795
  365. return mi[:k]
  366. def index_subclass_makers_generator():
  367. make_index_funcs = [
  368. makeDateIndex,
  369. makePeriodIndex,
  370. makeTimedeltaIndex,
  371. makeRangeIndex,
  372. makeIntervalIndex,
  373. makeCategoricalIndex,
  374. makeMultiIndex,
  375. ]
  376. yield from make_index_funcs
  377. def all_timeseries_index_generator(k: int = 10) -> Iterable[Index]:
  378. """
  379. Generator which can be iterated over to get instances of all the classes
  380. which represent time-series.
  381. Parameters
  382. ----------
  383. k: length of each of the index instances
  384. """
  385. make_index_funcs: list[Callable[..., Index]] = [
  386. makeDateIndex,
  387. makePeriodIndex,
  388. makeTimedeltaIndex,
  389. ]
  390. for make_index_func in make_index_funcs:
  391. yield make_index_func(k=k)
  392. # make series
  393. def make_rand_series(name=None, dtype=np.float64) -> Series:
  394. index = makeStringIndex(_N)
  395. data = np.random.randn(_N)
  396. with np.errstate(invalid="ignore"):
  397. data = data.astype(dtype, copy=False)
  398. return Series(data, index=index, name=name)
  399. def makeFloatSeries(name=None) -> Series:
  400. return make_rand_series(name=name)
  401. def makeStringSeries(name=None) -> Series:
  402. return make_rand_series(name=name)
  403. def makeObjectSeries(name=None) -> Series:
  404. data = makeStringIndex(_N)
  405. data = Index(data, dtype=object)
  406. index = makeStringIndex(_N)
  407. return Series(data, index=index, name=name)
  408. def getSeriesData() -> dict[str, Series]:
  409. index = makeStringIndex(_N)
  410. return {c: Series(np.random.randn(_N), index=index) for c in getCols(_K)}
  411. def makeTimeSeries(nper=None, freq: Frequency = "B", name=None) -> Series:
  412. if nper is None:
  413. nper = _N
  414. return Series(
  415. np.random.randn(nper), index=makeDateIndex(nper, freq=freq), name=name
  416. )
  417. def makePeriodSeries(nper=None, name=None) -> Series:
  418. if nper is None:
  419. nper = _N
  420. return Series(np.random.randn(nper), index=makePeriodIndex(nper), name=name)
  421. def getTimeSeriesData(nper=None, freq: Frequency = "B") -> dict[str, Series]:
  422. return {c: makeTimeSeries(nper, freq) for c in getCols(_K)}
  423. def getPeriodData(nper=None) -> dict[str, Series]:
  424. return {c: makePeriodSeries(nper) for c in getCols(_K)}
  425. # make frame
  426. def makeTimeDataFrame(nper=None, freq: Frequency = "B") -> DataFrame:
  427. data = getTimeSeriesData(nper, freq)
  428. return DataFrame(data)
  429. def makeDataFrame() -> DataFrame:
  430. data = getSeriesData()
  431. return DataFrame(data)
  432. def getMixedTypeDict():
  433. index = Index(["a", "b", "c", "d", "e"])
  434. data = {
  435. "A": [0.0, 1.0, 2.0, 3.0, 4.0],
  436. "B": [0.0, 1.0, 0.0, 1.0, 0.0],
  437. "C": ["foo1", "foo2", "foo3", "foo4", "foo5"],
  438. "D": bdate_range("1/1/2009", periods=5),
  439. }
  440. return index, data
  441. def makeMixedDataFrame() -> DataFrame:
  442. return DataFrame(getMixedTypeDict()[1])
  443. def makePeriodFrame(nper=None) -> DataFrame:
  444. data = getPeriodData(nper)
  445. return DataFrame(data)
  446. def makeCustomIndex(
  447. nentries,
  448. nlevels,
  449. prefix: str = "#",
  450. names: bool | str | list[str] | None = False,
  451. ndupe_l=None,
  452. idx_type=None,
  453. ) -> Index:
  454. """
  455. Create an index/multindex with given dimensions, levels, names, etc'
  456. nentries - number of entries in index
  457. nlevels - number of levels (> 1 produces multindex)
  458. prefix - a string prefix for labels
  459. names - (Optional), bool or list of strings. if True will use default
  460. names, if false will use no names, if a list is given, the name of
  461. each level in the index will be taken from the list.
  462. ndupe_l - (Optional), list of ints, the number of rows for which the
  463. label will repeated at the corresponding level, you can specify just
  464. the first few, the rest will use the default ndupe_l of 1.
  465. len(ndupe_l) <= nlevels.
  466. idx_type - "i"/"f"/"s"/"dt"/"p"/"td".
  467. If idx_type is not None, `idx_nlevels` must be 1.
  468. "i"/"f" creates an integer/float index,
  469. "s" creates a string
  470. "dt" create a datetime index.
  471. "td" create a datetime index.
  472. if unspecified, string labels will be generated.
  473. """
  474. if ndupe_l is None:
  475. ndupe_l = [1] * nlevels
  476. assert is_sequence(ndupe_l) and len(ndupe_l) <= nlevels
  477. assert names is None or names is False or names is True or len(names) is nlevels
  478. assert idx_type is None or (
  479. idx_type in ("i", "f", "s", "u", "dt", "p", "td") and nlevels == 1
  480. )
  481. if names is True:
  482. # build default names
  483. names = [prefix + str(i) for i in range(nlevels)]
  484. if names is False:
  485. # pass None to index constructor for no name
  486. names = None
  487. # make singleton case uniform
  488. if isinstance(names, str) and nlevels == 1:
  489. names = [names]
  490. # specific 1D index type requested?
  491. idx_func_dict: dict[str, Callable[..., Index]] = {
  492. "i": makeIntIndex,
  493. "f": makeFloatIndex,
  494. "s": makeStringIndex,
  495. "dt": makeDateIndex,
  496. "td": makeTimedeltaIndex,
  497. "p": makePeriodIndex,
  498. }
  499. idx_func = idx_func_dict.get(idx_type)
  500. if idx_func:
  501. idx = idx_func(nentries)
  502. # but we need to fill in the name
  503. if names:
  504. idx.name = names[0]
  505. return idx
  506. elif idx_type is not None:
  507. raise ValueError(
  508. f"{repr(idx_type)} is not a legal value for `idx_type`, "
  509. "use 'i'/'f'/'s'/'dt'/'p'/'td'."
  510. )
  511. if len(ndupe_l) < nlevels:
  512. ndupe_l.extend([1] * (nlevels - len(ndupe_l)))
  513. assert len(ndupe_l) == nlevels
  514. assert all(x > 0 for x in ndupe_l)
  515. list_of_lists = []
  516. for i in range(nlevels):
  517. def keyfunc(x):
  518. numeric_tuple = re.sub(r"[^\d_]_?", "", x).split("_")
  519. return [int(num) for num in numeric_tuple]
  520. # build a list of lists to create the index from
  521. div_factor = nentries // ndupe_l[i] + 1
  522. # Deprecated since version 3.9: collections.Counter now supports []. See PEP 585
  523. # and Generic Alias Type.
  524. cnt: Counter[str] = collections.Counter()
  525. for j in range(div_factor):
  526. label = f"{prefix}_l{i}_g{j}"
  527. cnt[label] = ndupe_l[i]
  528. # cute Counter trick
  529. result = sorted(cnt.elements(), key=keyfunc)[:nentries]
  530. list_of_lists.append(result)
  531. tuples = list(zip(*list_of_lists))
  532. # convert tuples to index
  533. if nentries == 1:
  534. # we have a single level of tuples, i.e. a regular Index
  535. name = None if names is None else names[0]
  536. index = Index(tuples[0], name=name)
  537. elif nlevels == 1:
  538. name = None if names is None else names[0]
  539. index = Index((x[0] for x in tuples), name=name)
  540. else:
  541. index = MultiIndex.from_tuples(tuples, names=names)
  542. return index
  543. def makeCustomDataframe(
  544. nrows,
  545. ncols,
  546. c_idx_names: bool | list[str] = True,
  547. r_idx_names: bool | list[str] = True,
  548. c_idx_nlevels: int = 1,
  549. r_idx_nlevels: int = 1,
  550. data_gen_f=None,
  551. c_ndupe_l=None,
  552. r_ndupe_l=None,
  553. dtype=None,
  554. c_idx_type=None,
  555. r_idx_type=None,
  556. ) -> DataFrame:
  557. """
  558. Create a DataFrame using supplied parameters.
  559. Parameters
  560. ----------
  561. nrows, ncols - number of data rows/cols
  562. c_idx_names, r_idx_names - False/True/list of strings, yields No names ,
  563. default names or uses the provided names for the levels of the
  564. corresponding index. You can provide a single string when
  565. c_idx_nlevels ==1.
  566. c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex
  567. r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex
  568. data_gen_f - a function f(row,col) which return the data value
  569. at that position, the default generator used yields values of the form
  570. "RxCy" based on position.
  571. c_ndupe_l, r_ndupe_l - list of integers, determines the number
  572. of duplicates for each label at a given level of the corresponding
  573. index. The default `None` value produces a multiplicity of 1 across
  574. all levels, i.e. a unique index. Will accept a partial list of length
  575. N < idx_nlevels, for just the first N levels. If ndupe doesn't divide
  576. nrows/ncol, the last label might have lower multiplicity.
  577. dtype - passed to the DataFrame constructor as is, in case you wish to
  578. have more control in conjunction with a custom `data_gen_f`
  579. r_idx_type, c_idx_type - "i"/"f"/"s"/"dt"/"td".
  580. If idx_type is not None, `idx_nlevels` must be 1.
  581. "i"/"f" creates an integer/float index,
  582. "s" creates a string index
  583. "dt" create a datetime index.
  584. "td" create a timedelta index.
  585. if unspecified, string labels will be generated.
  586. Examples
  587. --------
  588. # 5 row, 3 columns, default names on both, single index on both axis
  589. >> makeCustomDataframe(5,3)
  590. # make the data a random int between 1 and 100
  591. >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100))
  592. # 2-level multiindex on rows with each label duplicated
  593. # twice on first level, default names on both axis, single
  594. # index on both axis
  595. >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2])
  596. # DatetimeIndex on row, index with unicode labels on columns
  597. # no names on either axis
  598. >> a=makeCustomDataframe(5,3,c_idx_names=False,r_idx_names=False,
  599. r_idx_type="dt",c_idx_type="u")
  600. # 4-level multindex on rows with names provided, 2-level multindex
  601. # on columns with default labels and default names.
  602. >> a=makeCustomDataframe(5,3,r_idx_nlevels=4,
  603. r_idx_names=["FEE","FIH","FOH","FUM"],
  604. c_idx_nlevels=2)
  605. >> a=mkdf(5,3,r_idx_nlevels=2,c_idx_nlevels=4)
  606. """
  607. assert c_idx_nlevels > 0
  608. assert r_idx_nlevels > 0
  609. assert r_idx_type is None or (
  610. r_idx_type in ("i", "f", "s", "dt", "p", "td") and r_idx_nlevels == 1
  611. )
  612. assert c_idx_type is None or (
  613. c_idx_type in ("i", "f", "s", "dt", "p", "td") and c_idx_nlevels == 1
  614. )
  615. columns = makeCustomIndex(
  616. ncols,
  617. nlevels=c_idx_nlevels,
  618. prefix="C",
  619. names=c_idx_names,
  620. ndupe_l=c_ndupe_l,
  621. idx_type=c_idx_type,
  622. )
  623. index = makeCustomIndex(
  624. nrows,
  625. nlevels=r_idx_nlevels,
  626. prefix="R",
  627. names=r_idx_names,
  628. ndupe_l=r_ndupe_l,
  629. idx_type=r_idx_type,
  630. )
  631. # by default, generate data based on location
  632. if data_gen_f is None:
  633. data_gen_f = lambda r, c: f"R{r}C{c}"
  634. data = [[data_gen_f(r, c) for c in range(ncols)] for r in range(nrows)]
  635. return DataFrame(data, index, columns, dtype=dtype)
  636. def _create_missing_idx(nrows, ncols, density: float, random_state=None):
  637. if random_state is None:
  638. random_state = np.random
  639. else:
  640. random_state = np.random.RandomState(random_state)
  641. # below is cribbed from scipy.sparse
  642. size = round((1 - density) * nrows * ncols)
  643. # generate a few more to ensure unique values
  644. min_rows = 5
  645. fac = 1.02
  646. extra_size = min(size + min_rows, fac * size)
  647. def _gen_unique_rand(rng, _extra_size):
  648. ind = rng.rand(int(_extra_size))
  649. return np.unique(np.floor(ind * nrows * ncols))[:size]
  650. ind = _gen_unique_rand(random_state, extra_size)
  651. while ind.size < size:
  652. extra_size *= 1.05
  653. ind = _gen_unique_rand(random_state, extra_size)
  654. j = np.floor(ind * 1.0 / nrows).astype(int)
  655. i = (ind - j * nrows).astype(int)
  656. return i.tolist(), j.tolist()
  657. def makeMissingDataframe(density: float = 0.9, random_state=None) -> DataFrame:
  658. df = makeDataFrame()
  659. i, j = _create_missing_idx(*df.shape, density=density, random_state=random_state)
  660. df.iloc[i, j] = np.nan
  661. return df
  662. class SubclassedSeries(Series):
  663. _metadata = ["testattr", "name"]
  664. @property
  665. def _constructor(self):
  666. # For testing, those properties return a generic callable, and not
  667. # the actual class. In this case that is equivalent, but it is to
  668. # ensure we don't rely on the property returning a class
  669. # See https://github.com/pandas-dev/pandas/pull/46018 and
  670. # https://github.com/pandas-dev/pandas/issues/32638 and linked issues
  671. return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
  672. @property
  673. def _constructor_expanddim(self):
  674. return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
  675. class SubclassedDataFrame(DataFrame):
  676. _metadata = ["testattr"]
  677. @property
  678. def _constructor(self):
  679. return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
  680. @property
  681. def _constructor_sliced(self):
  682. return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
  683. class SubclassedCategorical(Categorical):
  684. @property
  685. def _constructor(self):
  686. return SubclassedCategorical
  687. def _make_skipna_wrapper(alternative, skipna_alternative=None):
  688. """
  689. Create a function for calling on an array.
  690. Parameters
  691. ----------
  692. alternative : function
  693. The function to be called on the array with no NaNs.
  694. Only used when 'skipna_alternative' is None.
  695. skipna_alternative : function
  696. The function to be called on the original array
  697. Returns
  698. -------
  699. function
  700. """
  701. if skipna_alternative:
  702. def skipna_wrapper(x):
  703. return skipna_alternative(x.values)
  704. else:
  705. def skipna_wrapper(x):
  706. nona = x.dropna()
  707. if len(nona) == 0:
  708. return np.nan
  709. return alternative(nona)
  710. return skipna_wrapper
  711. def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
  712. """
  713. Convert list of CSV rows to single CSV-formatted string for current OS.
  714. This method is used for creating expected value of to_csv() method.
  715. Parameters
  716. ----------
  717. rows_list : List[str]
  718. Each element represents the row of csv.
  719. Returns
  720. -------
  721. str
  722. Expected output of to_csv() in current OS.
  723. """
  724. sep = os.linesep
  725. return sep.join(rows_list) + sep
  726. def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
  727. """
  728. Helper function to mark pytest.raises that have an external error message.
  729. Parameters
  730. ----------
  731. expected_exception : Exception
  732. Expected error to raise.
  733. Returns
  734. -------
  735. Callable
  736. Regular `pytest.raises` function with `match` equal to `None`.
  737. """
  738. import pytest
  739. return pytest.raises(expected_exception, match=None)
  740. cython_table = pd.core.common._cython_table.items()
  741. def get_cython_table_params(ndframe, func_names_and_expected):
  742. """
  743. Combine frame, functions from com._cython_table
  744. keys and expected result.
  745. Parameters
  746. ----------
  747. ndframe : DataFrame or Series
  748. func_names_and_expected : Sequence of two items
  749. The first item is a name of a NDFrame method ('sum', 'prod') etc.
  750. The second item is the expected return value.
  751. Returns
  752. -------
  753. list
  754. List of three items (DataFrame, function, expected result)
  755. """
  756. results = []
  757. for func_name, expected in func_names_and_expected:
  758. results.append((ndframe, func_name, expected))
  759. results += [
  760. (ndframe, func, expected)
  761. for func, name in cython_table
  762. if name == func_name
  763. ]
  764. return results
  765. def get_op_from_name(op_name: str) -> Callable:
  766. """
  767. The operator function for a given op name.
  768. Parameters
  769. ----------
  770. op_name : str
  771. The op name, in form of "add" or "__add__".
  772. Returns
  773. -------
  774. function
  775. A function performing the operation.
  776. """
  777. short_opname = op_name.strip("_")
  778. try:
  779. op = getattr(operator, short_opname)
  780. except AttributeError:
  781. # Assume it is the reverse operator
  782. rop = getattr(operator, short_opname[1:])
  783. op = lambda x, y: rop(y, x)
  784. return op
  785. # -----------------------------------------------------------------------------
  786. # Indexing test helpers
  787. def getitem(x):
  788. return x
  789. def setitem(x):
  790. return x
  791. def loc(x):
  792. return x.loc
  793. def iloc(x):
  794. return x.iloc
  795. def at(x):
  796. return x.at
  797. def iat(x):
  798. return x.iat
  799. # -----------------------------------------------------------------------------
  800. def shares_memory(left, right) -> bool:
  801. """
  802. Pandas-compat for np.shares_memory.
  803. """
  804. if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
  805. return np.shares_memory(left, right)
  806. elif isinstance(left, np.ndarray):
  807. # Call with reversed args to get to unpacking logic below.
  808. return shares_memory(right, left)
  809. if isinstance(left, RangeIndex):
  810. return False
  811. if isinstance(left, MultiIndex):
  812. return shares_memory(left._codes, right)
  813. if isinstance(left, (Index, Series)):
  814. return shares_memory(left._values, right)
  815. if isinstance(left, NDArrayBackedExtensionArray):
  816. return shares_memory(left._ndarray, right)
  817. if isinstance(left, pd.core.arrays.SparseArray):
  818. return shares_memory(left.sp_values, right)
  819. if isinstance(left, pd.core.arrays.IntervalArray):
  820. return shares_memory(left._left, right) or shares_memory(left._right, right)
  821. if isinstance(left, ExtensionArray) and left.dtype == "string[pyarrow]":
  822. # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
  823. left = cast("ArrowExtensionArray", left)
  824. if isinstance(right, ExtensionArray) and right.dtype == "string[pyarrow]":
  825. right = cast("ArrowExtensionArray", right)
  826. left_pa_data = left._data
  827. right_pa_data = right._data
  828. left_buf1 = left_pa_data.chunk(0).buffers()[1]
  829. right_buf1 = right_pa_data.chunk(0).buffers()[1]
  830. return left_buf1 == right_buf1
  831. if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
  832. # By convention, we'll say these share memory if they share *either*
  833. # the _data or the _mask
  834. return np.shares_memory(left._data, right._data) or np.shares_memory(
  835. left._mask, right._mask
  836. )
  837. if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
  838. arr = left._mgr.arrays[0]
  839. return shares_memory(arr, right)
  840. raise NotImplementedError(type(left), type(right))
  841. __all__ = [
  842. "ALL_INT_EA_DTYPES",
  843. "ALL_INT_NUMPY_DTYPES",
  844. "ALL_NUMPY_DTYPES",
  845. "ALL_REAL_NUMPY_DTYPES",
  846. "all_timeseries_index_generator",
  847. "assert_almost_equal",
  848. "assert_attr_equal",
  849. "assert_categorical_equal",
  850. "assert_class_equal",
  851. "assert_contains_all",
  852. "assert_copy",
  853. "assert_datetime_array_equal",
  854. "assert_dict_equal",
  855. "assert_equal",
  856. "assert_extension_array_equal",
  857. "assert_frame_equal",
  858. "assert_index_equal",
  859. "assert_indexing_slices_equivalent",
  860. "assert_interval_array_equal",
  861. "assert_is_sorted",
  862. "assert_is_valid_plot_return_object",
  863. "assert_metadata_equivalent",
  864. "assert_numpy_array_equal",
  865. "assert_period_array_equal",
  866. "assert_produces_warning",
  867. "assert_series_equal",
  868. "assert_sp_array_equal",
  869. "assert_timedelta_array_equal",
  870. "at",
  871. "BOOL_DTYPES",
  872. "box_expected",
  873. "BYTES_DTYPES",
  874. "can_set_locale",
  875. "close",
  876. "COMPLEX_DTYPES",
  877. "convert_rows_list_to_csv_str",
  878. "DATETIME64_DTYPES",
  879. "decompress_file",
  880. "EMPTY_STRING_PATTERN",
  881. "ENDIAN",
  882. "ensure_clean",
  883. "ensure_safe_environment_variables",
  884. "equalContents",
  885. "external_error_raised",
  886. "FLOAT_EA_DTYPES",
  887. "FLOAT_NUMPY_DTYPES",
  888. "getCols",
  889. "get_cython_table_params",
  890. "get_dtype",
  891. "getitem",
  892. "get_locales",
  893. "getMixedTypeDict",
  894. "get_obj",
  895. "get_op_from_name",
  896. "getPeriodData",
  897. "getSeriesData",
  898. "getTimeSeriesData",
  899. "iat",
  900. "iloc",
  901. "index_subclass_makers_generator",
  902. "loc",
  903. "makeBoolIndex",
  904. "makeCategoricalIndex",
  905. "makeCustomDataframe",
  906. "makeCustomIndex",
  907. "makeDataFrame",
  908. "makeDateIndex",
  909. "makeFloatIndex",
  910. "makeFloatSeries",
  911. "makeIntervalIndex",
  912. "makeIntIndex",
  913. "makeMissingDataframe",
  914. "makeMixedDataFrame",
  915. "makeMultiIndex",
  916. "makeNumericIndex",
  917. "makeObjectSeries",
  918. "makePeriodFrame",
  919. "makePeriodIndex",
  920. "makePeriodSeries",
  921. "make_rand_series",
  922. "makeRangeIndex",
  923. "makeStringIndex",
  924. "makeStringSeries",
  925. "makeTimeDataFrame",
  926. "makeTimedeltaIndex",
  927. "makeTimeSeries",
  928. "makeUIntIndex",
  929. "maybe_produces_warning",
  930. "NARROW_NP_DTYPES",
  931. "network",
  932. "NP_NAT_OBJECTS",
  933. "NULL_OBJECTS",
  934. "OBJECT_DTYPES",
  935. "raise_assert_detail",
  936. "rands",
  937. "reset_display_options",
  938. "raises_chained_assignment_error",
  939. "round_trip_localpath",
  940. "round_trip_pathlib",
  941. "round_trip_pickle",
  942. "setitem",
  943. "set_locale",
  944. "set_timezone",
  945. "shares_memory",
  946. "SIGNED_INT_EA_DTYPES",
  947. "SIGNED_INT_NUMPY_DTYPES",
  948. "STRING_DTYPES",
  949. "SubclassedCategorical",
  950. "SubclassedDataFrame",
  951. "SubclassedSeries",
  952. "TIMEDELTA64_DTYPES",
  953. "to_array",
  954. "UNSIGNED_INT_EA_DTYPES",
  955. "UNSIGNED_INT_NUMPY_DTYPES",
  956. "use_numexpr",
  957. "with_csv_dialect",
  958. "write_to_compressed",
  959. ]