test_validation.py 676 B

123456789101112131415161718192021222324252627
  1. from typing import Callable
  2. import pytest
  3. from numpy import array_api as xp
  4. def p(func: Callable, *args, **kwargs):
  5. f_sig = ", ".join(
  6. [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()]
  7. )
  8. id_ = f"{func.__name__}({f_sig})"
  9. return pytest.param(func, args, kwargs, id=id_)
  10. @pytest.mark.parametrize(
  11. "func, args, kwargs",
  12. [
  13. p(xp.can_cast, 42, xp.int8),
  14. p(xp.can_cast, xp.int8, 42),
  15. p(xp.result_type, 42),
  16. ],
  17. )
  18. def test_raises_on_invalid_types(func, args, kwargs):
  19. """Function raises TypeError when passed invalidly-typed inputs"""
  20. with pytest.raises(TypeError):
  21. func(*args, **kwargs)