123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import datetime as dt
- from typing import (
- Any,
- Optional,
- Sequence,
- Tuple,
- Union,
- cast,
- )
- import numpy as np
- from pandas._typing import (
- Dtype,
- PositionalIndexer,
- )
- from pandas.core.dtypes.dtypes import register_extension_dtype
- from pandas.api.extensions import (
- ExtensionArray,
- ExtensionDtype,
- )
- from pandas.api.types import pandas_dtype
- @register_extension_dtype
- class DateDtype(ExtensionDtype):
- @property
- def type(self):
- return dt.date
- @property
- def name(self):
- return "DateDtype"
- @classmethod
- def construct_from_string(cls, string: str):
- if not isinstance(string, str):
- raise TypeError(
- f"'construct_from_string' expects a string, got {type(string)}"
- )
- if string == cls.__name__:
- return cls()
- else:
- raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
- @classmethod
- def construct_array_type(cls):
- return DateArray
- @property
- def na_value(self):
- return dt.date.min
- def __repr__(self) -> str:
- return self.name
- class DateArray(ExtensionArray):
- def __init__(
- self,
- dates: Union[
- dt.date,
- Sequence[dt.date],
- Tuple[np.ndarray, np.ndarray, np.ndarray],
- np.ndarray,
- ],
- ) -> None:
- if isinstance(dates, dt.date):
- self._year = np.array([dates.year])
- self._month = np.array([dates.month])
- self._day = np.array([dates.year])
- return
- ldates = len(dates)
- if isinstance(dates, list):
- # pre-allocate the arrays since we know the size before hand
- self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
- self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
- self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
- # populate them
- for i, (y, m, d) in enumerate(
- map(lambda date: (date.year, date.month, date.day), dates)
- ):
- self._year[i] = y
- self._month[i] = m
- self._day[i] = d
- elif isinstance(dates, tuple):
- # only support triples
- if ldates != 3:
- raise ValueError("only triples are valid")
- # check if all elements have the same type
- if any(map(lambda x: not isinstance(x, np.ndarray), dates)):
- raise TypeError("invalid type")
- ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
- if not ly == lm == ld:
- raise ValueError(
- f"tuple members must have the same length: {(ly, lm, ld)}"
- )
- self._year = dates[0].astype(np.uint16)
- self._month = dates[1].astype(np.uint8)
- self._day = dates[2].astype(np.uint8)
- elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
- self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
- self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
- self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
- # error: "object_" object is not iterable
- obj = np.char.split(dates, sep="-")
- for (i,), (y, m, d) in np.ndenumerate(obj): # type: ignore[misc]
- self._year[i] = int(y)
- self._month[i] = int(m)
- self._day[i] = int(d)
- else:
- raise TypeError(f"{type(dates)} is not supported")
- @property
- def dtype(self) -> ExtensionDtype:
- return DateDtype()
- def astype(self, dtype, copy=True):
- dtype = pandas_dtype(dtype)
- if isinstance(dtype, DateDtype):
- data = self.copy() if copy else self
- else:
- data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)
- return data
- @property
- def nbytes(self) -> int:
- return self._year.nbytes + self._month.nbytes + self._day.nbytes
- def __len__(self) -> int:
- return len(self._year) # all 3 arrays are enforced to have the same length
- def __getitem__(self, item: PositionalIndexer):
- if isinstance(item, int):
- return dt.date(self._year[item], self._month[item], self._day[item])
- else:
- raise NotImplementedError("only ints are supported as indexes")
- def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any):
- if not isinstance(key, int):
- raise NotImplementedError("only ints are supported as indexes")
- if not isinstance(value, dt.date):
- raise TypeError("you can only set datetime.date types")
- self._year[key] = value.year
- self._month[key] = value.month
- self._day[key] = value.day
- def __repr__(self) -> str:
- return f"DateArray{list(zip(self._year, self._month, self._day))}"
- def copy(self) -> "DateArray":
- return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))
- def isna(self) -> np.ndarray:
- return np.logical_and(
- np.logical_and(
- self._year == dt.date.min.year, self._month == dt.date.min.month
- ),
- self._day == dt.date.min.day,
- )
- @classmethod
- def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
- if isinstance(scalars, dt.date):
- pass
- elif isinstance(scalars, DateArray):
- pass
- elif isinstance(scalars, np.ndarray):
- scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd
- return DateArray(scalars)
|