123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- """ Testing
- """
- import os
- import zlib
- from io import BytesIO
- from tempfile import mkstemp
- from contextlib import contextmanager
- import numpy as np
- from numpy.testing import assert_, assert_equal
- from pytest import raises as assert_raises
- from scipy.io.matlab._streams import (make_stream,
- GenericStream, ZlibInputStream,
- _read_into, _read_string, BLOCK_SIZE)
- @contextmanager
- def setup_test_file():
- val = b'a\x00string'
- fd, fname = mkstemp()
- with os.fdopen(fd, 'wb') as fs:
- fs.write(val)
- with open(fname, 'rb') as fs:
- gs = BytesIO(val)
- cs = BytesIO(val)
- yield fs, gs, cs
- os.unlink(fname)
- def test_make_stream():
- with setup_test_file() as (fs, gs, cs):
- # test stream initialization
- assert_(isinstance(make_stream(gs), GenericStream))
- def test_tell_seek():
- with setup_test_file() as (fs, gs, cs):
- for s in (fs, gs, cs):
- st = make_stream(s)
- res = st.seek(0)
- assert_equal(res, 0)
- assert_equal(st.tell(), 0)
- res = st.seek(5)
- assert_equal(res, 0)
- assert_equal(st.tell(), 5)
- res = st.seek(2, 1)
- assert_equal(res, 0)
- assert_equal(st.tell(), 7)
- res = st.seek(-2, 2)
- assert_equal(res, 0)
- assert_equal(st.tell(), 6)
- def test_read():
- with setup_test_file() as (fs, gs, cs):
- for s in (fs, gs, cs):
- st = make_stream(s)
- st.seek(0)
- res = st.read(-1)
- assert_equal(res, b'a\x00string')
- st.seek(0)
- res = st.read(4)
- assert_equal(res, b'a\x00st')
- # read into
- st.seek(0)
- res = _read_into(st, 4)
- assert_equal(res, b'a\x00st')
- res = _read_into(st, 4)
- assert_equal(res, b'ring')
- assert_raises(OSError, _read_into, st, 2)
- # read alloc
- st.seek(0)
- res = _read_string(st, 4)
- assert_equal(res, b'a\x00st')
- res = _read_string(st, 4)
- assert_equal(res, b'ring')
- assert_raises(OSError, _read_string, st, 2)
- class TestZlibInputStream:
- def _get_data(self, size):
- data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
- compressed_data = zlib.compress(data)
- stream = BytesIO(compressed_data)
- return stream, len(compressed_data), data
- def test_read(self):
- SIZES = [0, 1, 10, BLOCK_SIZE//2, BLOCK_SIZE-1,
- BLOCK_SIZE, BLOCK_SIZE+1, 2*BLOCK_SIZE-1]
- READ_SIZES = [BLOCK_SIZE//2, BLOCK_SIZE-1,
- BLOCK_SIZE, BLOCK_SIZE+1]
- def check(size, read_size):
- compressed_stream, compressed_data_len, data = self._get_data(size)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- data2 = b''
- so_far = 0
- while True:
- block = stream.read(min(read_size,
- size - so_far))
- if not block:
- break
- so_far += len(block)
- data2 += block
- assert_equal(data, data2)
- for size in SIZES:
- for read_size in READ_SIZES:
- check(size, read_size)
- def test_read_max_length(self):
- size = 1234
- data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
- compressed_data = zlib.compress(data)
- compressed_stream = BytesIO(compressed_data + b"abbacaca")
- stream = ZlibInputStream(compressed_stream, len(compressed_data))
- stream.read(len(data))
- assert_equal(compressed_stream.tell(), len(compressed_data))
- assert_raises(OSError, stream.read, 1)
- def test_read_bad_checksum(self):
- data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
- compressed_data = zlib.compress(data)
- # break checksum
- compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
- compressed_stream = BytesIO(compressed_data)
- stream = ZlibInputStream(compressed_stream, len(compressed_data))
- assert_raises(zlib.error, stream.read, len(data))
- def test_seek(self):
- compressed_stream, compressed_data_len, data = self._get_data(1024)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- stream.seek(123)
- p = 123
- assert_equal(stream.tell(), p)
- d1 = stream.read(11)
- assert_equal(d1, data[p:p+11])
- stream.seek(321, 1)
- p = 123+11+321
- assert_equal(stream.tell(), p)
- d2 = stream.read(21)
- assert_equal(d2, data[p:p+21])
- stream.seek(641, 0)
- p = 641
- assert_equal(stream.tell(), p)
- d3 = stream.read(11)
- assert_equal(d3, data[p:p+11])
- assert_raises(OSError, stream.seek, 10, 2)
- assert_raises(OSError, stream.seek, -1, 1)
- assert_raises(ValueError, stream.seek, 1, 123)
- stream.seek(10000, 1)
- assert_raises(OSError, stream.read, 12)
- def test_seek_bad_checksum(self):
- data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
- compressed_data = zlib.compress(data)
- # break checksum
- compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
- compressed_stream = BytesIO(compressed_data)
- stream = ZlibInputStream(compressed_stream, len(compressed_data))
- assert_raises(zlib.error, stream.seek, len(data))
- def test_all_data_read(self):
- compressed_stream, compressed_data_len, data = self._get_data(1024)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- assert_(not stream.all_data_read())
- stream.seek(512)
- assert_(not stream.all_data_read())
- stream.seek(1024)
- assert_(stream.all_data_read())
- def test_all_data_read_overlap(self):
- COMPRESSION_LEVEL = 6
- data = np.arange(33707000).astype(np.uint8).tobytes()
- compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
- compressed_data_len = len(compressed_data)
- # check that part of the checksum overlaps
- assert_(compressed_data_len == BLOCK_SIZE + 2)
- compressed_stream = BytesIO(compressed_data)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- assert_(not stream.all_data_read())
- stream.seek(len(data))
- assert_(stream.all_data_read())
- def test_all_data_read_bad_checksum(self):
- COMPRESSION_LEVEL = 6
- data = np.arange(33707000).astype(np.uint8).tobytes()
- compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
- compressed_data_len = len(compressed_data)
- # check that part of the checksum overlaps
- assert_(compressed_data_len == BLOCK_SIZE + 2)
- # break checksum
- compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
- compressed_stream = BytesIO(compressed_data)
- stream = ZlibInputStream(compressed_stream, compressed_data_len)
- assert_(not stream.all_data_read())
- stream.seek(len(data))
- assert_raises(zlib.error, stream.all_data_read)
|