serialization.py 48 KB


  1. import difflib
  2. import os
  3. import io
  4. import shutil
  5. import struct
  6. import sys
  7. import torch
  8. import tarfile
  9. import tempfile
  10. import warnings
  11. from contextlib import closing, contextmanager
  12. from ._utils import _import_dotted_name
  13. from torch._sources import get_source_lines_and_file
  14. from torch.types import Storage
  15. from torch.storage import _get_dtype_from_pickle_storage_type
  16. from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO
  17. from typing_extensions import TypeAlias # Python 3.10+
  18. import copyreg
  19. import pickle
  20. import pathlib
  21. import torch._weights_only_unpickler as _weights_only_unpickler
  22. DEFAULT_PROTOCOL = 2
  23. LONG_SIZE = struct.Struct('=l').size
  24. INT_SIZE = struct.Struct('=i').size
  25. SHORT_SIZE = struct.Struct('=h').size
  26. MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
  27. PROTOCOL_VERSION = 1001
  28. STORAGE_KEY_SEPARATOR = ','
  29. FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
  30. MAP_LOCATION: TypeAlias = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]]
  31. __all__ = [
  32. 'SourceChangeWarning',
  33. 'mkdtemp',
  34. 'register_package',
  35. 'check_module_version_greater_or_equal',
  36. 'validate_cuda_device',
  37. 'location_tag',
  38. 'default_restore_location',
  39. 'normalize_storage_type',
  40. 'storage_to_tensor_type',
  41. 'save',
  42. 'load',
  43. 'StorageType',
  44. ]
  45. class SourceChangeWarning(Warning):
  46. pass
  47. @contextmanager
  48. def mkdtemp():
  49. path = tempfile.mkdtemp()
  50. yield path
  51. shutil.rmtree(path)
  52. _package_registry = []
  53. def _is_zipfile(f) -> bool:
  54. # This is a stricter implementation than zipfile.is_zipfile().
  55. # zipfile.is_zipfile() is True if the magic number appears anywhere in the
  56. # binary. Since we expect the files here to be generated by torch.save or
  57. # torch.jit.save, it's safe to only check the start bytes and avoid
  58. # collisions and assume the zip has only 1 file.
  59. # See bugs.python.org/issue28494.
  60. # Read the first 4 bytes of the file
  61. read_bytes = []
  62. start = f.tell()
  63. byte = f.read(1)
  64. while byte != b"":
  65. read_bytes.append(byte)
  66. if len(read_bytes) == 4:
  67. break
  68. byte = f.read(1)
  69. f.seek(start)
  70. local_header_magic_number = [b'P', b'K', b'\x03', b'\x04']
  71. return read_bytes == local_header_magic_number
  72. def register_package(priority, tagger, deserializer):
  73. queue_elem = (priority, tagger, deserializer)
  74. _package_registry.append(queue_elem)
  75. _package_registry.sort()
  76. def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
  77. '''
  78. Check if a module's version satisfies requirements
  79. Usually, a module's version string will be like 'x.y.z', which would be represented
  80. as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
  81. string does not match the given tuple's format up to the length of the tuple, then
  82. error and exit or emit a warning.
  83. Args:
  84. module: the module to check the version of
  85. req_version_tuple: tuple (usually of ints) representing the required version
  86. error_if_malformed: whether we should exit if module version string is malformed
  87. Returns:
  88. requirement_is_met: bool
  89. '''
  90. try:
  91. version_strs = module.__version__.split('.')
  92. # Cast module version fields to match the types of the required version
  93. module_version = tuple(
  94. type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
  95. )
  96. requirement_is_met = module_version >= req_version_tuple
  97. except Exception as e:
  98. message = (
  99. "'%s' module version string is malformed '%s' and cannot be compared"
  100. " with tuple %s"
  101. ) % (
  102. module.__name__, module.__version__, str(req_version_tuple)
  103. )
  104. if error_if_malformed:
  105. raise RuntimeError(message) from e
  106. else:
  107. warnings.warn(message + ', but continuing assuming that requirement is met')
  108. requirement_is_met = True
  109. return requirement_is_met
  110. def _cpu_tag(obj):
  111. if obj.device.type == 'cpu':
  112. return 'cpu'
  113. def _cuda_tag(obj):
  114. if obj.device.type == 'cuda':
  115. return 'cuda:' + str(obj.device.index)
  116. def _mps_tag(obj):
  117. if obj.device.type == 'mps':
  118. return 'mps'
  119. def _meta_tag(obj):
  120. if obj.device.type == 'meta':
  121. return 'meta'
  122. def _cpu_deserialize(obj, location):
  123. if location == 'cpu':
  124. return obj
  125. def validate_cuda_device(location):
  126. device = torch.cuda._utils._get_device_index(location, True)
  127. if not torch.cuda.is_available():
  128. raise RuntimeError('Attempting to deserialize object on a CUDA '
  129. 'device but torch.cuda.is_available() is False. '
  130. 'If you are running on a CPU-only machine, '
  131. 'please use torch.load with map_location=torch.device(\'cpu\') '
  132. 'to map your storages to the CPU.')
  133. device_count = torch.cuda.device_count()
  134. if device >= device_count:
  135. raise RuntimeError('Attempting to deserialize object on CUDA device '
  136. f'{device} but torch.cuda.device_count() is {device_count}. Please use '
  137. 'torch.load with map_location to map your storages '
  138. 'to an existing device.')
  139. return device
  140. def _cuda_deserialize(obj, location):
  141. if location.startswith('cuda'):
  142. device = validate_cuda_device(location)
  143. if getattr(obj, "_torch_load_uninitialized", False):
  144. with torch.cuda.device(device):
  145. return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
  146. else:
  147. return obj.cuda(device)
  148. def _mps_deserialize(obj, location):
  149. if location == 'mps':
  150. return obj.mps()
  151. def _meta_deserialize(obj, location):
  152. if location == 'meta':
  153. return torch.UntypedStorage(obj.nbytes(), device='meta')
  154. register_package(10, _cpu_tag, _cpu_deserialize)
  155. register_package(20, _cuda_tag, _cuda_deserialize)
  156. register_package(21, _mps_tag, _mps_deserialize)
  157. register_package(22, _meta_tag, _meta_deserialize)
  158. def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
  159. for _, tagger, _ in _package_registry:
  160. location = tagger(storage)
  161. if location:
  162. return location
  163. raise RuntimeError("don't know how to determine data location of "
  164. + torch.typename(storage))
  165. def default_restore_location(storage, location):
  166. for _, _, fn in _package_registry:
  167. result = fn(storage, location)
  168. if result is not None:
  169. return result
  170. raise RuntimeError("don't know how to restore data location of "
  171. + torch.typename(storage) + " (tagged with "
  172. + location + ")")
  173. def normalize_storage_type(storage_type):
  174. return getattr(torch, storage_type.__name__)
  175. def storage_to_tensor_type(storage):
  176. storage_type = type(storage)
  177. module = _import_dotted_name(storage_type.__module__)
  178. return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
  179. def _is_path(name_or_buffer):
  180. return isinstance(name_or_buffer, (str, pathlib.Path))
  181. class _opener:
  182. def __init__(self, file_like):
  183. self.file_like = file_like
  184. def __enter__(self):
  185. return self.file_like
  186. def __exit__(self, *args):
  187. pass
  188. class _open_file(_opener):
  189. def __init__(self, name, mode):
  190. super().__init__(open(name, mode))
  191. def __exit__(self, *args):
  192. self.file_like.close()
  193. class _open_buffer_reader(_opener):
  194. def __init__(self, buffer):
  195. super().__init__(buffer)
  196. _check_seekable(buffer)
  197. class _open_buffer_writer(_opener):
  198. def __exit__(self, *args):
  199. self.file_like.flush()
  200. def _open_file_like(name_or_buffer, mode):
  201. if _is_path(name_or_buffer):
  202. return _open_file(name_or_buffer, mode)
  203. else:
  204. if 'w' in mode:
  205. return _open_buffer_writer(name_or_buffer)
  206. elif 'r' in mode:
  207. return _open_buffer_reader(name_or_buffer)
  208. else:
  209. raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
  210. class _open_zipfile_reader(_opener):
  211. def __init__(self, name_or_buffer) -> None:
  212. super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
  213. class _open_zipfile_writer_file(_opener):
  214. def __init__(self, name) -> None:
  215. super().__init__(torch._C.PyTorchFileWriter(str(name)))
  216. def __exit__(self, *args) -> None:
  217. self.file_like.write_end_of_file()
  218. class _open_zipfile_writer_buffer(_opener):
  219. def __init__(self, buffer) -> None:
  220. if not callable(getattr(buffer, "write", None)):
  221. msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
  222. if not hasattr(buffer, "write"):
  223. raise AttributeError(msg)
  224. raise TypeError(msg)
  225. self.buffer = buffer
  226. super().__init__(torch._C.PyTorchFileWriter(buffer))
  227. def __exit__(self, *args) -> None:
  228. self.file_like.write_end_of_file()
  229. self.buffer.flush()
  230. def _open_zipfile_writer(name_or_buffer):
  231. container: Type[_opener]
  232. if _is_path(name_or_buffer):
  233. container = _open_zipfile_writer_file
  234. else:
  235. container = _open_zipfile_writer_buffer
  236. return container(name_or_buffer)
  237. def _is_compressed_file(f) -> bool:
  238. compress_modules = ['gzip']
  239. try:
  240. return f.__module__ in compress_modules
  241. except AttributeError:
  242. return False
  243. def _should_read_directly(f):
  244. """
  245. Checks if f is a file that should be read directly. It should be read
  246. directly if it is backed by a real file (has a fileno) and is not a
  247. a compressed file (e.g. gzip)
  248. """
  249. if _is_compressed_file(f):
  250. return False
  251. try:
  252. return f.fileno() >= 0
  253. except io.UnsupportedOperation:
  254. return False
  255. except AttributeError:
  256. return False
  257. def _check_seekable(f) -> bool:
  258. def raise_err_msg(patterns, e):
  259. for p in patterns:
  260. if p in str(e):
  261. msg = (str(e) + ". You can only torch.load from a file that is seekable."
  262. + " Please pre-load the data into a buffer like io.BytesIO and"
  263. + " try to load from it instead.")
  264. raise type(e)(msg)
  265. raise e
  266. try:
  267. f.seek(f.tell())
  268. return True
  269. except (io.UnsupportedOperation, AttributeError) as e:
  270. raise_err_msg(["seek", "tell"], e)
  271. return False
  272. def _check_dill_version(pickle_module) -> None:
  273. '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
  274. If dill version is lower than 0.3.1, a ValueError is raised.
  275. Args:
  276. pickle_module: module used for pickling metadata and objects
  277. '''
  278. if pickle_module is not None and pickle_module.__name__ == 'dill':
  279. required_dill_version = (0, 3, 1)
  280. if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
  281. raise ValueError((
  282. "'torch' supports dill >= %s, but you have dill %s."
  283. " Please upgrade dill or switch to 'pickle'"
  284. ) % (
  285. '.'.join([str(num) for num in required_dill_version]),
  286. pickle_module.__version__
  287. ))
  288. def _check_save_filelike(f):
  289. if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'):
  290. raise AttributeError((
  291. "expected 'f' to be string, path, or a file-like object with "
  292. "a 'write' attribute"))
  293. def save(
  294. obj: object,
  295. f: FILE_LIKE,
  296. pickle_module: Any = pickle,
  297. pickle_protocol: int = DEFAULT_PROTOCOL,
  298. _use_new_zipfile_serialization: bool = True
  299. ) -> None:
  300. # Reference: https://github.com/pytorch/pytorch/issues/54354
  301. # The first line of this docstring overrides the one Sphinx generates for the
  302. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  303. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  304. """save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
  305. Saves an object to a disk file.
  306. See also: :ref:`saving-loading-tensors`
  307. Args:
  308. obj: saved object
  309. f: a file-like object (has to implement write and flush) or a string or
  310. os.PathLike object containing a file name
  311. pickle_module: module used for pickling metadata and objects
  312. pickle_protocol: can be specified to override the default protocol
  313. .. note::
  314. A common PyTorch convention is to save tensors using .pt file extension.
  315. .. note::
  316. PyTorch preserves storage sharing across serialization. See
  317. :ref:`preserve-storage-sharing` for more details.
  318. .. note::
  319. The 1.6 release of PyTorch switched ``torch.save`` to use a new
  320. zipfile-based file format. ``torch.load`` still retains the ability to
  321. load files in the old format. If for any reason you want ``torch.save``
  322. to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
  323. Example:
  324. >>> # xdoctest: +SKIP("makes cwd dirty")
  325. >>> # Save to file
  326. >>> x = torch.tensor([0, 1, 2, 3, 4])
  327. >>> torch.save(x, 'tensor.pt')
  328. >>> # Save to io.BytesIO buffer
  329. >>> buffer = io.BytesIO()
  330. >>> torch.save(x, buffer)
  331. """
  332. torch._C._log_api_usage_once("torch.save")
  333. _check_dill_version(pickle_module)
  334. _check_save_filelike(f)
  335. if _use_new_zipfile_serialization:
  336. with _open_zipfile_writer(f) as opened_zipfile:
  337. _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  338. return
  339. else:
  340. with _open_file_like(f, 'wb') as opened_file:
  341. _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  342. def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
  343. import torch.nn as nn
  344. serialized_container_types = {}
  345. serialized_storages = {}
  346. # Since loading storages that view the same data with different dtypes is
  347. # not supported, we need to keep track of the dtype associated with each
  348. # storage data_ptr and throw an error if the dtype is ever different.
  349. # TODO: This feature could be added in the future
  350. storage_dtypes: Dict[int, torch.dtype] = {}
  351. def persistent_id(obj: Any) -> Optional[Tuple]:
  352. # FIXME: the docs say that persistent_id should only return a string
  353. # but torch store returns tuples. This works only in the binary protocol
  354. # see
  355. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  356. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  357. if isinstance(obj, type) and issubclass(obj, nn.Module):
  358. if obj in serialized_container_types:
  359. return None
  360. serialized_container_types[obj] = True
  361. source_file = source = None
  362. try:
  363. source_lines, _, source_file = get_source_lines_and_file(obj)
  364. source = ''.join(source_lines)
  365. except Exception: # saving the source is optional, so we can ignore any errors
  366. warnings.warn("Couldn't retrieve source code for container of "
  367. "type " + obj.__name__ + ". It won't be checked "
  368. "for correctness upon loading.")
  369. return ('module', obj, source_file, source)
  370. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  371. storage: torch.UntypedStorage
  372. if isinstance(obj, torch.storage.TypedStorage):
  373. # TODO: Once we decide to break serialization FC, this case
  374. # can be deleted
  375. storage = obj._untyped_storage
  376. storage_dtype = obj.dtype
  377. storage_type_str = obj._pickle_storage_type()
  378. storage_type = getattr(torch, storage_type_str)
  379. dtype = obj.dtype
  380. storage_numel = obj._size()
  381. elif isinstance(obj, torch.UntypedStorage):
  382. storage = obj
  383. storage_dtype = torch.uint8
  384. storage_type = normalize_storage_type(type(obj))
  385. dtype = torch.uint8
  386. storage_numel = storage.nbytes()
  387. else:
  388. raise TypeError(f'type not recognized: {type(obj)}')
  389. # If storage is allocated, ensure that any other saved storages
  390. # pointing to the same data all have the same dtype. If storage is
  391. # not allocated, don't perform this check
  392. if storage.data_ptr() != 0:
  393. if storage.data_ptr() in storage_dtypes:
  394. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  395. raise RuntimeError(
  396. 'Cannot save multiple tensors or storages that '
  397. 'view the same data as different types')
  398. else:
  399. storage_dtypes[storage.data_ptr()] = storage_dtype
  400. view_metadata: Optional[Tuple[str, int, int]]
  401. # Offset is always 0, but we keep it for backwards compatibility
  402. # with the old serialization format (which supported storage views)
  403. offset = 0
  404. storage_key = str(storage._cdata)
  405. location = location_tag(storage)
  406. # TODO: There's an issue here with FC. It might be impossible to
  407. # solve, but it's worth noting. Imagine we save a list `[storage,
  408. # tensor]`, where `tensor.storage()` is the same as `storage`, and
  409. # `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
  410. # torch.float`. The storage will be serialized with element size
  411. # of 1, since we're choosing to serialize the first occurance of
  412. # a duplicate storage. Since this legacy serialization format saves
  413. # the numel of the storage, rather than nbytes directly, we'll be
  414. # effectively saving nbytes in this case. We'll be able to load it
  415. # and the tensor back up with no problems in _this_ and future
  416. # versions of pytorch, but in older versions, here's the problem:
  417. # the storage will be loaded up as a UntypedStorage, and then the
  418. # FloatTensor will loaded and the UntypedStorage will be assigned to
  419. # it. Since the storage dtype does not match the tensor dtype, this
  420. # will cause an error. If we reverse the list, like `[tensor,
  421. # storage]`, then we will save the `tensor.storage()` as a faked
  422. # `FloatStorage`, and the saved size will be the correct
  423. # dtype-specific numel count that old versions expect. `tensor`
  424. # will be able to load up properly in old versions, pointing to
  425. # a FloatStorage. However, `storage` is still being translated to
  426. # a UntypedStorage, and it will try to resolve to the same
  427. # FloatStorage that `tensor` contains. This will also cause an
  428. # error. It doesn't seem like there's any way around this.
  429. # Probably, we just cannot maintain FC for the legacy format if the
  430. # saved list contains both a tensor and a storage that point to the
  431. # same data. We should still be able to maintain FC for lists of
  432. # just tensors, as long as all views share the same dtype as the
  433. # tensor they are viewing.
  434. if storage_key not in serialized_storages:
  435. serialized_storages[storage_key] = (storage, dtype)
  436. is_view = storage._cdata != storage._cdata
  437. if is_view:
  438. view_metadata = (str(storage._cdata), offset, storage.nbytes())
  439. else:
  440. view_metadata = None
  441. res = ('storage',
  442. storage_type,
  443. storage_key,
  444. location,
  445. storage_numel,
  446. view_metadata)
  447. return res
  448. return None
  449. sys_info = dict(
  450. protocol_version=PROTOCOL_VERSION,
  451. little_endian=sys.byteorder == 'little',
  452. type_sizes=dict(
  453. short=SHORT_SIZE,
  454. int=INT_SIZE,
  455. long=LONG_SIZE,
  456. ),
  457. )
  458. pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
  459. pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
  460. pickle_module.dump(sys_info, f, protocol=pickle_protocol)
  461. pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
  462. pickler.persistent_id = persistent_id
  463. pickler.dump(obj)
  464. serialized_storage_keys = sorted(serialized_storages.keys())
  465. pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
  466. f.flush()
  467. for key in serialized_storage_keys:
  468. storage, dtype = serialized_storages[key]
  469. storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
  470. def _save(obj, zip_file, pickle_module, pickle_protocol):
  471. serialized_storages = {}
  472. id_map: Dict[int, str] = {}
  473. # Since loading storages that view the same data with different dtypes is
  474. # not supported, we need to keep track of the dtype associated with each
  475. # storage data_ptr and throw an error if the dtype is ever different.
  476. # TODO: This feature could be added in the future
  477. storage_dtypes: Dict[int, torch.dtype] = {}
  478. def persistent_id(obj):
  479. # FIXME: the docs say that persistent_id should only return a string
  480. # but torch store returns tuples. This works only in the binary protocol
  481. # see
  482. # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
  483. # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
  484. if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
  485. if isinstance(obj, torch.storage.TypedStorage):
  486. # TODO: Once we decide to break serialization FC, this case
  487. # can be deleted
  488. storage = obj._untyped_storage
  489. storage_dtype = obj.dtype
  490. storage_type_str = obj._pickle_storage_type()
  491. storage_type = getattr(torch, storage_type_str)
  492. storage_numel = obj._size()
  493. else:
  494. storage = obj
  495. storage_dtype = torch.uint8
  496. storage_type = normalize_storage_type(type(obj))
  497. storage_numel = storage.nbytes()
  498. # If storage is allocated, ensure that any other saved storages
  499. # pointing to the same data all have the same dtype. If storage is
  500. # not allocated, don't perform this check
  501. if storage.data_ptr() != 0:
  502. if storage.data_ptr() in storage_dtypes:
  503. if storage_dtype != storage_dtypes[storage.data_ptr()]:
  504. raise RuntimeError(
  505. 'Cannot save multiple tensors or storages that '
  506. 'view the same data as different types')
  507. else:
  508. storage_dtypes[storage.data_ptr()] = storage_dtype
  509. storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
  510. location = location_tag(storage)
  511. serialized_storages[storage_key] = storage
  512. return ('storage',
  513. storage_type,
  514. storage_key,
  515. location,
  516. storage_numel)
  517. return None
  518. # Write the pickle data for `obj`
  519. data_buf = io.BytesIO()
  520. pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
  521. pickler.persistent_id = persistent_id
  522. pickler.dump(obj)
  523. data_value = data_buf.getvalue()
  524. zip_file.write_record('data.pkl', data_value, len(data_value))
  525. # Write each tensor to a file named tensor/the_tensor_key in the zip archive
  526. for key in sorted(serialized_storages.keys()):
  527. name = f'data/{key}'
  528. storage = serialized_storages[key]
  529. # given that we copy things around anyway, we might use storage.cpu()
  530. # this means to that to get tensors serialized, you need to implement
  531. # .cpu() on the underlying Storage
  532. if storage.device.type != 'cpu':
  533. storage = storage.cpu()
  534. # Now that it is on the CPU we can directly copy it into the zip file
  535. num_bytes = storage.nbytes()
  536. zip_file.write_record(name, storage.data_ptr(), num_bytes)
  537. def load(
  538. f: FILE_LIKE,
  539. map_location: MAP_LOCATION = None,
  540. pickle_module: Any = None,
  541. *,
  542. weights_only: bool = False,
  543. **pickle_load_args: Any
  544. ) -> Any:
  545. # Reference: https://github.com/pytorch/pytorch/issues/54354
  546. # The first line of this docstring overrides the one Sphinx generates for the
  547. # documentation. We need it so that Sphinx doesn't leak `pickle`s path from
  548. # the build environment (e.g. `<module 'pickle' from '/leaked/path').
  549. """load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)
  550. Loads an object saved with :func:`torch.save` from a file.
  551. :func:`torch.load` uses Python's unpickling facilities but treats storages,
  552. which underlie tensors, specially. They are first deserialized on the
  553. CPU and are then moved to the device they were saved from. If this fails
  554. (e.g. because the run time system doesn't have certain devices), an exception
  555. is raised. However, storages can be dynamically remapped to an alternative
  556. set of devices using the :attr:`map_location` argument.
  557. If :attr:`map_location` is a callable, it will be called once for each serialized
  558. storage with two arguments: storage and location. The storage argument
  559. will be the initial deserialization of the storage, residing on the CPU.
  560. Each serialized storage has a location tag associated with it which
  561. identifies the device it was saved from, and this tag is the second
  562. argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
  563. for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
  564. :attr:`map_location` should return either ``None`` or a storage. If
  565. :attr:`map_location` returns a storage, it will be used as the final deserialized
  566. object, already moved to the right device. Otherwise, :func:`torch.load` will
  567. fall back to the default behavior, as if :attr:`map_location` wasn't specified.
  568. If :attr:`map_location` is a :class:`torch.device` object or a string containing
  569. a device tag, it indicates the location where all tensors should be loaded.
  570. Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
  571. appearing in the file (keys), to ones that specify where to put the
  572. storages (values).
  573. User extensions can register their own location tags and tagging and
  574. deserialization methods using :func:`torch.serialization.register_package`.
  575. Args:
  576. f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  577. or a string or os.PathLike object containing a file name
  578. map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
  579. locations
  580. pickle_module: module used for unpickling metadata and objects (has to
  581. match the :attr:`pickle_module` used to serialize file)
  582. weights_only: Indicates whether unpickler should be restricted to
  583. loading only tensors, primitive types and dictionaries
  584. pickle_load_args: (Python 3 only) optional keyword arguments passed over to
  585. :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
  586. :attr:`errors=...`.
  587. .. warning::
  588. :func:`torch.load()` unless `weights_only` parameter is set to `True`,
  589. uses ``pickle`` module implicitly, which is known to be insecure.
  590. It is possible to construct malicious pickle data which will execute arbitrary code
  591. during unpickling. Never load data that could have come from an untrusted
  592. source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
  593. .. note::
  594. When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
  595. will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
  596. and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
  597. .. note::
  598. By default, we decode byte strings as ``utf-8``. This is to avoid a common error
  599. case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
  600. when loading files saved by Python 2 in Python 3. If this default
  601. is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
  602. these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
  603. to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
  604. as byte arrays which can be decoded later with ``byte_array.decode(...)``.
  605. Example:
  606. >>> # xdoctest: +SKIP("undefined filepaths")
  607. >>> torch.load('tensors.pt')
  608. # Load all tensors onto the CPU
  609. >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
  610. # Load all tensors onto the CPU, using a function
  611. >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
  612. # Load all tensors onto GPU 1
  613. >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
  614. # Map tensors from GPU 1 to GPU 0
  615. >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'})
  616. # Load tensor from io.BytesIO object
  617. >>> with open('tensor.pt', 'rb') as f:
  618. ... buffer = io.BytesIO(f.read())
  619. >>> torch.load(buffer)
  620. # Load a module with 'ascii' encoding for unpickling
  621. >>> torch.load('module.pt', encoding='ascii')
  622. """
  623. torch._C._log_api_usage_once("torch.load")
  624. UNSAFE_MESSAGE = (
  625. "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`"
  626. " will likely succeed, but it can result in arbitrary code execution."
  627. "Do it only if you get the file from a trusted source. WeightsUnpickler error: "
  628. )
  629. # Add ability to force safe only weight loads via environment variable
  630. if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
  631. weights_only = True
  632. if weights_only:
  633. if pickle_module is not None:
  634. raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
  635. else:
  636. if pickle_module is None:
  637. pickle_module = pickle
  638. _check_dill_version(pickle_module)
  639. if 'encoding' not in pickle_load_args.keys():
  640. pickle_load_args['encoding'] = 'utf-8'
  641. with _open_file_like(f, 'rb') as opened_file:
  642. if _is_zipfile(opened_file):
  643. # The zipfile reader is going to advance the current file position.
  644. # If we want to actually tail call to torch.jit.load, we need to
  645. # reset back to the original position.
  646. orig_position = opened_file.tell()
  647. with _open_zipfile_reader(opened_file) as opened_zipfile:
  648. if _is_torchscript_zip(opened_zipfile):
  649. warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
  650. " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
  651. " silence this warning)", UserWarning)
  652. opened_file.seek(orig_position)
  653. return torch.jit.load(opened_file, map_location=map_location)
  654. if weights_only:
  655. try:
  656. return _load(opened_zipfile, map_location, _weights_only_unpickler, **pickle_load_args)
  657. except RuntimeError as e:
  658. raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
  659. return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  660. if weights_only:
  661. try:
  662. return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
  663. except RuntimeError as e:
  664. raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
  665. return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  666. # Register pickling support for layout instances such as
  667. # torch.sparse_coo, etc
  668. def _get_layout(name):
  669. """Get layout extension object from its string representation.
  670. """
  671. cache = _get_layout.cache # type: ignore[attr-defined]
  672. if not cache:
  673. for v in torch.__dict__.values():
  674. if isinstance(v, torch.layout):
  675. cache[str(v)] = v
  676. return cache[name]
  677. # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
  678. _get_layout.cache = {} # type: ignore[attr-defined]
  679. copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
  680. def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
  681. deserialized_objects: Dict[int, Any] = {}
  682. restore_location = _get_restore_location(map_location)
  683. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  684. def find_class(self, mod_name, name):
  685. if type(name) is str and 'Storage' in name:
  686. try:
  687. return StorageType(name)
  688. except KeyError:
  689. pass
  690. return super().find_class(mod_name, name)
  691. def _check_container_source(container_type, source_file, original_source):
  692. try:
  693. current_source = ''.join(get_source_lines_and_file(container_type)[0])
  694. except Exception: # saving the source is optional, so we can ignore any errors
  695. warnings.warn("Couldn't retrieve source code for container of "
  696. "type " + container_type.__name__ + ". It won't be checked "
  697. "for correctness upon loading.")
  698. return
  699. if original_source != current_source:
  700. if container_type.dump_patches:
  701. file_name = container_type.__name__ + '.patch'
  702. diff = difflib.unified_diff(current_source.split('\n'),
  703. original_source.split('\n'),
  704. source_file,
  705. source_file, lineterm="")
  706. lines = '\n'.join(diff)
  707. try:
  708. with open(file_name, 'a+') as f:
  709. file_size = f.seek(0, 2)
  710. f.seek(0)
  711. if file_size == 0:
  712. f.write(lines)
  713. elif file_size != len(lines) or f.read() != lines:
  714. raise IOError
  715. msg = ("Saved a reverse patch to " + file_name + ". "
  716. "Run `patch -p0 < " + file_name + "` to revert your "
  717. "changes.")
  718. except IOError:
  719. msg = ("Tried to save a patch, but couldn't create a "
  720. "writable file " + file_name + ". Make sure it "
  721. "doesn't exist and your working directory is "
  722. "writable.")
  723. else:
  724. msg = ("you can retrieve the original source code by "
  725. "accessing the object's source attribute or set "
  726. "`torch.nn.Module.dump_patches = True` and use the "
  727. "patch tool to revert the changes.")
  728. msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
  729. warnings.warn(msg, SourceChangeWarning)
  730. def legacy_load(f):
  731. deserialized_objects: Dict[int, Any] = {}
  732. def persistent_load(saved_id):
  733. if isinstance(saved_id, tuple):
  734. # Ignore containers that don't have any sources saved
  735. if all(saved_id[1:]):
  736. _check_container_source(*saved_id)
  737. return saved_id[0]
  738. return deserialized_objects[int(saved_id)]
  739. with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
  740. mkdtemp() as tmpdir:
  741. tar.extract('storages', path=tmpdir)
  742. with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
  743. num_storages = pickle_module.load(f, **pickle_load_args)
  744. for i in range(num_storages):
  745. args = pickle_module.load(f, **pickle_load_args)
  746. key, location, storage_type = args
  747. dtype = storage_type._dtype
  748. obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
  749. obj = restore_location(obj, location)
  750. # TODO: Once we decide to break serialization FC, we can
  751. # stop wrapping with TypedStorage
  752. deserialized_objects[key] = torch.storage.TypedStorage(
  753. wrap_storage=obj,
  754. dtype=dtype,
  755. _internal=True)
  756. storage_views = pickle_module.load(f, **pickle_load_args)
  757. for target_cdata, root_cdata, offset, numel in storage_views:
  758. root = deserialized_objects[root_cdata]
  759. element_size = torch._utils._element_size(root.dtype)
  760. offset_bytes = offset * element_size
  761. # TODO: Once we decide to break serialization FC, we can
  762. # stop wrapping with TypedStorage
  763. deserialized_objects[target_cdata] = torch.storage.TypedStorage(
  764. wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
  765. dtype=root.dtype,
  766. _internal=True)
  767. tar.extract('tensors', path=tmpdir)
  768. with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
  769. num_tensors = pickle_module.load(f, **pickle_load_args)
  770. for _ in range(num_tensors):
  771. args = pickle_module.load(f, **pickle_load_args)
  772. key, storage_id, original_tensor_type = args
  773. storage = deserialized_objects[storage_id]
  774. ndim, = struct.unpack('<i', f.read(4))
  775. # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
  776. f.read(4)
  777. numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  778. stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
  779. storage_offset, = struct.unpack('<q', f.read(8))
  780. tensor = torch.tensor([], dtype=storage.dtype).set_(
  781. storage._untyped_storage, storage_offset, numel, stride)
  782. deserialized_objects[key] = tensor
  783. pickle_file = tar.extractfile('pickle')
  784. unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
  785. unpickler.persistent_load = persistent_load
  786. result = unpickler.load()
  787. return result
  788. deserialized_objects = {}
  789. def persistent_load(saved_id):
  790. assert isinstance(saved_id, tuple)
  791. typename = _maybe_decode_ascii(saved_id[0])
  792. data = saved_id[1:]
  793. if typename == 'module':
  794. # Ignore containers that don't have any sources saved
  795. if all(data[1:]):
  796. _check_container_source(*data)
  797. return data[0]
  798. elif typename == 'storage':
  799. storage_type, root_key, location, numel, view_metadata = data
  800. location = _maybe_decode_ascii(location)
  801. dtype = storage_type.dtype
  802. nbytes = numel * torch._utils._element_size(dtype)
  803. if root_key not in deserialized_objects:
  804. obj = cast(Storage, torch.UntypedStorage(nbytes))
  805. obj._torch_load_uninitialized = True
  806. # TODO: Once we decide to break serialization FC, we can
  807. # stop wrapping with TypedStorage
  808. typed_storage = torch.storage.TypedStorage(
  809. wrap_storage=restore_location(obj, location),
  810. dtype=dtype,
  811. _internal=True)
  812. deserialized_objects[root_key] = typed_storage
  813. else:
  814. typed_storage = deserialized_objects[root_key]
  815. if typed_storage._data_ptr() == 0:
  816. typed_storage = torch.storage.TypedStorage(
  817. device=typed_storage._untyped_storage.device,
  818. dtype=dtype,
  819. _internal=True)
  820. if view_metadata is not None:
  821. view_key, offset, view_size = view_metadata
  822. offset_bytes = offset * torch._utils._element_size(dtype)
  823. view_size_bytes = view_size * torch._utils._element_size(dtype)
  824. if view_key not in deserialized_objects:
  825. # TODO: Once we decide to break serialization FC, we can
  826. # stop wrapping with TypedStorage
  827. deserialized_objects[view_key] = torch.storage.TypedStorage(
  828. wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
  829. dtype=dtype,
  830. _internal=True)
  831. res = deserialized_objects[view_key]
  832. else:
  833. res = typed_storage
  834. return res
  835. else:
  836. raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
  837. _check_seekable(f)
  838. f_should_read_directly = _should_read_directly(f)
  839. if f_should_read_directly and f.tell() == 0:
  840. # legacy_load requires that f has fileno()
  841. # only if offset is zero we can attempt the legacy tar file loader
  842. try:
  843. return legacy_load(f)
  844. except tarfile.TarError:
  845. if _is_zipfile(f):
  846. # .zip is used for torch.jit.save and will throw an un-pickling error here
  847. raise RuntimeError(
  848. f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
  849. # if not a tarfile, reset file offset and proceed
  850. f.seek(0)
  851. if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
  852. raise RuntimeError(
  853. "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
  854. f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
  855. "functionality.")
  856. magic_number = pickle_module.load(f, **pickle_load_args)
  857. if magic_number != MAGIC_NUMBER:
  858. raise RuntimeError("Invalid magic number; corrupt file?")
  859. protocol_version = pickle_module.load(f, **pickle_load_args)
  860. if protocol_version != PROTOCOL_VERSION:
  861. raise RuntimeError("Invalid protocol version: %s" % protocol_version)
  862. _sys_info = pickle_module.load(f, **pickle_load_args)
  863. unpickler = UnpicklerWrapper(f, **pickle_load_args)
  864. unpickler.persistent_load = persistent_load
  865. result = unpickler.load()
  866. deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
  867. offset = f.tell() if f_should_read_directly else None
  868. for key in deserialized_storage_keys:
  869. assert key in deserialized_objects
  870. typed_storage = deserialized_objects[key]
  871. typed_storage._untyped_storage._set_from_file(
  872. f, offset, f_should_read_directly,
  873. torch._utils._element_size(typed_storage.dtype))
  874. if offset is not None:
  875. offset = f.tell()
  876. torch._utils._validate_loaded_sparse_tensors()
  877. return result
  878. def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
  879. # When using encoding='bytes' in Py3, some **internal** keys stored as
  880. # strings in Py2 are loaded as bytes. This function decodes them with
  881. # ascii encoding, one that Py3 uses by default.
  882. #
  883. # NOTE: This should only be used on internal keys (e.g., `typename` and
  884. # `location` in `persistent_load` below!
  885. if isinstance(bytes_str, bytes):
  886. return bytes_str.decode('ascii')
  887. return bytes_str
  888. def _get_restore_location(map_location):
  889. if map_location is None:
  890. restore_location = default_restore_location
  891. elif isinstance(map_location, dict):
  892. def restore_location(storage, location):
  893. location = map_location.get(location, location)
  894. return default_restore_location(storage, location)
  895. elif isinstance(map_location, str):
  896. def restore_location(storage, location):
  897. return default_restore_location(storage, map_location)
  898. elif isinstance(map_location, torch.device):
  899. def restore_location(storage, location):
  900. return default_restore_location(storage, str(map_location))
  901. else:
  902. def restore_location(storage, location):
  903. result = map_location(storage, location)
  904. if result is None:
  905. result = default_restore_location(storage, location)
  906. return result
  907. return restore_location
  908. class StorageType():
  909. def __init__(self, name):
  910. self.dtype = _get_dtype_from_pickle_storage_type(name)
  911. def __str__(self):
  912. return f'StorageType(dtype={self.dtype})'
  913. def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
  914. restore_location = _get_restore_location(map_location)
  915. loaded_storages = {}
  916. def load_tensor(dtype, numel, key, location):
  917. name = f'data/{key}'
  918. storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
  919. # TODO: Once we decide to break serialization FC, we can
  920. # stop wrapping with TypedStorage
  921. typed_storage = torch.storage.TypedStorage(
  922. wrap_storage=restore_location(storage, location),
  923. dtype=dtype,
  924. _internal=True)
  925. if typed_storage._data_ptr() != 0:
  926. loaded_storages[key] = typed_storage
  927. return typed_storage
  928. def persistent_load(saved_id):
  929. assert isinstance(saved_id, tuple)
  930. typename = _maybe_decode_ascii(saved_id[0])
  931. data = saved_id[1:]
  932. assert typename == 'storage', \
  933. f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
  934. storage_type, key, location, numel = data
  935. if storage_type is torch.UntypedStorage:
  936. dtype = torch.uint8
  937. else:
  938. dtype = storage_type.dtype
  939. if key in loaded_storages:
  940. typed_storage = loaded_storages[key]
  941. else:
  942. nbytes = numel * torch._utils._element_size(dtype)
  943. typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  944. return typed_storage
  945. load_module_mapping: Dict[str, str] = {
  946. # See https://github.com/pytorch/pytorch/pull/51633
  947. 'torch.tensor': 'torch._tensor'
  948. }
  949. # Need to subclass Unpickler instead of directly monkey-patching the find_class method
  950. # because it's marked readonly in pickle.
  951. # The type: ignore is because mypy can't statically determine the type of this class.
  952. class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
  953. # from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
  954. # Lets us override the imports that pickle uses when unpickling an object.
  955. # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
  956. def find_class(self, mod_name, name):
  957. if type(name) is str and 'Storage' in name:
  958. try:
  959. return StorageType(name)
  960. except KeyError:
  961. pass
  962. mod_name = load_module_mapping.get(mod_name, mod_name)
  963. return super().find_class(mod_name, name)
  964. # Load the data (which may in turn use `persistent_load` to load tensors)
  965. data_file = io.BytesIO(zip_file.get_record(pickle_file))
  966. unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
  967. unpickler.persistent_load = persistent_load
  968. result = unpickler.load()
  969. torch._utils._validate_loaded_sparse_tensors()
  970. return result
  971. def _is_torchscript_zip(zip_file):
  972. return 'constants.pkl' in zip_file.get_all_records()