test_dlpack.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import sys
  2. import pytest
  3. import numpy as np
  4. from numpy.testing import assert_array_equal, IS_PYPY
  5. class TestDLPack:
  6. @pytest.mark.skipif(IS_PYPY, reason="PyPy can't get refcounts.")
  7. def test_dunder_dlpack_refcount(self):
  8. x = np.arange(5)
  9. y = x.__dlpack__()
  10. assert sys.getrefcount(x) == 3
  11. del y
  12. assert sys.getrefcount(x) == 2
  13. def test_dunder_dlpack_stream(self):
  14. x = np.arange(5)
  15. x.__dlpack__(stream=None)
  16. with pytest.raises(RuntimeError):
  17. x.__dlpack__(stream=1)
  18. def test_strides_not_multiple_of_itemsize(self):
  19. dt = np.dtype([('int', np.int32), ('char', np.int8)])
  20. y = np.zeros((5,), dtype=dt)
  21. z = y['int']
  22. with pytest.raises(BufferError):
  23. np.from_dlpack(z)
  24. @pytest.mark.skipif(IS_PYPY, reason="PyPy can't get refcounts.")
  25. def test_from_dlpack_refcount(self):
  26. x = np.arange(5)
  27. y = np.from_dlpack(x)
  28. assert sys.getrefcount(x) == 3
  29. del y
  30. assert sys.getrefcount(x) == 2
  31. @pytest.mark.parametrize("dtype", [
  32. np.int8, np.int16, np.int32, np.int64,
  33. np.uint8, np.uint16, np.uint32, np.uint64,
  34. np.float16, np.float32, np.float64,
  35. np.complex64, np.complex128
  36. ])
  37. def test_dtype_passthrough(self, dtype):
  38. x = np.arange(5, dtype=dtype)
  39. y = np.from_dlpack(x)
  40. assert y.dtype == x.dtype
  41. assert_array_equal(x, y)
  42. def test_invalid_dtype(self):
  43. x = np.asarray(np.datetime64('2021-05-27'))
  44. with pytest.raises(BufferError):
  45. np.from_dlpack(x)
  46. def test_invalid_byte_swapping(self):
  47. dt = np.dtype('=i8').newbyteorder()
  48. x = np.arange(5, dtype=dt)
  49. with pytest.raises(BufferError):
  50. np.from_dlpack(x)
  51. def test_non_contiguous(self):
  52. x = np.arange(25).reshape((5, 5))
  53. y1 = x[0]
  54. assert_array_equal(y1, np.from_dlpack(y1))
  55. y2 = x[:, 0]
  56. assert_array_equal(y2, np.from_dlpack(y2))
  57. y3 = x[1, :]
  58. assert_array_equal(y3, np.from_dlpack(y3))
  59. y4 = x[1]
  60. assert_array_equal(y4, np.from_dlpack(y4))
  61. y5 = np.diagonal(x).copy()
  62. assert_array_equal(y5, np.from_dlpack(y5))
  63. @pytest.mark.parametrize("ndim", range(33))
  64. def test_higher_dims(self, ndim):
  65. shape = (1,) * ndim
  66. x = np.zeros(shape, dtype=np.float64)
  67. assert shape == np.from_dlpack(x).shape
  68. def test_dlpack_device(self):
  69. x = np.arange(5)
  70. assert x.__dlpack_device__() == (1, 0)
  71. y = np.from_dlpack(x)
  72. assert y.__dlpack_device__() == (1, 0)
  73. z = y[::2]
  74. assert z.__dlpack_device__() == (1, 0)
  75. def dlpack_deleter_exception(self):
  76. x = np.arange(5)
  77. _ = x.__dlpack__()
  78. raise RuntimeError
  79. def test_dlpack_destructor_exception(self):
  80. with pytest.raises(RuntimeError):
  81. self.dlpack_deleter_exception()
  82. def test_readonly(self):
  83. x = np.arange(5)
  84. x.flags.writeable = False
  85. with pytest.raises(BufferError):
  86. x.__dlpack__()
  87. def test_ndim0(self):
  88. x = np.array(1.0)
  89. y = np.from_dlpack(x)
  90. assert_array_equal(x, y)
  91. def test_size1dims_arrays(self):
  92. x = np.ndarray(dtype='f8', shape=(10, 5, 1), strides=(8, 80, 4),
  93. buffer=np.ones(1000, dtype=np.uint8), order='F')
  94. y = np.from_dlpack(x)
  95. assert_array_equal(x, y)