_creation_functions.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, List, Optional, Tuple, Union
  3. if TYPE_CHECKING:
  4. from ._typing import (
  5. Array,
  6. Device,
  7. Dtype,
  8. NestedSequence,
  9. SupportsBufferProtocol,
  10. )
  11. from collections.abc import Sequence
  12. from ._dtypes import _all_dtypes
  13. import numpy as np
  14. def _check_valid_dtype(dtype):
  15. # Note: Only spelling dtypes as the dtype objects is supported.
  16. # We use this instead of "dtype in _all_dtypes" because the dtype objects
  17. # define equality with the sorts of things we want to disallow.
  18. for d in (None,) + _all_dtypes:
  19. if dtype is d:
  20. return
  21. raise ValueError("dtype must be one of the supported dtypes")
  22. def asarray(
  23. obj: Union[
  24. Array,
  25. bool,
  26. int,
  27. float,
  28. NestedSequence[bool | int | float],
  29. SupportsBufferProtocol,
  30. ],
  31. /,
  32. *,
  33. dtype: Optional[Dtype] = None,
  34. device: Optional[Device] = None,
  35. copy: Optional[Union[bool, np._CopyMode]] = None,
  36. ) -> Array:
  37. """
  38. Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
  39. See its docstring for more information.
  40. """
  41. # _array_object imports in this file are inside the functions to avoid
  42. # circular imports
  43. from ._array_object import Array
  44. _check_valid_dtype(dtype)
  45. if device not in ["cpu", None]:
  46. raise ValueError(f"Unsupported device {device!r}")
  47. if copy in (False, np._CopyMode.IF_NEEDED):
  48. # Note: copy=False is not yet implemented in np.asarray
  49. raise NotImplementedError("copy=False is not yet implemented")
  50. if isinstance(obj, Array):
  51. if dtype is not None and obj.dtype != dtype:
  52. copy = True
  53. if copy in (True, np._CopyMode.ALWAYS):
  54. return Array._new(np.array(obj._array, copy=True, dtype=dtype))
  55. return obj
  56. if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
  57. # Give a better error message in this case. NumPy would convert this
  58. # to an object array. TODO: This won't handle large integers in lists.
  59. raise OverflowError("Integer out of bounds for array dtypes")
  60. res = np.asarray(obj, dtype=dtype)
  61. return Array._new(res)
  62. def arange(
  63. start: Union[int, float],
  64. /,
  65. stop: Optional[Union[int, float]] = None,
  66. step: Union[int, float] = 1,
  67. *,
  68. dtype: Optional[Dtype] = None,
  69. device: Optional[Device] = None,
  70. ) -> Array:
  71. """
  72. Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
  73. See its docstring for more information.
  74. """
  75. from ._array_object import Array
  76. _check_valid_dtype(dtype)
  77. if device not in ["cpu", None]:
  78. raise ValueError(f"Unsupported device {device!r}")
  79. return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
  80. def empty(
  81. shape: Union[int, Tuple[int, ...]],
  82. *,
  83. dtype: Optional[Dtype] = None,
  84. device: Optional[Device] = None,
  85. ) -> Array:
  86. """
  87. Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
  88. See its docstring for more information.
  89. """
  90. from ._array_object import Array
  91. _check_valid_dtype(dtype)
  92. if device not in ["cpu", None]:
  93. raise ValueError(f"Unsupported device {device!r}")
  94. return Array._new(np.empty(shape, dtype=dtype))
  95. def empty_like(
  96. x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
  97. ) -> Array:
  98. """
  99. Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
  100. See its docstring for more information.
  101. """
  102. from ._array_object import Array
  103. _check_valid_dtype(dtype)
  104. if device not in ["cpu", None]:
  105. raise ValueError(f"Unsupported device {device!r}")
  106. return Array._new(np.empty_like(x._array, dtype=dtype))
  107. def eye(
  108. n_rows: int,
  109. n_cols: Optional[int] = None,
  110. /,
  111. *,
  112. k: int = 0,
  113. dtype: Optional[Dtype] = None,
  114. device: Optional[Device] = None,
  115. ) -> Array:
  116. """
  117. Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
  118. See its docstring for more information.
  119. """
  120. from ._array_object import Array
  121. _check_valid_dtype(dtype)
  122. if device not in ["cpu", None]:
  123. raise ValueError(f"Unsupported device {device!r}")
  124. return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
  125. def from_dlpack(x: object, /) -> Array:
  126. from ._array_object import Array
  127. return Array._new(np.from_dlpack(x))
  128. def full(
  129. shape: Union[int, Tuple[int, ...]],
  130. fill_value: Union[int, float],
  131. *,
  132. dtype: Optional[Dtype] = None,
  133. device: Optional[Device] = None,
  134. ) -> Array:
  135. """
  136. Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
  137. See its docstring for more information.
  138. """
  139. from ._array_object import Array
  140. _check_valid_dtype(dtype)
  141. if device not in ["cpu", None]:
  142. raise ValueError(f"Unsupported device {device!r}")
  143. if isinstance(fill_value, Array) and fill_value.ndim == 0:
  144. fill_value = fill_value._array
  145. res = np.full(shape, fill_value, dtype=dtype)
  146. if res.dtype not in _all_dtypes:
  147. # This will happen if the fill value is not something that NumPy
  148. # coerces to one of the acceptable dtypes.
  149. raise TypeError("Invalid input to full")
  150. return Array._new(res)
  151. def full_like(
  152. x: Array,
  153. /,
  154. fill_value: Union[int, float],
  155. *,
  156. dtype: Optional[Dtype] = None,
  157. device: Optional[Device] = None,
  158. ) -> Array:
  159. """
  160. Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
  161. See its docstring for more information.
  162. """
  163. from ._array_object import Array
  164. _check_valid_dtype(dtype)
  165. if device not in ["cpu", None]:
  166. raise ValueError(f"Unsupported device {device!r}")
  167. res = np.full_like(x._array, fill_value, dtype=dtype)
  168. if res.dtype not in _all_dtypes:
  169. # This will happen if the fill value is not something that NumPy
  170. # coerces to one of the acceptable dtypes.
  171. raise TypeError("Invalid input to full_like")
  172. return Array._new(res)
  173. def linspace(
  174. start: Union[int, float],
  175. stop: Union[int, float],
  176. /,
  177. num: int,
  178. *,
  179. dtype: Optional[Dtype] = None,
  180. device: Optional[Device] = None,
  181. endpoint: bool = True,
  182. ) -> Array:
  183. """
  184. Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
  185. See its docstring for more information.
  186. """
  187. from ._array_object import Array
  188. _check_valid_dtype(dtype)
  189. if device not in ["cpu", None]:
  190. raise ValueError(f"Unsupported device {device!r}")
  191. return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
  192. def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
  193. """
  194. Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
  195. See its docstring for more information.
  196. """
  197. from ._array_object import Array
  198. # Note: unlike np.meshgrid, only inputs with all the same dtype are
  199. # allowed
  200. if len({a.dtype for a in arrays}) > 1:
  201. raise ValueError("meshgrid inputs must all have the same dtype")
  202. return [
  203. Array._new(array)
  204. for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
  205. ]
  206. def ones(
  207. shape: Union[int, Tuple[int, ...]],
  208. *,
  209. dtype: Optional[Dtype] = None,
  210. device: Optional[Device] = None,
  211. ) -> Array:
  212. """
  213. Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
  214. See its docstring for more information.
  215. """
  216. from ._array_object import Array
  217. _check_valid_dtype(dtype)
  218. if device not in ["cpu", None]:
  219. raise ValueError(f"Unsupported device {device!r}")
  220. return Array._new(np.ones(shape, dtype=dtype))
  221. def ones_like(
  222. x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
  223. ) -> Array:
  224. """
  225. Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
  226. See its docstring for more information.
  227. """
  228. from ._array_object import Array
  229. _check_valid_dtype(dtype)
  230. if device not in ["cpu", None]:
  231. raise ValueError(f"Unsupported device {device!r}")
  232. return Array._new(np.ones_like(x._array, dtype=dtype))
  233. def tril(x: Array, /, *, k: int = 0) -> Array:
  234. """
  235. Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
  236. See its docstring for more information.
  237. """
  238. from ._array_object import Array
  239. if x.ndim < 2:
  240. # Note: Unlike np.tril, x must be at least 2-D
  241. raise ValueError("x must be at least 2-dimensional for tril")
  242. return Array._new(np.tril(x._array, k=k))
  243. def triu(x: Array, /, *, k: int = 0) -> Array:
  244. """
  245. Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
  246. See its docstring for more information.
  247. """
  248. from ._array_object import Array
  249. if x.ndim < 2:
  250. # Note: Unlike np.triu, x must be at least 2-D
  251. raise ValueError("x must be at least 2-dimensional for triu")
  252. return Array._new(np.triu(x._array, k=k))
  253. def zeros(
  254. shape: Union[int, Tuple[int, ...]],
  255. *,
  256. dtype: Optional[Dtype] = None,
  257. device: Optional[Device] = None,
  258. ) -> Array:
  259. """
  260. Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
  261. See its docstring for more information.
  262. """
  263. from ._array_object import Array
  264. _check_valid_dtype(dtype)
  265. if device not in ["cpu", None]:
  266. raise ValueError(f"Unsupported device {device!r}")
  267. return Array._new(np.zeros(shape, dtype=dtype))
  268. def zeros_like(
  269. x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
  270. ) -> Array:
  271. """
  272. Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
  273. See its docstring for more information.
  274. """
  275. from ._array_object import Array
  276. _check_valid_dtype(dtype)
  277. if device not in ["cpu", None]:
  278. raise ValueError(f"Unsupported device {device!r}")
  279. return Array._new(np.zeros_like(x._array, dtype=dtype))