_mmio.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996
  1. """
  2. Matrix Market I/O in Python.
  3. See http://math.nist.gov/MatrixMarket/formats.html
  4. for information about the Matrix Market format.
  5. """
  6. #
  7. # Author: Pearu Peterson <pearu@cens.ioc.ee>
  8. # Created: October, 2004
  9. #
  10. # References:
  11. # http://math.nist.gov/MatrixMarket/
  12. #
  13. import os
  14. import numpy as np
  15. from numpy import (asarray, real, imag, conj, zeros, ndarray, concatenate,
  16. ones, can_cast)
  17. from scipy.sparse import coo_matrix, isspmatrix
  18. __all__ = ['mminfo', 'mmread', 'mmwrite', 'MMFile']
  19. # -----------------------------------------------------------------------------
  20. def asstr(s):
  21. if isinstance(s, bytes):
  22. return s.decode('latin1')
  23. return str(s)
  24. def mminfo(source):
  25. """
  26. Return size and storage parameters from Matrix Market file-like 'source'.
  27. Parameters
  28. ----------
  29. source : str or file-like
  30. Matrix Market filename (extension .mtx) or open file-like object
  31. Returns
  32. -------
  33. rows : int
  34. Number of matrix rows.
  35. cols : int
  36. Number of matrix columns.
  37. entries : int
  38. Number of non-zero entries of a sparse matrix
  39. or rows*cols for a dense matrix.
  40. format : str
  41. Either 'coordinate' or 'array'.
  42. field : str
  43. Either 'real', 'complex', 'pattern', or 'integer'.
  44. symmetry : str
  45. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  46. Examples
  47. --------
  48. >>> from io import StringIO
  49. >>> from scipy.io import mminfo
  50. >>> text = '''%%MatrixMarket matrix coordinate real general
  51. ... 5 5 7
  52. ... 2 3 1.0
  53. ... 3 4 2.0
  54. ... 3 5 3.0
  55. ... 4 1 4.0
  56. ... 4 2 5.0
  57. ... 4 3 6.0
  58. ... 4 4 7.0
  59. ... '''
  60. ``mminfo(source)`` returns the number of rows, number of columns,
  61. format, field type and symmetry attribute of the source file.
  62. >>> mminfo(StringIO(text))
  63. (5, 5, 7, 'coordinate', 'real', 'general')
  64. """
  65. return MMFile.info(source)
  66. # -----------------------------------------------------------------------------
  67. def mmread(source):
  68. """
  69. Reads the contents of a Matrix Market file-like 'source' into a matrix.
  70. Parameters
  71. ----------
  72. source : str or file-like
  73. Matrix Market filename (extensions .mtx, .mtz.gz)
  74. or open file-like object.
  75. Returns
  76. -------
  77. a : ndarray or coo_matrix
  78. Dense or sparse matrix depending on the matrix format in the
  79. Matrix Market file.
  80. Examples
  81. --------
  82. >>> from io import StringIO
  83. >>> from scipy.io import mmread
  84. >>> text = '''%%MatrixMarket matrix coordinate real general
  85. ... 5 5 7
  86. ... 2 3 1.0
  87. ... 3 4 2.0
  88. ... 3 5 3.0
  89. ... 4 1 4.0
  90. ... 4 2 5.0
  91. ... 4 3 6.0
  92. ... 4 4 7.0
  93. ... '''
  94. ``mmread(source)`` returns the data as sparse matrix in COO format.
  95. >>> m = mmread(StringIO(text))
  96. >>> m
  97. <5x5 sparse matrix of type '<class 'numpy.float64'>'
  98. with 7 stored elements in COOrdinate format>
  99. >>> m.A
  100. array([[0., 0., 0., 0., 0.],
  101. [0., 0., 1., 0., 0.],
  102. [0., 0., 0., 2., 3.],
  103. [4., 5., 6., 7., 0.],
  104. [0., 0., 0., 0., 0.]])
  105. """
  106. return MMFile().read(source)
  107. # -----------------------------------------------------------------------------
  108. def mmwrite(target, a, comment='', field=None, precision=None, symmetry=None):
  109. r"""
  110. Writes the sparse or dense array `a` to Matrix Market file-like `target`.
  111. Parameters
  112. ----------
  113. target : str or file-like
  114. Matrix Market filename (extension .mtx) or open file-like object.
  115. a : array like
  116. Sparse or dense 2-D array.
  117. comment : str, optional
  118. Comments to be prepended to the Matrix Market file.
  119. field : None or str, optional
  120. Either 'real', 'complex', 'pattern', or 'integer'.
  121. precision : None or int, optional
  122. Number of digits to display for real or complex values.
  123. symmetry : None or str, optional
  124. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  125. If symmetry is None the symmetry type of 'a' is determined by its
  126. values.
  127. Returns
  128. -------
  129. None
  130. Examples
  131. --------
  132. >>> from io import BytesIO
  133. >>> import numpy as np
  134. >>> from scipy.sparse import coo_matrix
  135. >>> from scipy.io import mmwrite
  136. Write a small NumPy array to a matrix market file. The file will be
  137. written in the ``'array'`` format.
  138. >>> a = np.array([[1.0, 0, 0, 0], [0, 2.5, 0, 6.25]])
  139. >>> target = BytesIO()
  140. >>> mmwrite(target, a)
  141. >>> print(target.getvalue().decode('latin1'))
  142. %%MatrixMarket matrix array real general
  143. %
  144. 2 4
  145. 1.0000000000000000e+00
  146. 0.0000000000000000e+00
  147. 0.0000000000000000e+00
  148. 2.5000000000000000e+00
  149. 0.0000000000000000e+00
  150. 0.0000000000000000e+00
  151. 0.0000000000000000e+00
  152. 6.2500000000000000e+00
  153. Add a comment to the output file, and set the precision to 3.
  154. >>> target = BytesIO()
  155. >>> mmwrite(target, a, comment='\n Some test data.\n', precision=3)
  156. >>> print(target.getvalue().decode('latin1'))
  157. %%MatrixMarket matrix array real general
  158. %
  159. % Some test data.
  160. %
  161. 2 4
  162. 1.000e+00
  163. 0.000e+00
  164. 0.000e+00
  165. 2.500e+00
  166. 0.000e+00
  167. 0.000e+00
  168. 0.000e+00
  169. 6.250e+00
  170. Convert to a sparse matrix before calling ``mmwrite``. This will
  171. result in the output format being ``'coordinate'`` rather than
  172. ``'array'``.
  173. >>> target = BytesIO()
  174. >>> mmwrite(target, coo_matrix(a), precision=3)
  175. >>> print(target.getvalue().decode('latin1'))
  176. %%MatrixMarket matrix coordinate real general
  177. %
  178. 2 4 3
  179. 1 1 1.00e+00
  180. 2 2 2.50e+00
  181. 2 4 6.25e+00
  182. Write a complex Hermitian array to a matrix market file. Note that
  183. only six values are actually written to the file; the other values
  184. are implied by the symmetry.
  185. >>> z = np.array([[3, 1+2j, 4-3j], [1-2j, 1, -5j], [4+3j, 5j, 2.5]])
  186. >>> z
  187. array([[ 3. +0.j, 1. +2.j, 4. -3.j],
  188. [ 1. -2.j, 1. +0.j, -0. -5.j],
  189. [ 4. +3.j, 0. +5.j, 2.5+0.j]])
  190. >>> target = BytesIO()
  191. >>> mmwrite(target, z, precision=2)
  192. >>> print(target.getvalue().decode('latin1'))
  193. %%MatrixMarket matrix array complex hermitian
  194. %
  195. 3 3
  196. 3.00e+00 0.00e+00
  197. 1.00e+00 -2.00e+00
  198. 4.00e+00 3.00e+00
  199. 1.00e+00 0.00e+00
  200. 0.00e+00 5.00e+00
  201. 2.50e+00 0.00e+00
  202. """
  203. MMFile().write(target, a, comment, field, precision, symmetry)
  204. ###############################################################################
  205. class MMFile:
  206. __slots__ = ('_rows',
  207. '_cols',
  208. '_entries',
  209. '_format',
  210. '_field',
  211. '_symmetry')
  212. @property
  213. def rows(self):
  214. return self._rows
  215. @property
  216. def cols(self):
  217. return self._cols
  218. @property
  219. def entries(self):
  220. return self._entries
  221. @property
  222. def format(self):
  223. return self._format
  224. @property
  225. def field(self):
  226. return self._field
  227. @property
  228. def symmetry(self):
  229. return self._symmetry
  230. @property
  231. def has_symmetry(self):
  232. return self._symmetry in (self.SYMMETRY_SYMMETRIC,
  233. self.SYMMETRY_SKEW_SYMMETRIC,
  234. self.SYMMETRY_HERMITIAN)
  235. # format values
  236. FORMAT_COORDINATE = 'coordinate'
  237. FORMAT_ARRAY = 'array'
  238. FORMAT_VALUES = (FORMAT_COORDINATE, FORMAT_ARRAY)
  239. @classmethod
  240. def _validate_format(self, format):
  241. if format not in self.FORMAT_VALUES:
  242. raise ValueError('unknown format type %s, must be one of %s' %
  243. (format, self.FORMAT_VALUES))
  244. # field values
  245. FIELD_INTEGER = 'integer'
  246. FIELD_UNSIGNED = 'unsigned-integer'
  247. FIELD_REAL = 'real'
  248. FIELD_COMPLEX = 'complex'
  249. FIELD_PATTERN = 'pattern'
  250. FIELD_VALUES = (FIELD_INTEGER, FIELD_UNSIGNED, FIELD_REAL, FIELD_COMPLEX,
  251. FIELD_PATTERN)
  252. @classmethod
  253. def _validate_field(self, field):
  254. if field not in self.FIELD_VALUES:
  255. raise ValueError('unknown field type %s, must be one of %s' %
  256. (field, self.FIELD_VALUES))
  257. # symmetry values
  258. SYMMETRY_GENERAL = 'general'
  259. SYMMETRY_SYMMETRIC = 'symmetric'
  260. SYMMETRY_SKEW_SYMMETRIC = 'skew-symmetric'
  261. SYMMETRY_HERMITIAN = 'hermitian'
  262. SYMMETRY_VALUES = (SYMMETRY_GENERAL, SYMMETRY_SYMMETRIC,
  263. SYMMETRY_SKEW_SYMMETRIC, SYMMETRY_HERMITIAN)
  264. @classmethod
  265. def _validate_symmetry(self, symmetry):
  266. if symmetry not in self.SYMMETRY_VALUES:
  267. raise ValueError('unknown symmetry type %s, must be one of %s' %
  268. (symmetry, self.SYMMETRY_VALUES))
  269. DTYPES_BY_FIELD = {FIELD_INTEGER: 'intp',
  270. FIELD_UNSIGNED: 'uint64',
  271. FIELD_REAL: 'd',
  272. FIELD_COMPLEX: 'D',
  273. FIELD_PATTERN: 'd'}
  274. # -------------------------------------------------------------------------
  275. @staticmethod
  276. def reader():
  277. pass
  278. # -------------------------------------------------------------------------
  279. @staticmethod
  280. def writer():
  281. pass
  282. # -------------------------------------------------------------------------
  283. @classmethod
  284. def info(self, source):
  285. """
  286. Return size, storage parameters from Matrix Market file-like 'source'.
  287. Parameters
  288. ----------
  289. source : str or file-like
  290. Matrix Market filename (extension .mtx) or open file-like object
  291. Returns
  292. -------
  293. rows : int
  294. Number of matrix rows.
  295. cols : int
  296. Number of matrix columns.
  297. entries : int
  298. Number of non-zero entries of a sparse matrix
  299. or rows*cols for a dense matrix.
  300. format : str
  301. Either 'coordinate' or 'array'.
  302. field : str
  303. Either 'real', 'complex', 'pattern', or 'integer'.
  304. symmetry : str
  305. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  306. """
  307. stream, close_it = self._open(source)
  308. try:
  309. # read and validate header line
  310. line = stream.readline()
  311. mmid, matrix, format, field, symmetry = \
  312. [asstr(part.strip()) for part in line.split()]
  313. if not mmid.startswith('%%MatrixMarket'):
  314. raise ValueError('source is not in Matrix Market format')
  315. if not matrix.lower() == 'matrix':
  316. raise ValueError("Problem reading file header: " + line)
  317. # http://math.nist.gov/MatrixMarket/formats.html
  318. if format.lower() == 'array':
  319. format = self.FORMAT_ARRAY
  320. elif format.lower() == 'coordinate':
  321. format = self.FORMAT_COORDINATE
  322. # skip comments
  323. # line.startswith('%')
  324. while line and line[0] in ['%', 37]:
  325. line = stream.readline()
  326. # skip empty lines
  327. while not line.strip():
  328. line = stream.readline()
  329. split_line = line.split()
  330. if format == self.FORMAT_ARRAY:
  331. if not len(split_line) == 2:
  332. raise ValueError("Header line not of length 2: " +
  333. line.decode('ascii'))
  334. rows, cols = map(int, split_line)
  335. entries = rows * cols
  336. else:
  337. if not len(split_line) == 3:
  338. raise ValueError("Header line not of length 3: " +
  339. line.decode('ascii'))
  340. rows, cols, entries = map(int, split_line)
  341. return (rows, cols, entries, format, field.lower(),
  342. symmetry.lower())
  343. finally:
  344. if close_it:
  345. stream.close()
  346. # -------------------------------------------------------------------------
  347. @staticmethod
  348. def _open(filespec, mode='rb'):
  349. """ Return an open file stream for reading based on source.
  350. If source is a file name, open it (after trying to find it with mtx and
  351. gzipped mtx extensions). Otherwise, just return source.
  352. Parameters
  353. ----------
  354. filespec : str or file-like
  355. String giving file name or file-like object
  356. mode : str, optional
  357. Mode with which to open file, if `filespec` is a file name.
  358. Returns
  359. -------
  360. fobj : file-like
  361. Open file-like object.
  362. close_it : bool
  363. True if the calling function should close this file when done,
  364. false otherwise.
  365. """
  366. # If 'filespec' is path-like (str, pathlib.Path, os.DirEntry, other class
  367. # implementing a '__fspath__' method), try to convert it to str. If this
  368. # fails by throwing a 'TypeError', assume it's an open file handle and
  369. # return it as-is.
  370. try:
  371. filespec = os.fspath(filespec)
  372. except TypeError:
  373. return filespec, False
  374. # 'filespec' is definitely a str now
  375. # open for reading
  376. if mode[0] == 'r':
  377. # determine filename plus extension
  378. if not os.path.isfile(filespec):
  379. if os.path.isfile(filespec+'.mtx'):
  380. filespec = filespec + '.mtx'
  381. elif os.path.isfile(filespec+'.mtx.gz'):
  382. filespec = filespec + '.mtx.gz'
  383. elif os.path.isfile(filespec+'.mtx.bz2'):
  384. filespec = filespec + '.mtx.bz2'
  385. # open filename
  386. if filespec.endswith('.gz'):
  387. import gzip
  388. stream = gzip.open(filespec, mode)
  389. elif filespec.endswith('.bz2'):
  390. import bz2
  391. stream = bz2.BZ2File(filespec, 'rb')
  392. else:
  393. stream = open(filespec, mode)
  394. # open for writing
  395. else:
  396. if filespec[-4:] != '.mtx':
  397. filespec = filespec + '.mtx'
  398. stream = open(filespec, mode)
  399. return stream, True
  400. # -------------------------------------------------------------------------
  401. @staticmethod
  402. def _get_symmetry(a):
  403. m, n = a.shape
  404. if m != n:
  405. return MMFile.SYMMETRY_GENERAL
  406. issymm = True
  407. isskew = True
  408. isherm = a.dtype.char in 'FD'
  409. # sparse input
  410. if isspmatrix(a):
  411. # check if number of nonzero entries of lower and upper triangle
  412. # matrix are equal
  413. a = a.tocoo()
  414. (row, col) = a.nonzero()
  415. if (row < col).sum() != (row > col).sum():
  416. return MMFile.SYMMETRY_GENERAL
  417. # define iterator over symmetric pair entries
  418. a = a.todok()
  419. def symm_iterator():
  420. for ((i, j), aij) in a.items():
  421. if i > j:
  422. aji = a[j, i]
  423. yield (aij, aji, False)
  424. elif i == j:
  425. yield (aij, aij, True)
  426. # non-sparse input
  427. else:
  428. # define iterator over symmetric pair entries
  429. def symm_iterator():
  430. for j in range(n):
  431. for i in range(j, n):
  432. aij, aji = a[i][j], a[j][i]
  433. yield (aij, aji, i == j)
  434. # check for symmetry
  435. # yields aij, aji, is_diagonal
  436. for (aij, aji, is_diagonal) in symm_iterator():
  437. if isskew and is_diagonal and aij != 0:
  438. isskew = False
  439. else:
  440. if issymm and aij != aji:
  441. issymm = False
  442. with np.errstate(over="ignore"):
  443. # This can give a warning for uint dtypes, so silence that
  444. if isskew and aij != -aji:
  445. isskew = False
  446. if isherm and aij != conj(aji):
  447. isherm = False
  448. if not (issymm or isskew or isherm):
  449. break
  450. # return symmetry value
  451. if issymm:
  452. return MMFile.SYMMETRY_SYMMETRIC
  453. if isskew:
  454. return MMFile.SYMMETRY_SKEW_SYMMETRIC
  455. if isherm:
  456. return MMFile.SYMMETRY_HERMITIAN
  457. return MMFile.SYMMETRY_GENERAL
  458. # -------------------------------------------------------------------------
  459. @staticmethod
  460. def _field_template(field, precision):
  461. return {MMFile.FIELD_REAL: '%%.%ie\n' % precision,
  462. MMFile.FIELD_INTEGER: '%i\n',
  463. MMFile.FIELD_UNSIGNED: '%u\n',
  464. MMFile.FIELD_COMPLEX: '%%.%ie %%.%ie\n' %
  465. (precision, precision)
  466. }.get(field, None)
  467. # -------------------------------------------------------------------------
  468. def __init__(self, **kwargs):
  469. self._init_attrs(**kwargs)
  470. # -------------------------------------------------------------------------
  471. def read(self, source):
  472. """
  473. Reads the contents of a Matrix Market file-like 'source' into a matrix.
  474. Parameters
  475. ----------
  476. source : str or file-like
  477. Matrix Market filename (extensions .mtx, .mtz.gz)
  478. or open file object.
  479. Returns
  480. -------
  481. a : ndarray or coo_matrix
  482. Dense or sparse matrix depending on the matrix format in the
  483. Matrix Market file.
  484. """
  485. stream, close_it = self._open(source)
  486. try:
  487. self._parse_header(stream)
  488. return self._parse_body(stream)
  489. finally:
  490. if close_it:
  491. stream.close()
  492. # -------------------------------------------------------------------------
  493. def write(self, target, a, comment='', field=None, precision=None,
  494. symmetry=None):
  495. """
  496. Writes sparse or dense array `a` to Matrix Market file-like `target`.
  497. Parameters
  498. ----------
  499. target : str or file-like
  500. Matrix Market filename (extension .mtx) or open file-like object.
  501. a : array like
  502. Sparse or dense 2-D array.
  503. comment : str, optional
  504. Comments to be prepended to the Matrix Market file.
  505. field : None or str, optional
  506. Either 'real', 'complex', 'pattern', or 'integer'.
  507. precision : None or int, optional
  508. Number of digits to display for real or complex values.
  509. symmetry : None or str, optional
  510. Either 'general', 'symmetric', 'skew-symmetric', or 'hermitian'.
  511. If symmetry is None the symmetry type of 'a' is determined by its
  512. values.
  513. """
  514. stream, close_it = self._open(target, 'wb')
  515. try:
  516. self._write(stream, a, comment, field, precision, symmetry)
  517. finally:
  518. if close_it:
  519. stream.close()
  520. else:
  521. stream.flush()
  522. # -------------------------------------------------------------------------
  523. def _init_attrs(self, **kwargs):
  524. """
  525. Initialize each attributes with the corresponding keyword arg value
  526. or a default of None
  527. """
  528. attrs = self.__class__.__slots__
  529. public_attrs = [attr[1:] for attr in attrs]
  530. invalid_keys = set(kwargs.keys()) - set(public_attrs)
  531. if invalid_keys:
  532. raise ValueError('''found %s invalid keyword arguments, please only
  533. use %s''' % (tuple(invalid_keys),
  534. public_attrs))
  535. for attr in attrs:
  536. setattr(self, attr, kwargs.get(attr[1:], None))
  537. # -------------------------------------------------------------------------
  538. def _parse_header(self, stream):
  539. rows, cols, entries, format, field, symmetry = \
  540. self.__class__.info(stream)
  541. self._init_attrs(rows=rows, cols=cols, entries=entries, format=format,
  542. field=field, symmetry=symmetry)
  543. # -------------------------------------------------------------------------
  544. def _parse_body(self, stream):
  545. rows, cols, entries, format, field, symm = (self.rows, self.cols,
  546. self.entries, self.format,
  547. self.field, self.symmetry)
  548. try:
  549. from scipy.sparse import coo_matrix
  550. except ImportError:
  551. coo_matrix = None
  552. dtype = self.DTYPES_BY_FIELD.get(field, None)
  553. has_symmetry = self.has_symmetry
  554. is_integer = field == self.FIELD_INTEGER
  555. is_unsigned_integer = field == self.FIELD_UNSIGNED
  556. is_complex = field == self.FIELD_COMPLEX
  557. is_skew = symm == self.SYMMETRY_SKEW_SYMMETRIC
  558. is_herm = symm == self.SYMMETRY_HERMITIAN
  559. is_pattern = field == self.FIELD_PATTERN
  560. if format == self.FORMAT_ARRAY:
  561. a = zeros((rows, cols), dtype=dtype)
  562. line = 1
  563. i, j = 0, 0
  564. if is_skew:
  565. a[i, j] = 0
  566. if i < rows - 1:
  567. i += 1
  568. while line:
  569. line = stream.readline()
  570. # line.startswith('%')
  571. if not line or line[0] in ['%', 37] or not line.strip():
  572. continue
  573. if is_integer:
  574. aij = int(line)
  575. elif is_unsigned_integer:
  576. aij = int(line)
  577. elif is_complex:
  578. aij = complex(*map(float, line.split()))
  579. else:
  580. aij = float(line)
  581. a[i, j] = aij
  582. if has_symmetry and i != j:
  583. if is_skew:
  584. a[j, i] = -aij
  585. elif is_herm:
  586. a[j, i] = conj(aij)
  587. else:
  588. a[j, i] = aij
  589. if i < rows-1:
  590. i = i + 1
  591. else:
  592. j = j + 1
  593. if not has_symmetry:
  594. i = 0
  595. else:
  596. i = j
  597. if is_skew:
  598. a[i, j] = 0
  599. if i < rows-1:
  600. i += 1
  601. if is_skew:
  602. if not (i in [0, j] and j == cols - 1):
  603. raise ValueError("Parse error, did not read all lines.")
  604. else:
  605. if not (i in [0, j] and j == cols):
  606. raise ValueError("Parse error, did not read all lines.")
  607. elif format == self.FORMAT_COORDINATE and coo_matrix is None:
  608. # Read sparse matrix to dense when coo_matrix is not available.
  609. a = zeros((rows, cols), dtype=dtype)
  610. line = 1
  611. k = 0
  612. while line:
  613. line = stream.readline()
  614. # line.startswith('%')
  615. if not line or line[0] in ['%', 37] or not line.strip():
  616. continue
  617. l = line.split()
  618. i, j = map(int, l[:2])
  619. i, j = i-1, j-1
  620. if is_integer:
  621. aij = int(l[2])
  622. elif is_unsigned_integer:
  623. aij = int(l[2])
  624. elif is_complex:
  625. aij = complex(*map(float, l[2:]))
  626. else:
  627. aij = float(l[2])
  628. a[i, j] = aij
  629. if has_symmetry and i != j:
  630. if is_skew:
  631. a[j, i] = -aij
  632. elif is_herm:
  633. a[j, i] = conj(aij)
  634. else:
  635. a[j, i] = aij
  636. k = k + 1
  637. if not k == entries:
  638. ValueError("Did not read all entries")
  639. elif format == self.FORMAT_COORDINATE:
  640. # Read sparse COOrdinate format
  641. if entries == 0:
  642. # empty matrix
  643. return coo_matrix((rows, cols), dtype=dtype)
  644. I = zeros(entries, dtype='intc')
  645. J = zeros(entries, dtype='intc')
  646. if is_pattern:
  647. V = ones(entries, dtype='int8')
  648. elif is_integer:
  649. V = zeros(entries, dtype='intp')
  650. elif is_unsigned_integer:
  651. V = zeros(entries, dtype='uint64')
  652. elif is_complex:
  653. V = zeros(entries, dtype='complex')
  654. else:
  655. V = zeros(entries, dtype='float')
  656. entry_number = 0
  657. for line in stream:
  658. # line.startswith('%')
  659. if not line or line[0] in ['%', 37] or not line.strip():
  660. continue
  661. if entry_number+1 > entries:
  662. raise ValueError("'entries' in header is smaller than "
  663. "number of entries")
  664. l = line.split()
  665. I[entry_number], J[entry_number] = map(int, l[:2])
  666. if not is_pattern:
  667. if is_integer:
  668. V[entry_number] = int(l[2])
  669. elif is_unsigned_integer:
  670. V[entry_number] = int(l[2])
  671. elif is_complex:
  672. V[entry_number] = complex(*map(float, l[2:]))
  673. else:
  674. V[entry_number] = float(l[2])
  675. entry_number += 1
  676. if entry_number < entries:
  677. raise ValueError("'entries' in header is larger than "
  678. "number of entries")
  679. I -= 1 # adjust indices (base 1 -> base 0)
  680. J -= 1
  681. if has_symmetry:
  682. mask = (I != J) # off diagonal mask
  683. od_I = I[mask]
  684. od_J = J[mask]
  685. od_V = V[mask]
  686. I = concatenate((I, od_J))
  687. J = concatenate((J, od_I))
  688. if is_skew:
  689. od_V *= -1
  690. elif is_herm:
  691. od_V = od_V.conjugate()
  692. V = concatenate((V, od_V))
  693. a = coo_matrix((V, (I, J)), shape=(rows, cols), dtype=dtype)
  694. else:
  695. raise NotImplementedError(format)
  696. return a
  697. # ------------------------------------------------------------------------
  698. def _write(self, stream, a, comment='', field=None, precision=None,
  699. symmetry=None):
  700. if isinstance(a, list) or isinstance(a, ndarray) or \
  701. isinstance(a, tuple) or hasattr(a, '__array__'):
  702. rep = self.FORMAT_ARRAY
  703. a = asarray(a)
  704. if len(a.shape) != 2:
  705. raise ValueError('Expected 2 dimensional array')
  706. rows, cols = a.shape
  707. if field is not None:
  708. if field == self.FIELD_INTEGER:
  709. if not can_cast(a.dtype, 'intp'):
  710. raise OverflowError("mmwrite does not support integer "
  711. "dtypes larger than native 'intp'.")
  712. a = a.astype('intp')
  713. elif field == self.FIELD_REAL:
  714. if a.dtype.char not in 'fd':
  715. a = a.astype('d')
  716. elif field == self.FIELD_COMPLEX:
  717. if a.dtype.char not in 'FD':
  718. a = a.astype('D')
  719. else:
  720. if not isspmatrix(a):
  721. raise ValueError('unknown matrix type: %s' % type(a))
  722. rep = 'coordinate'
  723. rows, cols = a.shape
  724. typecode = a.dtype.char
  725. if precision is None:
  726. if typecode in 'fF':
  727. precision = 8
  728. else:
  729. precision = 16
  730. if field is None:
  731. kind = a.dtype.kind
  732. if kind == 'i':
  733. if not can_cast(a.dtype, 'intp'):
  734. raise OverflowError("mmwrite does not support integer "
  735. "dtypes larger than native 'intp'.")
  736. field = 'integer'
  737. elif kind == 'f':
  738. field = 'real'
  739. elif kind == 'c':
  740. field = 'complex'
  741. elif kind == 'u':
  742. field = 'unsigned-integer'
  743. else:
  744. raise TypeError('unexpected dtype kind ' + kind)
  745. if symmetry is None:
  746. symmetry = self._get_symmetry(a)
  747. # validate rep, field, and symmetry
  748. self.__class__._validate_format(rep)
  749. self.__class__._validate_field(field)
  750. self.__class__._validate_symmetry(symmetry)
  751. # write initial header line
  752. data = f'%%MatrixMarket matrix {rep} {field} {symmetry}\n'
  753. stream.write(data.encode('latin1'))
  754. # write comments
  755. for line in comment.split('\n'):
  756. data = '%%%s\n' % (line)
  757. stream.write(data.encode('latin1'))
  758. template = self._field_template(field, precision)
  759. # write dense format
  760. if rep == self.FORMAT_ARRAY:
  761. # write shape spec
  762. data = '%i %i\n' % (rows, cols)
  763. stream.write(data.encode('latin1'))
  764. if field in (self.FIELD_INTEGER, self.FIELD_REAL,
  765. self.FIELD_UNSIGNED):
  766. if symmetry == self.SYMMETRY_GENERAL:
  767. for j in range(cols):
  768. for i in range(rows):
  769. data = template % a[i, j]
  770. stream.write(data.encode('latin1'))
  771. elif symmetry == self.SYMMETRY_SKEW_SYMMETRIC:
  772. for j in range(cols):
  773. for i in range(j + 1, rows):
  774. data = template % a[i, j]
  775. stream.write(data.encode('latin1'))
  776. else:
  777. for j in range(cols):
  778. for i in range(j, rows):
  779. data = template % a[i, j]
  780. stream.write(data.encode('latin1'))
  781. elif field == self.FIELD_COMPLEX:
  782. if symmetry == self.SYMMETRY_GENERAL:
  783. for j in range(cols):
  784. for i in range(rows):
  785. aij = a[i, j]
  786. data = template % (real(aij), imag(aij))
  787. stream.write(data.encode('latin1'))
  788. else:
  789. for j in range(cols):
  790. for i in range(j, rows):
  791. aij = a[i, j]
  792. data = template % (real(aij), imag(aij))
  793. stream.write(data.encode('latin1'))
  794. elif field == self.FIELD_PATTERN:
  795. raise ValueError('pattern type inconsisted with dense format')
  796. else:
  797. raise TypeError('Unknown field type %s' % field)
  798. # write sparse format
  799. else:
  800. coo = a.tocoo() # convert to COOrdinate format
  801. # if symmetry format used, remove values above main diagonal
  802. if symmetry != self.SYMMETRY_GENERAL:
  803. lower_triangle_mask = coo.row >= coo.col
  804. coo = coo_matrix((coo.data[lower_triangle_mask],
  805. (coo.row[lower_triangle_mask],
  806. coo.col[lower_triangle_mask])),
  807. shape=coo.shape)
  808. # write shape spec
  809. data = '%i %i %i\n' % (rows, cols, coo.nnz)
  810. stream.write(data.encode('latin1'))
  811. template = self._field_template(field, precision-1)
  812. if field == self.FIELD_PATTERN:
  813. for r, c in zip(coo.row+1, coo.col+1):
  814. data = "%i %i\n" % (r, c)
  815. stream.write(data.encode('latin1'))
  816. elif field in (self.FIELD_INTEGER, self.FIELD_REAL,
  817. self.FIELD_UNSIGNED):
  818. for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
  819. data = ("%i %i " % (r, c)) + (template % d)
  820. stream.write(data.encode('latin1'))
  821. elif field == self.FIELD_COMPLEX:
  822. for r, c, d in zip(coo.row+1, coo.col+1, coo.data):
  823. data = ("%i %i " % (r, c)) + (template % (d.real, d.imag))
  824. stream.write(data.encode('latin1'))
  825. else:
  826. raise TypeError('Unknown field type %s' % field)
  827. def _is_fromfile_compatible(stream):
  828. """
  829. Check whether `stream` is compatible with numpy.fromfile.
  830. Passing a gzipped file object to ``fromfile/fromstring`` doesn't work with
  831. Python 3.
  832. """
  833. bad_cls = []
  834. try:
  835. import gzip
  836. bad_cls.append(gzip.GzipFile)
  837. except ImportError:
  838. pass
  839. try:
  840. import bz2
  841. bad_cls.append(bz2.BZ2File)
  842. except ImportError:
  843. pass
  844. bad_cls = tuple(bad_cls)
  845. return not isinstance(stream, bad_cls)