_mio5.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. ''' Classes for read / write of matlab (TM) 5 files
  2. The matfile specification last found here:
  3. https://www.mathworks.com/access/helpdesk/help/pdf_doc/matlab/matfile_format.pdf
  4. (as of December 5 2008)
  5. '''
  6. '''
  7. =================================
  8. Note on functions and mat files
  9. =================================
  10. The document above does not give any hints as to the storage of matlab
  11. function handles, or anonymous function handles. I had, therefore, to
  12. guess the format of matlab arrays of ``mxFUNCTION_CLASS`` and
  13. ``mxOPAQUE_CLASS`` by looking at example mat files.
  14. ``mxFUNCTION_CLASS`` stores all types of matlab functions. It seems to
  15. contain a struct matrix with a set pattern of fields. For anonymous
  16. functions, a sub-fields of one of these fields seems to contain the
  17. well-named ``mxOPAQUE_CLASS``. This seems to contain:
  18. * array flags as for any matlab matrix
  19. * 3 int8 strings
  20. * a matrix
  21. It seems that whenever the mat file contains a ``mxOPAQUE_CLASS``
  22. instance, there is also an un-named matrix (name == '') at the end of
  23. the mat file. I'll call this the ``__function_workspace__`` matrix.
  24. When I saved two anonymous functions in a mat file, or appended another
  25. anonymous function to the mat file, there was still only one
  26. ``__function_workspace__`` un-named matrix at the end, but larger than
  27. that for a mat file with a single anonymous function, suggesting that
  28. the workspaces for the two functions had been merged.
  29. The ``__function_workspace__`` matrix appears to be of double class
  30. (``mxCLASS_DOUBLE``), but stored as uint8, the memory for which is in
  31. the format of a mini .mat file, without the first 124 bytes of the file
  32. header (the description and the subsystem_offset), but with the version
  33. U2 bytes, and the S2 endian test bytes. There follow 4 zero bytes,
  34. presumably for 8 byte padding, and then a series of ``miMATRIX``
  35. entries, as in a standard mat file. The ``miMATRIX`` entries appear to
  36. be series of un-named (name == '') matrices, and may also contain arrays
  37. of this same mini-mat format.
  38. I guess that:
  39. * saving an anonymous function back to a mat file will need the
  40. associated ``__function_workspace__`` matrix saved as well for the
  41. anonymous function to work correctly.
  42. * appending to a mat file that has a ``__function_workspace__`` would
  43. involve first pulling off this workspace, appending, checking whether
  44. there were any more anonymous functions appended, and then somehow
  45. merging the relevant workspaces, and saving at the end of the mat
  46. file.
  47. The mat files I was playing with are in ``tests/data``:
  48. * sqr.mat
  49. * parabola.mat
  50. * some_functions.mat
  51. See ``tests/test_mio.py:test_mio_funcs.py`` for the debugging
  52. script I was working with.
  53. '''
  54. # Small fragments of current code adapted from matfile.py by Heiko
  55. # Henkelmann; parts of the code for simplify_cells=True adapted from
  56. # http://blog.nephics.com/2019/08/28/better-loadmat-for-scipy/.
  57. import os
  58. import time
  59. import sys
  60. import zlib
  61. from io import BytesIO
  62. import warnings
  63. import numpy as np
  64. import scipy.sparse
  65. from ._byteordercodes import native_code, swapped_code
  66. from ._miobase import (MatFileReader, docfiller, matdims, read_dtype,
  67. arr_to_chars, arr_dtype_number, MatWriteError,
  68. MatReadError, MatReadWarning)
  69. # Reader object for matlab 5 format variables
  70. from ._mio5_utils import VarReader5
  71. # Constants and helper objects
  72. from ._mio5_params import (MatlabObject, MatlabFunction, MDTYPES, NP_TO_MTYPES,
  73. NP_TO_MXTYPES, miCOMPRESSED, miMATRIX, miINT8,
  74. miUTF8, miUINT32, mxCELL_CLASS, mxSTRUCT_CLASS,
  75. mxOBJECT_CLASS, mxCHAR_CLASS, mxSPARSE_CLASS,
  76. mxDOUBLE_CLASS, mclass_info, mat_struct)
  77. from ._streams import ZlibInputStream
  78. def _has_struct(elem):
  79. """Determine if elem is an array and if first array item is a struct."""
  80. return (isinstance(elem, np.ndarray) and (elem.size > 0) and
  81. isinstance(elem[0], mat_struct))
  82. def _inspect_cell_array(ndarray):
  83. """Construct lists from cell arrays (loaded as numpy ndarrays), recursing
  84. into items if they contain mat_struct objects."""
  85. elem_list = []
  86. for sub_elem in ndarray:
  87. if isinstance(sub_elem, mat_struct):
  88. elem_list.append(_matstruct_to_dict(sub_elem))
  89. elif _has_struct(sub_elem):
  90. elem_list.append(_inspect_cell_array(sub_elem))
  91. else:
  92. elem_list.append(sub_elem)
  93. return elem_list
  94. def _matstruct_to_dict(matobj):
  95. """Construct nested dicts from mat_struct objects."""
  96. d = {}
  97. for f in matobj._fieldnames:
  98. elem = matobj.__dict__[f]
  99. if isinstance(elem, mat_struct):
  100. d[f] = _matstruct_to_dict(elem)
  101. elif _has_struct(elem):
  102. d[f] = _inspect_cell_array(elem)
  103. else:
  104. d[f] = elem
  105. return d
  106. def _simplify_cells(d):
  107. """Convert mat objects in dict to nested dicts."""
  108. for key in d:
  109. if isinstance(d[key], mat_struct):
  110. d[key] = _matstruct_to_dict(d[key])
  111. elif _has_struct(d[key]):
  112. d[key] = _inspect_cell_array(d[key])
  113. return d
  114. class MatFile5Reader(MatFileReader):
  115. ''' Reader for Mat 5 mat files
  116. Adds the following attribute to base class
  117. uint16_codec - char codec to use for uint16 char arrays
  118. (defaults to system default codec)
  119. Uses variable reader that has the following stardard interface (see
  120. abstract class in ``miobase``::
  121. __init__(self, file_reader)
  122. read_header(self)
  123. array_from_header(self)
  124. and added interface::
  125. set_stream(self, stream)
  126. read_full_tag(self)
  127. '''
  128. @docfiller
  129. def __init__(self,
  130. mat_stream,
  131. byte_order=None,
  132. mat_dtype=False,
  133. squeeze_me=False,
  134. chars_as_strings=True,
  135. matlab_compatible=False,
  136. struct_as_record=True,
  137. verify_compressed_data_integrity=True,
  138. uint16_codec=None,
  139. simplify_cells=False):
  140. '''Initializer for matlab 5 file format reader
  141. %(matstream_arg)s
  142. %(load_args)s
  143. %(struct_arg)s
  144. uint16_codec : {None, string}
  145. Set codec to use for uint16 char arrays (e.g., 'utf-8').
  146. Use system default codec if None
  147. '''
  148. super().__init__(
  149. mat_stream,
  150. byte_order,
  151. mat_dtype,
  152. squeeze_me,
  153. chars_as_strings,
  154. matlab_compatible,
  155. struct_as_record,
  156. verify_compressed_data_integrity,
  157. simplify_cells)
  158. # Set uint16 codec
  159. if not uint16_codec:
  160. uint16_codec = sys.getdefaultencoding()
  161. self.uint16_codec = uint16_codec
  162. # placeholders for readers - see initialize_read method
  163. self._file_reader = None
  164. self._matrix_reader = None
  165. def guess_byte_order(self):
  166. ''' Guess byte order.
  167. Sets stream pointer to 0'''
  168. self.mat_stream.seek(126)
  169. mi = self.mat_stream.read(2)
  170. self.mat_stream.seek(0)
  171. return mi == b'IM' and '<' or '>'
  172. def read_file_header(self):
  173. ''' Read in mat 5 file header '''
  174. hdict = {}
  175. hdr_dtype = MDTYPES[self.byte_order]['dtypes']['file_header']
  176. hdr = read_dtype(self.mat_stream, hdr_dtype)
  177. hdict['__header__'] = hdr['description'].item().strip(b' \t\n\000')
  178. v_major = hdr['version'] >> 8
  179. v_minor = hdr['version'] & 0xFF
  180. hdict['__version__'] = '%d.%d' % (v_major, v_minor)
  181. return hdict
  182. def initialize_read(self):
  183. ''' Run when beginning read of variables
  184. Sets up readers from parameters in `self`
  185. '''
  186. # reader for top level stream. We need this extra top-level
  187. # reader because we use the matrix_reader object to contain
  188. # compressed matrices (so they have their own stream)
  189. self._file_reader = VarReader5(self)
  190. # reader for matrix streams
  191. self._matrix_reader = VarReader5(self)
  192. def read_var_header(self):
  193. ''' Read header, return header, next position
  194. Header has to define at least .name and .is_global
  195. Parameters
  196. ----------
  197. None
  198. Returns
  199. -------
  200. header : object
  201. object that can be passed to self.read_var_array, and that
  202. has attributes .name and .is_global
  203. next_position : int
  204. position in stream of next variable
  205. '''
  206. mdtype, byte_count = self._file_reader.read_full_tag()
  207. if not byte_count > 0:
  208. raise ValueError("Did not read any bytes")
  209. next_pos = self.mat_stream.tell() + byte_count
  210. if mdtype == miCOMPRESSED:
  211. # Make new stream from compressed data
  212. stream = ZlibInputStream(self.mat_stream, byte_count)
  213. self._matrix_reader.set_stream(stream)
  214. check_stream_limit = self.verify_compressed_data_integrity
  215. mdtype, byte_count = self._matrix_reader.read_full_tag()
  216. else:
  217. check_stream_limit = False
  218. self._matrix_reader.set_stream(self.mat_stream)
  219. if not mdtype == miMATRIX:
  220. raise TypeError('Expecting miMATRIX type here, got %d' % mdtype)
  221. header = self._matrix_reader.read_header(check_stream_limit)
  222. return header, next_pos
  223. def read_var_array(self, header, process=True):
  224. ''' Read array, given `header`
  225. Parameters
  226. ----------
  227. header : header object
  228. object with fields defining variable header
  229. process : {True, False} bool, optional
  230. If True, apply recursive post-processing during loading of
  231. array.
  232. Returns
  233. -------
  234. arr : array
  235. array with post-processing applied or not according to
  236. `process`.
  237. '''
  238. return self._matrix_reader.array_from_header(header, process)
  239. def get_variables(self, variable_names=None):
  240. ''' get variables from stream as dictionary
  241. variable_names - optional list of variable names to get
  242. If variable_names is None, then get all variables in file
  243. '''
  244. if isinstance(variable_names, str):
  245. variable_names = [variable_names]
  246. elif variable_names is not None:
  247. variable_names = list(variable_names)
  248. self.mat_stream.seek(0)
  249. # Here we pass all the parameters in self to the reading objects
  250. self.initialize_read()
  251. mdict = self.read_file_header()
  252. mdict['__globals__'] = []
  253. while not self.end_of_stream():
  254. hdr, next_position = self.read_var_header()
  255. name = 'None' if hdr.name is None else hdr.name.decode('latin1')
  256. if name in mdict:
  257. warnings.warn('Duplicate variable name "%s" in stream'
  258. ' - replacing previous with new\n'
  259. 'Consider mio5.varmats_from_mat to split '
  260. 'file into single variable files' % name,
  261. MatReadWarning, stacklevel=2)
  262. if name == '':
  263. # can only be a matlab 7 function workspace
  264. name = '__function_workspace__'
  265. # We want to keep this raw because mat_dtype processing
  266. # will break the format (uint8 as mxDOUBLE_CLASS)
  267. process = False
  268. else:
  269. process = True
  270. if variable_names is not None and name not in variable_names:
  271. self.mat_stream.seek(next_position)
  272. continue
  273. try:
  274. res = self.read_var_array(hdr, process)
  275. except MatReadError as err:
  276. warnings.warn(
  277. 'Unreadable variable "%s", because "%s"' %
  278. (name, err),
  279. Warning, stacklevel=2)
  280. res = "Read error: %s" % err
  281. self.mat_stream.seek(next_position)
  282. mdict[name] = res
  283. if hdr.is_global:
  284. mdict['__globals__'].append(name)
  285. if variable_names is not None:
  286. variable_names.remove(name)
  287. if len(variable_names) == 0:
  288. break
  289. if self.simplify_cells:
  290. return _simplify_cells(mdict)
  291. else:
  292. return mdict
  293. def list_variables(self):
  294. ''' list variables from stream '''
  295. self.mat_stream.seek(0)
  296. # Here we pass all the parameters in self to the reading objects
  297. self.initialize_read()
  298. self.read_file_header()
  299. vars = []
  300. while not self.end_of_stream():
  301. hdr, next_position = self.read_var_header()
  302. name = 'None' if hdr.name is None else hdr.name.decode('latin1')
  303. if name == '':
  304. # can only be a matlab 7 function workspace
  305. name = '__function_workspace__'
  306. shape = self._matrix_reader.shape_from_header(hdr)
  307. if hdr.is_logical:
  308. info = 'logical'
  309. else:
  310. info = mclass_info.get(hdr.mclass, 'unknown')
  311. vars.append((name, shape, info))
  312. self.mat_stream.seek(next_position)
  313. return vars
  314. def varmats_from_mat(file_obj):
  315. """ Pull variables out of mat 5 file as a sequence of mat file objects
  316. This can be useful with a difficult mat file, containing unreadable
  317. variables. This routine pulls the variables out in raw form and puts them,
  318. unread, back into a file stream for saving or reading. Another use is the
  319. pathological case where there is more than one variable of the same name in
  320. the file; this routine returns the duplicates, whereas the standard reader
  321. will overwrite duplicates in the returned dictionary.
  322. The file pointer in `file_obj` will be undefined. File pointers for the
  323. returned file-like objects are set at 0.
  324. Parameters
  325. ----------
  326. file_obj : file-like
  327. file object containing mat file
  328. Returns
  329. -------
  330. named_mats : list
  331. list contains tuples of (name, BytesIO) where BytesIO is a file-like
  332. object containing mat file contents as for a single variable. The
  333. BytesIO contains a string with the original header and a single var. If
  334. ``var_file_obj`` is an individual BytesIO instance, then save as a mat
  335. file with something like ``open('test.mat',
  336. 'wb').write(var_file_obj.read())``
  337. Examples
  338. --------
  339. >>> import scipy.io
  340. BytesIO is from the ``io`` module in Python 3, and is ``cStringIO`` for
  341. Python < 3.
  342. >>> mat_fileobj = BytesIO()
  343. >>> scipy.io.savemat(mat_fileobj, {'b': np.arange(10), 'a': 'a string'})
  344. >>> varmats = varmats_from_mat(mat_fileobj)
  345. >>> sorted([name for name, str_obj in varmats])
  346. ['a', 'b']
  347. """
  348. rdr = MatFile5Reader(file_obj)
  349. file_obj.seek(0)
  350. # Raw read of top-level file header
  351. hdr_len = MDTYPES[native_code]['dtypes']['file_header'].itemsize
  352. raw_hdr = file_obj.read(hdr_len)
  353. # Initialize variable reading
  354. file_obj.seek(0)
  355. rdr.initialize_read()
  356. rdr.read_file_header()
  357. next_position = file_obj.tell()
  358. named_mats = []
  359. while not rdr.end_of_stream():
  360. start_position = next_position
  361. hdr, next_position = rdr.read_var_header()
  362. name = 'None' if hdr.name is None else hdr.name.decode('latin1')
  363. # Read raw variable string
  364. file_obj.seek(start_position)
  365. byte_count = next_position - start_position
  366. var_str = file_obj.read(byte_count)
  367. # write to stringio object
  368. out_obj = BytesIO()
  369. out_obj.write(raw_hdr)
  370. out_obj.write(var_str)
  371. out_obj.seek(0)
  372. named_mats.append((name, out_obj))
  373. return named_mats
  374. class EmptyStructMarker:
  375. """ Class to indicate presence of empty matlab struct on output """
  376. def to_writeable(source):
  377. ''' Convert input object ``source`` to something we can write
  378. Parameters
  379. ----------
  380. source : object
  381. Returns
  382. -------
  383. arr : None or ndarray or EmptyStructMarker
  384. If `source` cannot be converted to something we can write to a matfile,
  385. return None. If `source` is equivalent to an empty dictionary, return
  386. ``EmptyStructMarker``. Otherwise return `source` converted to an
  387. ndarray with contents for writing to matfile.
  388. '''
  389. if isinstance(source, np.ndarray):
  390. return source
  391. if source is None:
  392. return None
  393. # Objects that implement mappings
  394. is_mapping = (hasattr(source, 'keys') and hasattr(source, 'values') and
  395. hasattr(source, 'items'))
  396. # Objects that don't implement mappings, but do have dicts
  397. if isinstance(source, np.generic):
  398. # NumPy scalars are never mappings (PyPy issue workaround)
  399. pass
  400. elif not is_mapping and hasattr(source, '__dict__'):
  401. source = dict((key, value) for key, value in source.__dict__.items()
  402. if not key.startswith('_'))
  403. is_mapping = True
  404. if is_mapping:
  405. dtype = []
  406. values = []
  407. for field, value in source.items():
  408. if (isinstance(field, str) and
  409. field[0] not in '_0123456789'):
  410. dtype.append((str(field), object))
  411. values.append(value)
  412. if dtype:
  413. return np.array([tuple(values)], dtype)
  414. else:
  415. return EmptyStructMarker
  416. # Next try and convert to an array
  417. narr = np.asanyarray(source)
  418. if narr.dtype.type in (object, np.object_) and \
  419. narr.shape == () and narr == source:
  420. # No interesting conversion possible
  421. return None
  422. return narr
  423. # Native byte ordered dtypes for convenience for writers
  424. NDT_FILE_HDR = MDTYPES[native_code]['dtypes']['file_header']
  425. NDT_TAG_FULL = MDTYPES[native_code]['dtypes']['tag_full']
  426. NDT_TAG_SMALL = MDTYPES[native_code]['dtypes']['tag_smalldata']
  427. NDT_ARRAY_FLAGS = MDTYPES[native_code]['dtypes']['array_flags']
  428. class VarWriter5:
  429. ''' Generic matlab matrix writing class '''
  430. mat_tag = np.zeros((), NDT_TAG_FULL)
  431. mat_tag['mdtype'] = miMATRIX
  432. def __init__(self, file_writer):
  433. self.file_stream = file_writer.file_stream
  434. self.unicode_strings = file_writer.unicode_strings
  435. self.long_field_names = file_writer.long_field_names
  436. self.oned_as = file_writer.oned_as
  437. # These are used for top level writes, and unset after
  438. self._var_name = None
  439. self._var_is_global = False
  440. def write_bytes(self, arr):
  441. self.file_stream.write(arr.tobytes(order='F'))
  442. def write_string(self, s):
  443. self.file_stream.write(s)
  444. def write_element(self, arr, mdtype=None):
  445. ''' write tag and data '''
  446. if mdtype is None:
  447. mdtype = NP_TO_MTYPES[arr.dtype.str[1:]]
  448. # Array needs to be in native byte order
  449. if arr.dtype.byteorder == swapped_code:
  450. arr = arr.byteswap().newbyteorder()
  451. byte_count = arr.size*arr.itemsize
  452. if byte_count <= 4:
  453. self.write_smalldata_element(arr, mdtype, byte_count)
  454. else:
  455. self.write_regular_element(arr, mdtype, byte_count)
  456. def write_smalldata_element(self, arr, mdtype, byte_count):
  457. # write tag with embedded data
  458. tag = np.zeros((), NDT_TAG_SMALL)
  459. tag['byte_count_mdtype'] = (byte_count << 16) + mdtype
  460. # if arr.tobytes is < 4, the element will be zero-padded as needed.
  461. tag['data'] = arr.tobytes(order='F')
  462. self.write_bytes(tag)
  463. def write_regular_element(self, arr, mdtype, byte_count):
  464. # write tag, data
  465. tag = np.zeros((), NDT_TAG_FULL)
  466. tag['mdtype'] = mdtype
  467. tag['byte_count'] = byte_count
  468. self.write_bytes(tag)
  469. self.write_bytes(arr)
  470. # pad to next 64-bit boundary
  471. bc_mod_8 = byte_count % 8
  472. if bc_mod_8:
  473. self.file_stream.write(b'\x00' * (8-bc_mod_8))
  474. def write_header(self,
  475. shape,
  476. mclass,
  477. is_complex=False,
  478. is_logical=False,
  479. nzmax=0):
  480. ''' Write header for given data options
  481. shape : sequence
  482. array shape
  483. mclass - mat5 matrix class
  484. is_complex - True if matrix is complex
  485. is_logical - True if matrix is logical
  486. nzmax - max non zero elements for sparse arrays
  487. We get the name and the global flag from the object, and reset
  488. them to defaults after we've used them
  489. '''
  490. # get name and is_global from one-shot object store
  491. name = self._var_name
  492. is_global = self._var_is_global
  493. # initialize the top-level matrix tag, store position
  494. self._mat_tag_pos = self.file_stream.tell()
  495. self.write_bytes(self.mat_tag)
  496. # write array flags (complex, global, logical, class, nzmax)
  497. af = np.zeros((), NDT_ARRAY_FLAGS)
  498. af['data_type'] = miUINT32
  499. af['byte_count'] = 8
  500. flags = is_complex << 3 | is_global << 2 | is_logical << 1
  501. af['flags_class'] = mclass | flags << 8
  502. af['nzmax'] = nzmax
  503. self.write_bytes(af)
  504. # shape
  505. self.write_element(np.array(shape, dtype='i4'))
  506. # write name
  507. name = np.asarray(name)
  508. if name == '': # empty string zero-terminated
  509. self.write_smalldata_element(name, miINT8, 0)
  510. else:
  511. self.write_element(name, miINT8)
  512. # reset the one-shot store to defaults
  513. self._var_name = ''
  514. self._var_is_global = False
  515. def update_matrix_tag(self, start_pos):
  516. curr_pos = self.file_stream.tell()
  517. self.file_stream.seek(start_pos)
  518. byte_count = curr_pos - start_pos - 8
  519. if byte_count >= 2**32:
  520. raise MatWriteError("Matrix too large to save with Matlab "
  521. "5 format")
  522. self.mat_tag['byte_count'] = byte_count
  523. self.write_bytes(self.mat_tag)
  524. self.file_stream.seek(curr_pos)
  525. def write_top(self, arr, name, is_global):
  526. """ Write variable at top level of mat file
  527. Parameters
  528. ----------
  529. arr : array_like
  530. array-like object to create writer for
  531. name : str, optional
  532. name as it will appear in matlab workspace
  533. default is empty string
  534. is_global : {False, True}, optional
  535. whether variable will be global on load into matlab
  536. """
  537. # these are set before the top-level header write, and unset at
  538. # the end of the same write, because they do not apply for lower levels
  539. self._var_is_global = is_global
  540. self._var_name = name
  541. # write the header and data
  542. self.write(arr)
  543. def write(self, arr):
  544. ''' Write `arr` to stream at top and sub levels
  545. Parameters
  546. ----------
  547. arr : array_like
  548. array-like object to create writer for
  549. '''
  550. # store position, so we can update the matrix tag
  551. mat_tag_pos = self.file_stream.tell()
  552. # First check if these are sparse
  553. if scipy.sparse.issparse(arr):
  554. self.write_sparse(arr)
  555. self.update_matrix_tag(mat_tag_pos)
  556. return
  557. # Try to convert things that aren't arrays
  558. narr = to_writeable(arr)
  559. if narr is None:
  560. raise TypeError('Could not convert %s (type %s) to array'
  561. % (arr, type(arr)))
  562. if isinstance(narr, MatlabObject):
  563. self.write_object(narr)
  564. elif isinstance(narr, MatlabFunction):
  565. raise MatWriteError('Cannot write matlab functions')
  566. elif narr is EmptyStructMarker: # empty struct array
  567. self.write_empty_struct()
  568. elif narr.dtype.fields: # struct array
  569. self.write_struct(narr)
  570. elif narr.dtype.hasobject: # cell array
  571. self.write_cells(narr)
  572. elif narr.dtype.kind in ('U', 'S'):
  573. if self.unicode_strings:
  574. codec = 'UTF8'
  575. else:
  576. codec = 'ascii'
  577. self.write_char(narr, codec)
  578. else:
  579. self.write_numeric(narr)
  580. self.update_matrix_tag(mat_tag_pos)
  581. def write_numeric(self, arr):
  582. imagf = arr.dtype.kind == 'c'
  583. logif = arr.dtype.kind == 'b'
  584. try:
  585. mclass = NP_TO_MXTYPES[arr.dtype.str[1:]]
  586. except KeyError:
  587. # No matching matlab type, probably complex256 / float128 / float96
  588. # Cast data to complex128 / float64.
  589. if imagf:
  590. arr = arr.astype('c128')
  591. elif logif:
  592. arr = arr.astype('i1') # Should only contain 0/1
  593. else:
  594. arr = arr.astype('f8')
  595. mclass = mxDOUBLE_CLASS
  596. self.write_header(matdims(arr, self.oned_as),
  597. mclass,
  598. is_complex=imagf,
  599. is_logical=logif)
  600. if imagf:
  601. self.write_element(arr.real)
  602. self.write_element(arr.imag)
  603. else:
  604. self.write_element(arr)
  605. def write_char(self, arr, codec='ascii'):
  606. ''' Write string array `arr` with given `codec`
  607. '''
  608. if arr.size == 0 or np.all(arr == ''):
  609. # This an empty string array or a string array containing
  610. # only empty strings. Matlab cannot distinguish between a
  611. # string array that is empty, and a string array containing
  612. # only empty strings, because it stores strings as arrays of
  613. # char. There is no way of having an array of char that is
  614. # not empty, but contains an empty string. We have to
  615. # special-case the array-with-empty-strings because even
  616. # empty strings have zero padding, which would otherwise
  617. # appear in matlab as a string with a space.
  618. shape = (0,) * np.max([arr.ndim, 2])
  619. self.write_header(shape, mxCHAR_CLASS)
  620. self.write_smalldata_element(arr, miUTF8, 0)
  621. return
  622. # non-empty string.
  623. #
  624. # Convert to char array
  625. arr = arr_to_chars(arr)
  626. # We have to write the shape directly, because we are going
  627. # recode the characters, and the resulting stream of chars
  628. # may have a different length
  629. shape = arr.shape
  630. self.write_header(shape, mxCHAR_CLASS)
  631. if arr.dtype.kind == 'U' and arr.size:
  632. # Make one long string from all the characters. We need to
  633. # transpose here, because we're flattening the array, before
  634. # we write the bytes. The bytes have to be written in
  635. # Fortran order.
  636. n_chars = np.prod(shape)
  637. st_arr = np.ndarray(shape=(),
  638. dtype=arr_dtype_number(arr, n_chars),
  639. buffer=arr.T.copy()) # Fortran order
  640. # Recode with codec to give byte string
  641. st = st_arr.item().encode(codec)
  642. # Reconstruct as 1-D byte array
  643. arr = np.ndarray(shape=(len(st),),
  644. dtype='S1',
  645. buffer=st)
  646. self.write_element(arr, mdtype=miUTF8)
  647. def write_sparse(self, arr):
  648. ''' Sparse matrices are 2D
  649. '''
  650. A = arr.tocsc() # convert to sparse CSC format
  651. A.sort_indices() # MATLAB expects sorted row indices
  652. is_complex = (A.dtype.kind == 'c')
  653. is_logical = (A.dtype.kind == 'b')
  654. nz = A.nnz
  655. self.write_header(matdims(arr, self.oned_as),
  656. mxSPARSE_CLASS,
  657. is_complex=is_complex,
  658. is_logical=is_logical,
  659. # matlab won't load file with 0 nzmax
  660. nzmax=1 if nz == 0 else nz)
  661. self.write_element(A.indices.astype('i4'))
  662. self.write_element(A.indptr.astype('i4'))
  663. self.write_element(A.data.real)
  664. if is_complex:
  665. self.write_element(A.data.imag)
  666. def write_cells(self, arr):
  667. self.write_header(matdims(arr, self.oned_as),
  668. mxCELL_CLASS)
  669. # loop over data, column major
  670. A = np.atleast_2d(arr).flatten('F')
  671. for el in A:
  672. self.write(el)
  673. def write_empty_struct(self):
  674. self.write_header((1, 1), mxSTRUCT_CLASS)
  675. # max field name length set to 1 in an example matlab struct
  676. self.write_element(np.array(1, dtype=np.int32))
  677. # Field names element is empty
  678. self.write_element(np.array([], dtype=np.int8))
  679. def write_struct(self, arr):
  680. self.write_header(matdims(arr, self.oned_as),
  681. mxSTRUCT_CLASS)
  682. self._write_items(arr)
  683. def _write_items(self, arr):
  684. # write fieldnames
  685. fieldnames = [f[0] for f in arr.dtype.descr]
  686. length = max([len(fieldname) for fieldname in fieldnames])+1
  687. max_length = (self.long_field_names and 64) or 32
  688. if length > max_length:
  689. raise ValueError("Field names are restricted to %d characters" %
  690. (max_length-1))
  691. self.write_element(np.array([length], dtype='i4'))
  692. self.write_element(
  693. np.array(fieldnames, dtype='S%d' % (length)),
  694. mdtype=miINT8)
  695. A = np.atleast_2d(arr).flatten('F')
  696. for el in A:
  697. for f in fieldnames:
  698. self.write(el[f])
  699. def write_object(self, arr):
  700. '''Same as writing structs, except different mx class, and extra
  701. classname element after header
  702. '''
  703. self.write_header(matdims(arr, self.oned_as),
  704. mxOBJECT_CLASS)
  705. self.write_element(np.array(arr.classname, dtype='S'),
  706. mdtype=miINT8)
  707. self._write_items(arr)
  708. class MatFile5Writer:
  709. ''' Class for writing mat5 files '''
  710. @docfiller
  711. def __init__(self, file_stream,
  712. do_compression=False,
  713. unicode_strings=False,
  714. global_vars=None,
  715. long_field_names=False,
  716. oned_as='row'):
  717. ''' Initialize writer for matlab 5 format files
  718. Parameters
  719. ----------
  720. %(do_compression)s
  721. %(unicode_strings)s
  722. global_vars : None or sequence of strings, optional
  723. Names of variables to be marked as global for matlab
  724. %(long_fields)s
  725. %(oned_as)s
  726. '''
  727. self.file_stream = file_stream
  728. self.do_compression = do_compression
  729. self.unicode_strings = unicode_strings
  730. if global_vars:
  731. self.global_vars = global_vars
  732. else:
  733. self.global_vars = []
  734. self.long_field_names = long_field_names
  735. self.oned_as = oned_as
  736. self._matrix_writer = None
  737. def write_file_header(self):
  738. # write header
  739. hdr = np.zeros((), NDT_FILE_HDR)
  740. hdr['description'] = 'MATLAB 5.0 MAT-file Platform: %s, Created on: %s' \
  741. % (os.name,time.asctime())
  742. hdr['version'] = 0x0100
  743. hdr['endian_test'] = np.ndarray(shape=(),
  744. dtype='S2',
  745. buffer=np.uint16(0x4d49))
  746. self.file_stream.write(hdr.tobytes())
  747. def put_variables(self, mdict, write_header=None):
  748. ''' Write variables in `mdict` to stream
  749. Parameters
  750. ----------
  751. mdict : mapping
  752. mapping with method ``items`` returns name, contents pairs where
  753. ``name`` which will appear in the matlab workspace in file load, and
  754. ``contents`` is something writeable to a matlab file, such as a NumPy
  755. array.
  756. write_header : {None, True, False}, optional
  757. If True, then write the matlab file header before writing the
  758. variables. If None (the default) then write the file header
  759. if we are at position 0 in the stream. By setting False
  760. here, and setting the stream position to the end of the file,
  761. you can append variables to a matlab file
  762. '''
  763. # write header if requested, or None and start of file
  764. if write_header is None:
  765. write_header = self.file_stream.tell() == 0
  766. if write_header:
  767. self.write_file_header()
  768. self._matrix_writer = VarWriter5(self)
  769. for name, var in mdict.items():
  770. if name[0] == '_':
  771. continue
  772. is_global = name in self.global_vars
  773. if self.do_compression:
  774. stream = BytesIO()
  775. self._matrix_writer.file_stream = stream
  776. self._matrix_writer.write_top(var, name.encode('latin1'), is_global)
  777. out_str = zlib.compress(stream.getvalue())
  778. tag = np.empty((), NDT_TAG_FULL)
  779. tag['mdtype'] = miCOMPRESSED
  780. tag['byte_count'] = len(out_str)
  781. self.file_stream.write(tag.tobytes())
  782. self.file_stream.write(out_str)
  783. else: # not compressing
  784. self._matrix_writer.write_top(var, name.encode('latin1'), is_global)