rules.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from __future__ import annotations
  2. import warnings
  3. from collections import UserString
  4. from numbers import Number
  5. from datetime import datetime
  6. import numpy as np
  7. import pandas as pd
  8. from typing import TYPE_CHECKING
  9. if TYPE_CHECKING:
  10. from typing import Literal
  11. from pandas import Series
  12. class VarType(UserString):
  13. """
  14. Prevent comparisons elsewhere in the library from using the wrong name.
  15. Errors are simple assertions because users should not be able to trigger
  16. them. If that changes, they should be more verbose.
  17. """
  18. # TODO VarType is an awfully overloaded name, but so is DataType ...
  19. # TODO adding unknown because we are using this in for scales, is that right?
  20. allowed = "numeric", "datetime", "categorical", "boolean", "unknown"
  21. def __init__(self, data):
  22. assert data in self.allowed, data
  23. super().__init__(data)
  24. def __eq__(self, other):
  25. assert other in self.allowed, other
  26. return self.data == other
  27. def variable_type(
  28. vector: Series,
  29. boolean_type: Literal["numeric", "categorical", "boolean"] = "numeric",
  30. strict_boolean: bool = False,
  31. ) -> VarType:
  32. """
  33. Determine whether a vector contains numeric, categorical, or datetime data.
  34. This function differs from the pandas typing API in a few ways:
  35. - Python sequences or object-typed PyData objects are considered numeric if
  36. all of their entries are numeric.
  37. - String or mixed-type data are considered categorical even if not
  38. explicitly represented as a :class:`pandas.api.types.CategoricalDtype`.
  39. - There is some flexibility about how to treat binary / boolean data.
  40. Parameters
  41. ----------
  42. vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence
  43. Input data to test.
  44. boolean_type : 'numeric', 'categorical', or 'boolean'
  45. Type to use for vectors containing only 0s and 1s (and NAs).
  46. strict_boolean : bool
  47. If True, only consider data to be boolean when the dtype is bool or Boolean.
  48. Returns
  49. -------
  50. var_type : 'numeric', 'categorical', or 'datetime'
  51. Name identifying the type of data in the vector.
  52. """
  53. # If a categorical dtype is set, infer categorical
  54. if isinstance(getattr(vector, 'dtype', None), pd.CategoricalDtype):
  55. return VarType("categorical")
  56. # Special-case all-na data, which is always "numeric"
  57. if pd.isna(vector).all():
  58. return VarType("numeric")
  59. # Now drop nulls to simplify further type inference
  60. vector = vector.dropna()
  61. # Special-case binary/boolean data, allow caller to determine
  62. # This triggers a numpy warning when vector has strings/objects
  63. # https://github.com/numpy/numpy/issues/6784
  64. # Because we reduce with .all(), we are agnostic about whether the
  65. # comparison returns a scalar or vector, so we will ignore the warning.
  66. # It triggers a separate DeprecationWarning when the vector has datetimes:
  67. # https://github.com/numpy/numpy/issues/13548
  68. # This is considered a bug by numpy and will likely go away.
  69. with warnings.catch_warnings():
  70. warnings.simplefilter(
  71. action='ignore',
  72. category=(FutureWarning, DeprecationWarning) # type: ignore # mypy bug?
  73. )
  74. if strict_boolean:
  75. if isinstance(vector.dtype, pd.core.dtypes.base.ExtensionDtype):
  76. boolean_dtypes = ["bool", "boolean"]
  77. else:
  78. boolean_dtypes = ["bool"]
  79. boolean_vector = vector.dtype in boolean_dtypes
  80. else:
  81. boolean_vector = bool(np.isin(vector, [0, 1]).all())
  82. if boolean_vector:
  83. return VarType(boolean_type)
  84. # Defer to positive pandas tests
  85. if pd.api.types.is_numeric_dtype(vector):
  86. return VarType("numeric")
  87. if pd.api.types.is_datetime64_dtype(vector):
  88. return VarType("datetime")
  89. # --- If we get to here, we need to check the entries
  90. # Check for a collection where everything is a number
  91. def all_numeric(x):
  92. for x_i in x:
  93. if not isinstance(x_i, Number):
  94. return False
  95. return True
  96. if all_numeric(vector):
  97. return VarType("numeric")
  98. # Check for a collection where everything is a datetime
  99. def all_datetime(x):
  100. for x_i in x:
  101. if not isinstance(x_i, (datetime, np.datetime64)):
  102. return False
  103. return True
  104. if all_datetime(vector):
  105. return VarType("datetime")
  106. # Otherwise, our final fallback is to consider things categorical
  107. return VarType("categorical")
  108. def categorical_order(vector: Series, order: list | None = None) -> list:
  109. """
  110. Return a list of unique data values using seaborn's ordering rules.
  111. Parameters
  112. ----------
  113. vector : Series
  114. Vector of "categorical" values
  115. order : list
  116. Desired order of category levels to override the order determined
  117. from the `data` object.
  118. Returns
  119. -------
  120. order : list
  121. Ordered list of category levels not including null values.
  122. """
  123. if order is not None:
  124. return order
  125. if vector.dtype.name == "category":
  126. order = list(vector.cat.categories)
  127. else:
  128. order = list(filter(pd.notnull, vector.unique()))
  129. if variable_type(pd.Series(order)) == "numeric":
  130. order.sort()
  131. return order