test_streams.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. """ Testing
  2. """
  3. import os
  4. import zlib
  5. from io import BytesIO
  6. from tempfile import mkstemp
  7. from contextlib import contextmanager
  8. import numpy as np
  9. from numpy.testing import assert_, assert_equal
  10. from pytest import raises as assert_raises
  11. from scipy.io.matlab._streams import (make_stream,
  12. GenericStream, ZlibInputStream,
  13. _read_into, _read_string, BLOCK_SIZE)
  14. @contextmanager
  15. def setup_test_file():
  16. val = b'a\x00string'
  17. fd, fname = mkstemp()
  18. with os.fdopen(fd, 'wb') as fs:
  19. fs.write(val)
  20. with open(fname, 'rb') as fs:
  21. gs = BytesIO(val)
  22. cs = BytesIO(val)
  23. yield fs, gs, cs
  24. os.unlink(fname)
  25. def test_make_stream():
  26. with setup_test_file() as (fs, gs, cs):
  27. # test stream initialization
  28. assert_(isinstance(make_stream(gs), GenericStream))
  29. def test_tell_seek():
  30. with setup_test_file() as (fs, gs, cs):
  31. for s in (fs, gs, cs):
  32. st = make_stream(s)
  33. res = st.seek(0)
  34. assert_equal(res, 0)
  35. assert_equal(st.tell(), 0)
  36. res = st.seek(5)
  37. assert_equal(res, 0)
  38. assert_equal(st.tell(), 5)
  39. res = st.seek(2, 1)
  40. assert_equal(res, 0)
  41. assert_equal(st.tell(), 7)
  42. res = st.seek(-2, 2)
  43. assert_equal(res, 0)
  44. assert_equal(st.tell(), 6)
  45. def test_read():
  46. with setup_test_file() as (fs, gs, cs):
  47. for s in (fs, gs, cs):
  48. st = make_stream(s)
  49. st.seek(0)
  50. res = st.read(-1)
  51. assert_equal(res, b'a\x00string')
  52. st.seek(0)
  53. res = st.read(4)
  54. assert_equal(res, b'a\x00st')
  55. # read into
  56. st.seek(0)
  57. res = _read_into(st, 4)
  58. assert_equal(res, b'a\x00st')
  59. res = _read_into(st, 4)
  60. assert_equal(res, b'ring')
  61. assert_raises(OSError, _read_into, st, 2)
  62. # read alloc
  63. st.seek(0)
  64. res = _read_string(st, 4)
  65. assert_equal(res, b'a\x00st')
  66. res = _read_string(st, 4)
  67. assert_equal(res, b'ring')
  68. assert_raises(OSError, _read_string, st, 2)
  69. class TestZlibInputStream:
  70. def _get_data(self, size):
  71. data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
  72. compressed_data = zlib.compress(data)
  73. stream = BytesIO(compressed_data)
  74. return stream, len(compressed_data), data
  75. def test_read(self):
  76. SIZES = [0, 1, 10, BLOCK_SIZE//2, BLOCK_SIZE-1,
  77. BLOCK_SIZE, BLOCK_SIZE+1, 2*BLOCK_SIZE-1]
  78. READ_SIZES = [BLOCK_SIZE//2, BLOCK_SIZE-1,
  79. BLOCK_SIZE, BLOCK_SIZE+1]
  80. def check(size, read_size):
  81. compressed_stream, compressed_data_len, data = self._get_data(size)
  82. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  83. data2 = b''
  84. so_far = 0
  85. while True:
  86. block = stream.read(min(read_size,
  87. size - so_far))
  88. if not block:
  89. break
  90. so_far += len(block)
  91. data2 += block
  92. assert_equal(data, data2)
  93. for size in SIZES:
  94. for read_size in READ_SIZES:
  95. check(size, read_size)
  96. def test_read_max_length(self):
  97. size = 1234
  98. data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
  99. compressed_data = zlib.compress(data)
  100. compressed_stream = BytesIO(compressed_data + b"abbacaca")
  101. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  102. stream.read(len(data))
  103. assert_equal(compressed_stream.tell(), len(compressed_data))
  104. assert_raises(OSError, stream.read, 1)
  105. def test_read_bad_checksum(self):
  106. data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
  107. compressed_data = zlib.compress(data)
  108. # break checksum
  109. compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
  110. compressed_stream = BytesIO(compressed_data)
  111. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  112. assert_raises(zlib.error, stream.read, len(data))
  113. def test_seek(self):
  114. compressed_stream, compressed_data_len, data = self._get_data(1024)
  115. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  116. stream.seek(123)
  117. p = 123
  118. assert_equal(stream.tell(), p)
  119. d1 = stream.read(11)
  120. assert_equal(d1, data[p:p+11])
  121. stream.seek(321, 1)
  122. p = 123+11+321
  123. assert_equal(stream.tell(), p)
  124. d2 = stream.read(21)
  125. assert_equal(d2, data[p:p+21])
  126. stream.seek(641, 0)
  127. p = 641
  128. assert_equal(stream.tell(), p)
  129. d3 = stream.read(11)
  130. assert_equal(d3, data[p:p+11])
  131. assert_raises(OSError, stream.seek, 10, 2)
  132. assert_raises(OSError, stream.seek, -1, 1)
  133. assert_raises(ValueError, stream.seek, 1, 123)
  134. stream.seek(10000, 1)
  135. assert_raises(OSError, stream.read, 12)
  136. def test_seek_bad_checksum(self):
  137. data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
  138. compressed_data = zlib.compress(data)
  139. # break checksum
  140. compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
  141. compressed_stream = BytesIO(compressed_data)
  142. stream = ZlibInputStream(compressed_stream, len(compressed_data))
  143. assert_raises(zlib.error, stream.seek, len(data))
  144. def test_all_data_read(self):
  145. compressed_stream, compressed_data_len, data = self._get_data(1024)
  146. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  147. assert_(not stream.all_data_read())
  148. stream.seek(512)
  149. assert_(not stream.all_data_read())
  150. stream.seek(1024)
  151. assert_(stream.all_data_read())
  152. def test_all_data_read_overlap(self):
  153. COMPRESSION_LEVEL = 6
  154. data = np.arange(33707000).astype(np.uint8).tobytes()
  155. compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
  156. compressed_data_len = len(compressed_data)
  157. # check that part of the checksum overlaps
  158. assert_(compressed_data_len == BLOCK_SIZE + 2)
  159. compressed_stream = BytesIO(compressed_data)
  160. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  161. assert_(not stream.all_data_read())
  162. stream.seek(len(data))
  163. assert_(stream.all_data_read())
  164. def test_all_data_read_bad_checksum(self):
  165. COMPRESSION_LEVEL = 6
  166. data = np.arange(33707000).astype(np.uint8).tobytes()
  167. compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
  168. compressed_data_len = len(compressed_data)
  169. # check that part of the checksum overlaps
  170. assert_(compressed_data_len == BLOCK_SIZE + 2)
  171. # break checksum
  172. compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
  173. compressed_stream = BytesIO(compressed_data)
  174. stream = ZlibInputStream(compressed_stream, compressed_data_len)
  175. assert_(not stream.all_data_read())
  176. stream.seek(len(data))
  177. assert_raises(zlib.error, stream.all_data_read)