array.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import datetime as dt
  2. from typing import (
  3. Any,
  4. Optional,
  5. Sequence,
  6. Tuple,
  7. Union,
  8. cast,
  9. )
  10. import numpy as np
  11. from pandas._typing import (
  12. Dtype,
  13. PositionalIndexer,
  14. )
  15. from pandas.core.dtypes.dtypes import register_extension_dtype
  16. from pandas.api.extensions import (
  17. ExtensionArray,
  18. ExtensionDtype,
  19. )
  20. from pandas.api.types import pandas_dtype
  21. @register_extension_dtype
  22. class DateDtype(ExtensionDtype):
  23. @property
  24. def type(self):
  25. return dt.date
  26. @property
  27. def name(self):
  28. return "DateDtype"
  29. @classmethod
  30. def construct_from_string(cls, string: str):
  31. if not isinstance(string, str):
  32. raise TypeError(
  33. f"'construct_from_string' expects a string, got {type(string)}"
  34. )
  35. if string == cls.__name__:
  36. return cls()
  37. else:
  38. raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
  39. @classmethod
  40. def construct_array_type(cls):
  41. return DateArray
  42. @property
  43. def na_value(self):
  44. return dt.date.min
  45. def __repr__(self) -> str:
  46. return self.name
  47. class DateArray(ExtensionArray):
  48. def __init__(
  49. self,
  50. dates: Union[
  51. dt.date,
  52. Sequence[dt.date],
  53. Tuple[np.ndarray, np.ndarray, np.ndarray],
  54. np.ndarray,
  55. ],
  56. ) -> None:
  57. if isinstance(dates, dt.date):
  58. self._year = np.array([dates.year])
  59. self._month = np.array([dates.month])
  60. self._day = np.array([dates.year])
  61. return
  62. ldates = len(dates)
  63. if isinstance(dates, list):
  64. # pre-allocate the arrays since we know the size before hand
  65. self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
  66. self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
  67. self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
  68. # populate them
  69. for i, (y, m, d) in enumerate(
  70. map(lambda date: (date.year, date.month, date.day), dates)
  71. ):
  72. self._year[i] = y
  73. self._month[i] = m
  74. self._day[i] = d
  75. elif isinstance(dates, tuple):
  76. # only support triples
  77. if ldates != 3:
  78. raise ValueError("only triples are valid")
  79. # check if all elements have the same type
  80. if any(map(lambda x: not isinstance(x, np.ndarray), dates)):
  81. raise TypeError("invalid type")
  82. ly, lm, ld = (len(cast(np.ndarray, d)) for d in dates)
  83. if not ly == lm == ld:
  84. raise ValueError(
  85. f"tuple members must have the same length: {(ly, lm, ld)}"
  86. )
  87. self._year = dates[0].astype(np.uint16)
  88. self._month = dates[1].astype(np.uint8)
  89. self._day = dates[2].astype(np.uint8)
  90. elif isinstance(dates, np.ndarray) and dates.dtype == "U10":
  91. self._year = np.zeros(ldates, dtype=np.uint16) # 65535 (0, 9999)
  92. self._month = np.zeros(ldates, dtype=np.uint8) # 255 (1, 31)
  93. self._day = np.zeros(ldates, dtype=np.uint8) # 255 (1, 12)
  94. # error: "object_" object is not iterable
  95. obj = np.char.split(dates, sep="-")
  96. for (i,), (y, m, d) in np.ndenumerate(obj): # type: ignore[misc]
  97. self._year[i] = int(y)
  98. self._month[i] = int(m)
  99. self._day[i] = int(d)
  100. else:
  101. raise TypeError(f"{type(dates)} is not supported")
  102. @property
  103. def dtype(self) -> ExtensionDtype:
  104. return DateDtype()
  105. def astype(self, dtype, copy=True):
  106. dtype = pandas_dtype(dtype)
  107. if isinstance(dtype, DateDtype):
  108. data = self.copy() if copy else self
  109. else:
  110. data = self.to_numpy(dtype=dtype, copy=copy, na_value=dt.date.min)
  111. return data
  112. @property
  113. def nbytes(self) -> int:
  114. return self._year.nbytes + self._month.nbytes + self._day.nbytes
  115. def __len__(self) -> int:
  116. return len(self._year) # all 3 arrays are enforced to have the same length
  117. def __getitem__(self, item: PositionalIndexer):
  118. if isinstance(item, int):
  119. return dt.date(self._year[item], self._month[item], self._day[item])
  120. else:
  121. raise NotImplementedError("only ints are supported as indexes")
  122. def __setitem__(self, key: Union[int, slice, np.ndarray], value: Any):
  123. if not isinstance(key, int):
  124. raise NotImplementedError("only ints are supported as indexes")
  125. if not isinstance(value, dt.date):
  126. raise TypeError("you can only set datetime.date types")
  127. self._year[key] = value.year
  128. self._month[key] = value.month
  129. self._day[key] = value.day
  130. def __repr__(self) -> str:
  131. return f"DateArray{list(zip(self._year, self._month, self._day))}"
  132. def copy(self) -> "DateArray":
  133. return DateArray((self._year.copy(), self._month.copy(), self._day.copy()))
  134. def isna(self) -> np.ndarray:
  135. return np.logical_and(
  136. np.logical_and(
  137. self._year == dt.date.min.year, self._month == dt.date.min.month
  138. ),
  139. self._day == dt.date.min.day,
  140. )
  141. @classmethod
  142. def _from_sequence(cls, scalars, *, dtype: Optional[Dtype] = None, copy=False):
  143. if isinstance(scalars, dt.date):
  144. pass
  145. elif isinstance(scalars, DateArray):
  146. pass
  147. elif isinstance(scalars, np.ndarray):
  148. scalars = scalars.astype("U10") # 10 chars for yyyy-mm-dd
  149. return DateArray(scalars)