123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- from __future__ import annotations
- from typing import TYPE_CHECKING, List, Optional, Tuple, Union
- if TYPE_CHECKING:
- from ._typing import (
- Array,
- Device,
- Dtype,
- NestedSequence,
- SupportsBufferProtocol,
- )
- from collections.abc import Sequence
- from ._dtypes import _all_dtypes
- import numpy as np
- def _check_valid_dtype(dtype):
- # Note: Only spelling dtypes as the dtype objects is supported.
- # We use this instead of "dtype in _all_dtypes" because the dtype objects
- # define equality with the sorts of things we want to disallow.
- for d in (None,) + _all_dtypes:
- if dtype is d:
- return
- raise ValueError("dtype must be one of the supported dtypes")
- def asarray(
- obj: Union[
- Array,
- bool,
- int,
- float,
- NestedSequence[bool | int | float],
- SupportsBufferProtocol,
- ],
- /,
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- copy: Optional[Union[bool, np._CopyMode]] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
- See its docstring for more information.
- """
- # _array_object imports in this file are inside the functions to avoid
- # circular imports
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- if copy in (False, np._CopyMode.IF_NEEDED):
- # Note: copy=False is not yet implemented in np.asarray
- raise NotImplementedError("copy=False is not yet implemented")
- if isinstance(obj, Array):
- if dtype is not None and obj.dtype != dtype:
- copy = True
- if copy in (True, np._CopyMode.ALWAYS):
- return Array._new(np.array(obj._array, copy=True, dtype=dtype))
- return obj
- if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
- # Give a better error message in this case. NumPy would convert this
- # to an object array. TODO: This won't handle large integers in lists.
- raise OverflowError("Integer out of bounds for array dtypes")
- res = np.asarray(obj, dtype=dtype)
- return Array._new(res)
- def arange(
- start: Union[int, float],
- /,
- stop: Optional[Union[int, float]] = None,
- step: Union[int, float] = 1,
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
- def empty(
- shape: Union[int, Tuple[int, ...]],
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.empty(shape, dtype=dtype))
- def empty_like(
- x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.empty_like(x._array, dtype=dtype))
- def eye(
- n_rows: int,
- n_cols: Optional[int] = None,
- /,
- *,
- k: int = 0,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
- def from_dlpack(x: object, /) -> Array:
- from ._array_object import Array
- return Array._new(np.from_dlpack(x))
- def full(
- shape: Union[int, Tuple[int, ...]],
- fill_value: Union[int, float],
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- if isinstance(fill_value, Array) and fill_value.ndim == 0:
- fill_value = fill_value._array
- res = np.full(shape, fill_value, dtype=dtype)
- if res.dtype not in _all_dtypes:
- # This will happen if the fill value is not something that NumPy
- # coerces to one of the acceptable dtypes.
- raise TypeError("Invalid input to full")
- return Array._new(res)
- def full_like(
- x: Array,
- /,
- fill_value: Union[int, float],
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- res = np.full_like(x._array, fill_value, dtype=dtype)
- if res.dtype not in _all_dtypes:
- # This will happen if the fill value is not something that NumPy
- # coerces to one of the acceptable dtypes.
- raise TypeError("Invalid input to full_like")
- return Array._new(res)
- def linspace(
- start: Union[int, float],
- stop: Union[int, float],
- /,
- num: int,
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- endpoint: bool = True,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
- def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
- """
- Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- # Note: unlike np.meshgrid, only inputs with all the same dtype are
- # allowed
- if len({a.dtype for a in arrays}) > 1:
- raise ValueError("meshgrid inputs must all have the same dtype")
- return [
- Array._new(array)
- for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
- ]
- def ones(
- shape: Union[int, Tuple[int, ...]],
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.ones(shape, dtype=dtype))
- def ones_like(
- x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.ones_like(x._array, dtype=dtype))
- def tril(x: Array, /, *, k: int = 0) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- if x.ndim < 2:
- # Note: Unlike np.tril, x must be at least 2-D
- raise ValueError("x must be at least 2-dimensional for tril")
- return Array._new(np.tril(x._array, k=k))
- def triu(x: Array, /, *, k: int = 0) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- if x.ndim < 2:
- # Note: Unlike np.triu, x must be at least 2-D
- raise ValueError("x must be at least 2-dimensional for triu")
- return Array._new(np.triu(x._array, k=k))
- def zeros(
- shape: Union[int, Tuple[int, ...]],
- *,
- dtype: Optional[Dtype] = None,
- device: Optional[Device] = None,
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.zeros(shape, dtype=dtype))
- def zeros_like(
- x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
- ) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
- See its docstring for more information.
- """
- from ._array_object import Array
- _check_valid_dtype(dtype)
- if device not in ["cpu", None]:
- raise ValueError(f"Unsupported device {device!r}")
- return Array._new(np.zeros_like(x._array, dtype=dtype))
|