algorithms.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """Algorithms to support fitting routines in seaborn plotting functions."""
  2. import numpy as np
  3. import warnings
  4. def bootstrap(*args, **kwargs):
  5. """Resample one or more arrays with replacement and store aggregate values.
  6. Positional arguments are a sequence of arrays to bootstrap along the first
  7. axis and pass to a summary function.
  8. Keyword arguments:
  9. n_boot : int, default=10000
  10. Number of iterations
  11. axis : int, default=None
  12. Will pass axis to ``func`` as a keyword argument.
  13. units : array, default=None
  14. Array of sampling unit IDs. When used the bootstrap resamples units
  15. and then observations within units instead of individual
  16. datapoints.
  17. func : string or callable, default="mean"
  18. Function to call on the args that are passed in. If string, uses as
  19. name of function in the numpy namespace. If nans are present in the
  20. data, will try to use nan-aware version of named function.
  21. seed : Generator | SeedSequence | RandomState | int | None
  22. Seed for the random number generator; useful if you want
  23. reproducible resamples.
  24. Returns
  25. -------
  26. boot_dist: array
  27. array of bootstrapped statistic values
  28. """
  29. # Ensure list of arrays are same length
  30. if len(np.unique(list(map(len, args)))) > 1:
  31. raise ValueError("All input arrays must have the same length")
  32. n = len(args[0])
  33. # Default keyword arguments
  34. n_boot = kwargs.get("n_boot", 10000)
  35. func = kwargs.get("func", "mean")
  36. axis = kwargs.get("axis", None)
  37. units = kwargs.get("units", None)
  38. random_seed = kwargs.get("random_seed", None)
  39. if random_seed is not None:
  40. msg = "`random_seed` has been renamed to `seed` and will be removed"
  41. warnings.warn(msg)
  42. seed = kwargs.get("seed", random_seed)
  43. if axis is None:
  44. func_kwargs = dict()
  45. else:
  46. func_kwargs = dict(axis=axis)
  47. # Initialize the resampler
  48. if isinstance(seed, np.random.RandomState):
  49. rng = seed
  50. else:
  51. rng = np.random.default_rng(seed)
  52. # Coerce to arrays
  53. args = list(map(np.asarray, args))
  54. if units is not None:
  55. units = np.asarray(units)
  56. if isinstance(func, str):
  57. # Allow named numpy functions
  58. f = getattr(np, func)
  59. # Try to use nan-aware version of function if necessary
  60. missing_data = np.isnan(np.sum(np.column_stack(args)))
  61. if missing_data and not func.startswith("nan"):
  62. nanf = getattr(np, f"nan{func}", None)
  63. if nanf is None:
  64. msg = f"Data contain nans but no nan-aware version of `{func}` found"
  65. warnings.warn(msg, UserWarning)
  66. else:
  67. f = nanf
  68. else:
  69. f = func
  70. # Handle numpy changes
  71. try:
  72. integers = rng.integers
  73. except AttributeError:
  74. integers = rng.randint
  75. # Do the bootstrap
  76. if units is not None:
  77. return _structured_bootstrap(args, n_boot, units, f,
  78. func_kwargs, integers)
  79. boot_dist = []
  80. for i in range(int(n_boot)):
  81. resampler = integers(0, n, n, dtype=np.intp) # intp is indexing dtype
  82. sample = [a.take(resampler, axis=0) for a in args]
  83. boot_dist.append(f(*sample, **func_kwargs))
  84. return np.array(boot_dist)
  85. def _structured_bootstrap(args, n_boot, units, func, func_kwargs, integers):
  86. """Resample units instead of datapoints."""
  87. unique_units = np.unique(units)
  88. n_units = len(unique_units)
  89. args = [[a[units == unit] for unit in unique_units] for a in args]
  90. boot_dist = []
  91. for i in range(int(n_boot)):
  92. resampler = integers(0, n_units, n_units, dtype=np.intp)
  93. sample = [[a[i] for i in resampler] for a in args]
  94. lengths = map(len, sample[0])
  95. resampler = [integers(0, n, n, dtype=np.intp) for n in lengths]
  96. sample = [[c.take(r, axis=0) for c, r in zip(a, resampler)] for a in sample]
  97. sample = list(map(np.concatenate, sample))
  98. boot_dist.append(func(*sample, **func_kwargs))
  99. return np.array(boot_dist)