_matrix_io.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import numpy as np
  2. import scipy.sparse
  3. __all__ = ['save_npz', 'load_npz']
  4. # Make loading safe vs. malicious input
  5. PICKLE_KWARGS = dict(allow_pickle=False)
  6. def save_npz(file, matrix, compressed=True):
  7. """ Save a sparse matrix to a file using ``.npz`` format.
  8. Parameters
  9. ----------
  10. file : str or file-like object
  11. Either the file name (string) or an open file (file-like object)
  12. where the data will be saved. If file is a string, the ``.npz``
  13. extension will be appended to the file name if it is not already
  14. there.
  15. matrix: spmatrix (format: ``csc``, ``csr``, ``bsr``, ``dia`` or coo``)
  16. The sparse matrix to save.
  17. compressed : bool, optional
  18. Allow compressing the file. Default: True
  19. See Also
  20. --------
  21. scipy.sparse.load_npz: Load a sparse matrix from a file using ``.npz`` format.
  22. numpy.savez: Save several arrays into a ``.npz`` archive.
  23. numpy.savez_compressed : Save several arrays into a compressed ``.npz`` archive.
  24. Examples
  25. --------
  26. Store sparse matrix to disk, and load it again:
  27. >>> import numpy as np
  28. >>> import scipy.sparse
  29. >>> sparse_matrix = scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]]))
  30. >>> sparse_matrix
  31. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  32. with 2 stored elements in Compressed Sparse Column format>
  33. >>> sparse_matrix.toarray()
  34. array([[0, 0, 3],
  35. [4, 0, 0]], dtype=int64)
  36. >>> scipy.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
  37. >>> sparse_matrix = scipy.sparse.load_npz('/tmp/sparse_matrix.npz')
  38. >>> sparse_matrix
  39. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  40. with 2 stored elements in Compressed Sparse Column format>
  41. >>> sparse_matrix.toarray()
  42. array([[0, 0, 3],
  43. [4, 0, 0]], dtype=int64)
  44. """
  45. arrays_dict = {}
  46. if matrix.format in ('csc', 'csr', 'bsr'):
  47. arrays_dict.update(indices=matrix.indices, indptr=matrix.indptr)
  48. elif matrix.format == 'dia':
  49. arrays_dict.update(offsets=matrix.offsets)
  50. elif matrix.format == 'coo':
  51. arrays_dict.update(row=matrix.row, col=matrix.col)
  52. else:
  53. raise NotImplementedError('Save is not implemented for sparse matrix of format {}.'.format(matrix.format))
  54. arrays_dict.update(
  55. format=matrix.format.encode('ascii'),
  56. shape=matrix.shape,
  57. data=matrix.data
  58. )
  59. if compressed:
  60. np.savez_compressed(file, **arrays_dict)
  61. else:
  62. np.savez(file, **arrays_dict)
  63. def load_npz(file):
  64. """ Load a sparse matrix from a file using ``.npz`` format.
  65. Parameters
  66. ----------
  67. file : str or file-like object
  68. Either the file name (string) or an open file (file-like object)
  69. where the data will be loaded.
  70. Returns
  71. -------
  72. result : csc_matrix, csr_matrix, bsr_matrix, dia_matrix or coo_matrix
  73. A sparse matrix containing the loaded data.
  74. Raises
  75. ------
  76. OSError
  77. If the input file does not exist or cannot be read.
  78. See Also
  79. --------
  80. scipy.sparse.save_npz: Save a sparse matrix to a file using ``.npz`` format.
  81. numpy.load: Load several arrays from a ``.npz`` archive.
  82. Examples
  83. --------
  84. Store sparse matrix to disk, and load it again:
  85. >>> import numpy as np
  86. >>> import scipy.sparse
  87. >>> sparse_matrix = scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]]))
  88. >>> sparse_matrix
  89. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  90. with 2 stored elements in Compressed Sparse Column format>
  91. >>> sparse_matrix.toarray()
  92. array([[0, 0, 3],
  93. [4, 0, 0]], dtype=int64)
  94. >>> scipy.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
  95. >>> sparse_matrix = scipy.sparse.load_npz('/tmp/sparse_matrix.npz')
  96. >>> sparse_matrix
  97. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  98. with 2 stored elements in Compressed Sparse Column format>
  99. >>> sparse_matrix.toarray()
  100. array([[0, 0, 3],
  101. [4, 0, 0]], dtype=int64)
  102. """
  103. with np.load(file, **PICKLE_KWARGS) as loaded:
  104. try:
  105. matrix_format = loaded['format']
  106. except KeyError as e:
  107. raise ValueError('The file {} does not contain a sparse matrix.'.format(file)) from e
  108. matrix_format = matrix_format.item()
  109. if not isinstance(matrix_format, str):
  110. # Play safe with Python 2 vs 3 backward compatibility;
  111. # files saved with SciPy < 1.0.0 may contain unicode or bytes.
  112. matrix_format = matrix_format.decode('ascii')
  113. try:
  114. cls = getattr(scipy.sparse, '{}_matrix'.format(matrix_format))
  115. except AttributeError as e:
  116. raise ValueError('Unknown matrix format "{}"'.format(matrix_format)) from e
  117. if matrix_format in ('csc', 'csr', 'bsr'):
  118. return cls((loaded['data'], loaded['indices'], loaded['indptr']), shape=loaded['shape'])
  119. elif matrix_format == 'dia':
  120. return cls((loaded['data'], loaded['offsets']), shape=loaded['shape'])
  121. elif matrix_format == 'coo':
  122. return cls((loaded['data'], (loaded['row'], loaded['col'])), shape=loaded['shape'])
  123. else:
  124. raise NotImplementedError('Load is not implemented for '
  125. 'sparse matrix of format {}.'.format(matrix_format))