api.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. from __future__ import annotations
  2. import textwrap
  3. from typing import cast
  4. import numpy as np
  5. from pandas._libs import (
  6. NaT,
  7. lib,
  8. )
  9. from pandas._typing import Axis
  10. from pandas.errors import InvalidIndexError
  11. from pandas.core.dtypes.cast import find_common_type
  12. from pandas.core.algorithms import safe_sort
  13. from pandas.core.indexes.base import (
  14. Index,
  15. _new_Index,
  16. ensure_index,
  17. ensure_index_from_sequences,
  18. get_unanimous_names,
  19. )
  20. from pandas.core.indexes.category import CategoricalIndex
  21. from pandas.core.indexes.datetimes import DatetimeIndex
  22. from pandas.core.indexes.interval import IntervalIndex
  23. from pandas.core.indexes.multi import MultiIndex
  24. from pandas.core.indexes.period import PeriodIndex
  25. from pandas.core.indexes.range import RangeIndex
  26. from pandas.core.indexes.timedeltas import TimedeltaIndex
  27. _sort_msg = textwrap.dedent(
  28. """\
  29. Sorting because non-concatenation axis is not aligned. A future version
  30. of pandas will change to not sort by default.
  31. To accept the future behavior, pass 'sort=False'.
  32. To retain the current behavior and silence the warning, pass 'sort=True'.
  33. """
  34. )
  35. __all__ = [
  36. "Index",
  37. "MultiIndex",
  38. "CategoricalIndex",
  39. "IntervalIndex",
  40. "RangeIndex",
  41. "InvalidIndexError",
  42. "TimedeltaIndex",
  43. "PeriodIndex",
  44. "DatetimeIndex",
  45. "_new_Index",
  46. "NaT",
  47. "ensure_index",
  48. "ensure_index_from_sequences",
  49. "get_objs_combined_axis",
  50. "union_indexes",
  51. "get_unanimous_names",
  52. "all_indexes_same",
  53. "default_index",
  54. "safe_sort_index",
  55. ]
  56. def get_objs_combined_axis(
  57. objs, intersect: bool = False, axis: Axis = 0, sort: bool = True, copy: bool = False
  58. ) -> Index:
  59. """
  60. Extract combined index: return intersection or union (depending on the
  61. value of "intersect") of indexes on given axis, or None if all objects
  62. lack indexes (e.g. they are numpy arrays).
  63. Parameters
  64. ----------
  65. objs : list
  66. Series or DataFrame objects, may be mix of the two.
  67. intersect : bool, default False
  68. If True, calculate the intersection between indexes. Otherwise,
  69. calculate the union.
  70. axis : {0 or 'index', 1 or 'outer'}, default 0
  71. The axis to extract indexes from.
  72. sort : bool, default True
  73. Whether the result index should come out sorted or not.
  74. copy : bool, default False
  75. If True, return a copy of the combined index.
  76. Returns
  77. -------
  78. Index
  79. """
  80. obs_idxes = [obj._get_axis(axis) for obj in objs]
  81. return _get_combined_index(obs_idxes, intersect=intersect, sort=sort, copy=copy)
  82. def _get_distinct_objs(objs: list[Index]) -> list[Index]:
  83. """
  84. Return a list with distinct elements of "objs" (different ids).
  85. Preserves order.
  86. """
  87. ids: set[int] = set()
  88. res = []
  89. for obj in objs:
  90. if id(obj) not in ids:
  91. ids.add(id(obj))
  92. res.append(obj)
  93. return res
  94. def _get_combined_index(
  95. indexes: list[Index],
  96. intersect: bool = False,
  97. sort: bool = False,
  98. copy: bool = False,
  99. ) -> Index:
  100. """
  101. Return the union or intersection of indexes.
  102. Parameters
  103. ----------
  104. indexes : list of Index or list objects
  105. When intersect=True, do not accept list of lists.
  106. intersect : bool, default False
  107. If True, calculate the intersection between indexes. Otherwise,
  108. calculate the union.
  109. sort : bool, default False
  110. Whether the result index should come out sorted or not.
  111. copy : bool, default False
  112. If True, return a copy of the combined index.
  113. Returns
  114. -------
  115. Index
  116. """
  117. # TODO: handle index names!
  118. indexes = _get_distinct_objs(indexes)
  119. if len(indexes) == 0:
  120. index = Index([])
  121. elif len(indexes) == 1:
  122. index = indexes[0]
  123. elif intersect:
  124. index = indexes[0]
  125. for other in indexes[1:]:
  126. index = index.intersection(other)
  127. else:
  128. index = union_indexes(indexes, sort=False)
  129. index = ensure_index(index)
  130. if sort:
  131. index = safe_sort_index(index)
  132. # GH 29879
  133. if copy:
  134. index = index.copy()
  135. return index
  136. def safe_sort_index(index: Index) -> Index:
  137. """
  138. Returns the sorted index
  139. We keep the dtypes and the name attributes.
  140. Parameters
  141. ----------
  142. index : an Index
  143. Returns
  144. -------
  145. Index
  146. """
  147. if index.is_monotonic_increasing:
  148. return index
  149. try:
  150. array_sorted = safe_sort(index)
  151. except TypeError:
  152. pass
  153. else:
  154. if isinstance(array_sorted, Index):
  155. return array_sorted
  156. array_sorted = cast(np.ndarray, array_sorted)
  157. if isinstance(index, MultiIndex):
  158. index = MultiIndex.from_tuples(array_sorted, names=index.names)
  159. else:
  160. index = Index(array_sorted, name=index.name, dtype=index.dtype)
  161. return index
  162. def union_indexes(indexes, sort: bool | None = True) -> Index:
  163. """
  164. Return the union of indexes.
  165. The behavior of sort and names is not consistent.
  166. Parameters
  167. ----------
  168. indexes : list of Index or list objects
  169. sort : bool, default True
  170. Whether the result index should come out sorted or not.
  171. Returns
  172. -------
  173. Index
  174. """
  175. if len(indexes) == 0:
  176. raise AssertionError("Must have at least 1 Index to union")
  177. if len(indexes) == 1:
  178. result = indexes[0]
  179. if isinstance(result, list):
  180. result = Index(sorted(result))
  181. return result
  182. indexes, kind = _sanitize_and_check(indexes)
  183. def _unique_indices(inds, dtype) -> Index:
  184. """
  185. Convert indexes to lists and concatenate them, removing duplicates.
  186. The final dtype is inferred.
  187. Parameters
  188. ----------
  189. inds : list of Index or list objects
  190. dtype : dtype to set for the resulting Index
  191. Returns
  192. -------
  193. Index
  194. """
  195. def conv(i):
  196. if isinstance(i, Index):
  197. i = i.tolist()
  198. return i
  199. return Index(
  200. lib.fast_unique_multiple_list([conv(i) for i in inds], sort=sort),
  201. dtype=dtype,
  202. )
  203. def _find_common_index_dtype(inds):
  204. """
  205. Finds a common type for the indexes to pass through to resulting index.
  206. Parameters
  207. ----------
  208. inds: list of Index or list objects
  209. Returns
  210. -------
  211. The common type or None if no indexes were given
  212. """
  213. dtypes = [idx.dtype for idx in indexes if isinstance(idx, Index)]
  214. if dtypes:
  215. dtype = find_common_type(dtypes)
  216. else:
  217. dtype = None
  218. return dtype
  219. if kind == "special":
  220. result = indexes[0]
  221. dtis = [x for x in indexes if isinstance(x, DatetimeIndex)]
  222. dti_tzs = [x for x in dtis if x.tz is not None]
  223. if len(dti_tzs) not in [0, len(dtis)]:
  224. # TODO: this behavior is not tested (so may not be desired),
  225. # but is kept in order to keep behavior the same when
  226. # deprecating union_many
  227. # test_frame_from_dict_with_mixed_indexes
  228. raise TypeError("Cannot join tz-naive with tz-aware DatetimeIndex")
  229. if len(dtis) == len(indexes):
  230. sort = True
  231. result = indexes[0]
  232. elif len(dtis) > 1:
  233. # If we have mixed timezones, our casting behavior may depend on
  234. # the order of indexes, which we don't want.
  235. sort = False
  236. # TODO: what about Categorical[dt64]?
  237. # test_frame_from_dict_with_mixed_indexes
  238. indexes = [x.astype(object, copy=False) for x in indexes]
  239. result = indexes[0]
  240. for other in indexes[1:]:
  241. result = result.union(other, sort=None if sort else False)
  242. return result
  243. elif kind == "array":
  244. dtype = _find_common_index_dtype(indexes)
  245. index = indexes[0]
  246. if not all(index.equals(other) for other in indexes[1:]):
  247. index = _unique_indices(indexes, dtype)
  248. name = get_unanimous_names(*indexes)[0]
  249. if name != index.name:
  250. index = index.rename(name)
  251. return index
  252. else: # kind='list'
  253. dtype = _find_common_index_dtype(indexes)
  254. return _unique_indices(indexes, dtype)
  255. def _sanitize_and_check(indexes):
  256. """
  257. Verify the type of indexes and convert lists to Index.
  258. Cases:
  259. - [list, list, ...]: Return ([list, list, ...], 'list')
  260. - [list, Index, ...]: Return _sanitize_and_check([Index, Index, ...])
  261. Lists are sorted and converted to Index.
  262. - [Index, Index, ...]: Return ([Index, Index, ...], TYPE)
  263. TYPE = 'special' if at least one special type, 'array' otherwise.
  264. Parameters
  265. ----------
  266. indexes : list of Index or list objects
  267. Returns
  268. -------
  269. sanitized_indexes : list of Index or list objects
  270. type : {'list', 'array', 'special'}
  271. """
  272. kinds = list({type(index) for index in indexes})
  273. if list in kinds:
  274. if len(kinds) > 1:
  275. indexes = [
  276. Index(list(x)) if not isinstance(x, Index) else x for x in indexes
  277. ]
  278. kinds.remove(list)
  279. else:
  280. return indexes, "list"
  281. if len(kinds) > 1 or Index not in kinds:
  282. return indexes, "special"
  283. else:
  284. return indexes, "array"
  285. def all_indexes_same(indexes) -> bool:
  286. """
  287. Determine if all indexes contain the same elements.
  288. Parameters
  289. ----------
  290. indexes : iterable of Index objects
  291. Returns
  292. -------
  293. bool
  294. True if all indexes contain the same elements, False otherwise.
  295. """
  296. itr = iter(indexes)
  297. first = next(itr)
  298. return all(first.equals(index) for index in itr)
  299. def default_index(n: int) -> RangeIndex:
  300. rng = range(0, n)
  301. return RangeIndex._simple_new(rng, name=None)