regression.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. import numpy as np
  4. import pandas as pd
  5. from seaborn._stats.base import Stat
  6. @dataclass
  7. class PolyFit(Stat):
  8. """
  9. Fit a polynomial of the given order and resample data onto predicted curve.
  10. """
  11. # This is a provisional class that is useful for building out functionality.
  12. # It may or may not change substantially in form or dissappear as we think
  13. # through the organization of the stats subpackage.
  14. order: int = 2
  15. gridsize: int = 100
  16. def _fit_predict(self, data):
  17. x = data["x"]
  18. y = data["y"]
  19. if x.nunique() <= self.order:
  20. # TODO warn?
  21. xx = yy = []
  22. else:
  23. p = np.polyfit(x, y, self.order)
  24. xx = np.linspace(x.min(), x.max(), self.gridsize)
  25. yy = np.polyval(p, xx)
  26. return pd.DataFrame(dict(x=xx, y=yy))
  27. # TODO we should have a way of identifying the method that will be applied
  28. # and then only define __call__ on a base-class of stats with this pattern
  29. def __call__(self, data, groupby, orient, scales):
  30. return (
  31. groupby
  32. .apply(data.dropna(subset=["x", "y"]), self._fit_predict)
  33. )
  34. @dataclass
  35. class OLSFit(Stat):
  36. ...