123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- """Test the runtime usage of `numpy.typing`."""
- from __future__ import annotations
- import sys
- from typing import (
- get_type_hints,
- Union,
- NamedTuple,
- get_args,
- get_origin,
- Any,
- )
- import pytest
- import numpy as np
- import numpy.typing as npt
- import numpy._typing as _npt
- class TypeTup(NamedTuple):
- typ: type
- args: tuple[type, ...]
- origin: None | type
- if sys.version_info >= (3, 9):
- NDArrayTup = TypeTup(npt.NDArray, npt.NDArray.__args__, np.ndarray)
- else:
- NDArrayTup = TypeTup(npt.NDArray, (), None)
- TYPES = {
- "ArrayLike": TypeTup(npt.ArrayLike, npt.ArrayLike.__args__, Union),
- "DTypeLike": TypeTup(npt.DTypeLike, npt.DTypeLike.__args__, Union),
- "NBitBase": TypeTup(npt.NBitBase, (), None),
- "NDArray": NDArrayTup,
- }
- @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
- def test_get_args(name: type, tup: TypeTup) -> None:
- """Test `typing.get_args`."""
- typ, ref = tup.typ, tup.args
- out = get_args(typ)
- assert out == ref
- @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
- def test_get_origin(name: type, tup: TypeTup) -> None:
- """Test `typing.get_origin`."""
- typ, ref = tup.typ, tup.origin
- out = get_origin(typ)
- assert out == ref
- @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
- def test_get_type_hints(name: type, tup: TypeTup) -> None:
- """Test `typing.get_type_hints`."""
- typ = tup.typ
- # Explicitly set `__annotations__` in order to circumvent the
- # stringification performed by `from __future__ import annotations`
- def func(a): pass
- func.__annotations__ = {"a": typ, "return": None}
- out = get_type_hints(func)
- ref = {"a": typ, "return": type(None)}
- assert out == ref
- @pytest.mark.parametrize("name,tup", TYPES.items(), ids=TYPES.keys())
- def test_get_type_hints_str(name: type, tup: TypeTup) -> None:
- """Test `typing.get_type_hints` with string-representation of types."""
- typ_str, typ = f"npt.{name}", tup.typ
- # Explicitly set `__annotations__` in order to circumvent the
- # stringification performed by `from __future__ import annotations`
- def func(a): pass
- func.__annotations__ = {"a": typ_str, "return": None}
- out = get_type_hints(func)
- ref = {"a": typ, "return": type(None)}
- assert out == ref
- def test_keys() -> None:
- """Test that ``TYPES.keys()`` and ``numpy.typing.__all__`` are synced."""
- keys = TYPES.keys()
- ref = set(npt.__all__)
- assert keys == ref
- PROTOCOLS: dict[str, tuple[type[Any], object]] = {
- "_SupportsDType": (_npt._SupportsDType, np.int64(1)),
- "_SupportsArray": (_npt._SupportsArray, np.arange(10)),
- "_SupportsArrayFunc": (_npt._SupportsArrayFunc, np.arange(10)),
- "_NestedSequence": (_npt._NestedSequence, [1]),
- }
- @pytest.mark.parametrize("cls,obj", PROTOCOLS.values(), ids=PROTOCOLS.keys())
- class TestRuntimeProtocol:
- def test_isinstance(self, cls: type[Any], obj: object) -> None:
- assert isinstance(obj, cls)
- assert not isinstance(None, cls)
- def test_issubclass(self, cls: type[Any], obj: object) -> None:
- if cls is _npt._SupportsDType:
- pytest.xfail(
- "Protocols with non-method members don't support issubclass()"
- )
- assert issubclass(type(obj), cls)
- assert not issubclass(type(None), cls)
|