sample.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. """
  2. Module containing utilities for NDFrame.sample() and .GroupBy.sample()
  3. """
  4. from __future__ import annotations
  5. from typing import TYPE_CHECKING
  6. import numpy as np
  7. from pandas._libs import lib
  8. from pandas._typing import AxisInt
  9. from pandas.core.dtypes.generic import (
  10. ABCDataFrame,
  11. ABCSeries,
  12. )
  13. if TYPE_CHECKING:
  14. from pandas.core.generic import NDFrame
  15. def preprocess_weights(obj: NDFrame, weights, axis: AxisInt) -> np.ndarray:
  16. """
  17. Process and validate the `weights` argument to `NDFrame.sample` and
  18. `.GroupBy.sample`.
  19. Returns `weights` as an ndarray[np.float64], validated except for normalizing
  20. weights (because that must be done groupwise in groupby sampling).
  21. """
  22. # If a series, align with frame
  23. if isinstance(weights, ABCSeries):
  24. weights = weights.reindex(obj.axes[axis])
  25. # Strings acceptable if a dataframe and axis = 0
  26. if isinstance(weights, str):
  27. if isinstance(obj, ABCDataFrame):
  28. if axis == 0:
  29. try:
  30. weights = obj[weights]
  31. except KeyError as err:
  32. raise KeyError(
  33. "String passed to weights not a valid column"
  34. ) from err
  35. else:
  36. raise ValueError(
  37. "Strings can only be passed to "
  38. "weights when sampling from rows on "
  39. "a DataFrame"
  40. )
  41. else:
  42. raise ValueError(
  43. "Strings cannot be passed as weights when sampling from a Series."
  44. )
  45. if isinstance(obj, ABCSeries):
  46. func = obj._constructor
  47. else:
  48. func = obj._constructor_sliced
  49. weights = func(weights, dtype="float64")._values
  50. if len(weights) != obj.shape[axis]:
  51. raise ValueError("Weights and axis to be sampled must be of same length")
  52. if lib.has_infs(weights):
  53. raise ValueError("weight vector may not include `inf` values")
  54. if (weights < 0).any():
  55. raise ValueError("weight vector many not include negative values")
  56. missing = np.isnan(weights)
  57. if missing.any():
  58. # Don't modify weights in place
  59. weights = weights.copy()
  60. weights[missing] = 0
  61. return weights
  62. def process_sampling_size(
  63. n: int | None, frac: float | None, replace: bool
  64. ) -> int | None:
  65. """
  66. Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
  67. `.GroupBy.sample`.
  68. Returns None if `frac` should be used (variable sampling sizes), otherwise returns
  69. the constant sampling size.
  70. """
  71. # If no frac or n, default to n=1.
  72. if n is None and frac is None:
  73. n = 1
  74. elif n is not None and frac is not None:
  75. raise ValueError("Please enter a value for `frac` OR `n`, not both")
  76. elif n is not None:
  77. if n < 0:
  78. raise ValueError(
  79. "A negative number of rows requested. Please provide `n` >= 0."
  80. )
  81. if n % 1 != 0:
  82. raise ValueError("Only integers accepted as `n` values")
  83. else:
  84. assert frac is not None # for mypy
  85. if frac > 1 and not replace:
  86. raise ValueError(
  87. "Replace has to be set to `True` when "
  88. "upsampling the population `frac` > 1."
  89. )
  90. if frac < 0:
  91. raise ValueError(
  92. "A negative number of rows requested. Please provide `frac` >= 0."
  93. )
  94. return n
  95. def sample(
  96. obj_len: int,
  97. size: int,
  98. replace: bool,
  99. weights: np.ndarray | None,
  100. random_state: np.random.RandomState | np.random.Generator,
  101. ) -> np.ndarray:
  102. """
  103. Randomly sample `size` indices in `np.arange(obj_len)`
  104. Parameters
  105. ----------
  106. obj_len : int
  107. The length of the indices being considered
  108. size : int
  109. The number of values to choose
  110. replace : bool
  111. Allow or disallow sampling of the same row more than once.
  112. weights : np.ndarray[np.float64] or None
  113. If None, equal probability weighting, otherwise weights according
  114. to the vector normalized
  115. random_state: np.random.RandomState or np.random.Generator
  116. State used for the random sampling
  117. Returns
  118. -------
  119. np.ndarray[np.intp]
  120. """
  121. if weights is not None:
  122. weight_sum = weights.sum()
  123. if weight_sum != 0:
  124. weights = weights / weight_sum
  125. else:
  126. raise ValueError("Invalid weights: weights sum to zero")
  127. return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
  128. np.intp, copy=False
  129. )