test_matrix_io.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import os
  2. import numpy as np
  3. import tempfile
  4. from pytest import raises as assert_raises
  5. from numpy.testing import assert_equal, assert_
  6. from scipy.sparse import (csc_matrix, csr_matrix, bsr_matrix, dia_matrix,
  7. coo_matrix, save_npz, load_npz, dok_matrix)
  8. DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
  9. def _save_and_load(matrix):
  10. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  11. os.close(fd)
  12. try:
  13. save_npz(tmpfile, matrix)
  14. loaded_matrix = load_npz(tmpfile)
  15. finally:
  16. os.remove(tmpfile)
  17. return loaded_matrix
  18. def _check_save_and_load(dense_matrix):
  19. for matrix_class in [csc_matrix, csr_matrix, bsr_matrix, dia_matrix, coo_matrix]:
  20. matrix = matrix_class(dense_matrix)
  21. loaded_matrix = _save_and_load(matrix)
  22. assert_(type(loaded_matrix) is matrix_class)
  23. assert_(loaded_matrix.shape == dense_matrix.shape)
  24. assert_(loaded_matrix.dtype == dense_matrix.dtype)
  25. assert_equal(loaded_matrix.toarray(), dense_matrix)
  26. def test_save_and_load_random():
  27. N = 10
  28. np.random.seed(0)
  29. dense_matrix = np.random.random((N, N))
  30. dense_matrix[dense_matrix > 0.7] = 0
  31. _check_save_and_load(dense_matrix)
  32. def test_save_and_load_empty():
  33. dense_matrix = np.zeros((4,6))
  34. _check_save_and_load(dense_matrix)
  35. def test_save_and_load_one_entry():
  36. dense_matrix = np.zeros((4,6))
  37. dense_matrix[1,2] = 1
  38. _check_save_and_load(dense_matrix)
  39. def test_malicious_load():
  40. class Executor:
  41. def __reduce__(self):
  42. return (assert_, (False, 'unexpected code execution'))
  43. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  44. os.close(fd)
  45. try:
  46. np.savez(tmpfile, format=Executor())
  47. # Should raise a ValueError, not execute code
  48. assert_raises(ValueError, load_npz, tmpfile)
  49. finally:
  50. os.remove(tmpfile)
  51. def test_py23_compatibility():
  52. # Try loading files saved on Python 2 and Python 3. They are not
  53. # the same, since files saved with SciPy versions < 1.0.0 may
  54. # contain unicode.
  55. a = load_npz(os.path.join(DATA_DIR, 'csc_py2.npz'))
  56. b = load_npz(os.path.join(DATA_DIR, 'csc_py3.npz'))
  57. c = csc_matrix([[0]])
  58. assert_equal(a.toarray(), c.toarray())
  59. assert_equal(b.toarray(), c.toarray())
  60. def test_implemented_error():
  61. # Attempts to save an unsupported type and checks that an
  62. # NotImplementedError is raised.
  63. x = dok_matrix((2,3))
  64. x[0,1] = 1
  65. assert_raises(NotImplementedError, save_npz, 'x.npz', x)