123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import io
- import torch
- from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
- from torch.package._package_pickler import create_pickler
- from torch.package._package_unpickler import PackageUnpickler
- from torch.serialization import _maybe_decode_ascii
- def _save_storages(importer, obj):
- serialized_storages = []
- serialized_dtypes = []
- importer = importer if isinstance(importer, torch.package.PackageImporter) else None
- importers: Importer
- if importer is not None:
- importers = OrderedImporter(importer, sys_importer)
- else:
- importers = sys_importer
- def persistent_id(obj):
- if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, we can
- # remove this case
- storage = obj._untyped_storage
- dtype = obj.dtype
- else:
- storage = obj
- dtype = torch.uint8
- serialized_storages.append(obj)
- serialized_dtypes.append(dtype)
- return ("storage", len(serialized_storages) - 1)
- if hasattr(obj, "__reduce_deploy__"):
- if _serialized_reduces.get(id(obj)) is None:
- _serialized_reduces[id(obj)] = (
- "reduce_deploy",
- id(obj),
- *obj.__reduce_deploy__(importers),
- )
- return _serialized_reduces[id(obj)]
- return None
- # Write the pickle data for `obj`
- data_buf = io.BytesIO()
- pickler = create_pickler(data_buf, importers)
- pickler.persistent_id = persistent_id
- pickler.dump(obj)
- data_value = data_buf.getvalue()
- return (
- data_value,
- serialized_storages,
- serialized_dtypes,
- importer.zip_reader if importer else None,
- )
- def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
- def persistent_load(saved_id):
- assert isinstance(saved_id, tuple)
- typename = _maybe_decode_ascii(saved_id[0])
- data = saved_id[1:]
- if typename == "storage":
- # TODO: Once we decide to break serialization FC, we can
- # stop wrapping with TypedStorage
- storage = serialized_storages[data[0]]
- dtype = serialized_dtypes[data[0]]
- return torch.storage.TypedStorage(
- wrap_storage=storage.untyped(), dtype=dtype
- )
- if typename == "reduce_deploy":
- reduce_id, func, args = data
- if reduce_id not in _loaded_reduces:
- _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
- return _loaded_reduces[reduce_id]
- return None
- importer: Importer
- if zip_reader is not None:
- importer = OrderedImporter(_get_package(zip_reader), sys_importer)
- else:
- importer = sys_importer
- unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
- unpickler.persistent_load = persistent_load # type: ignore[assignment]
- result = _deploy_objects[id] = unpickler.load()
- return result
- def _get_package(zip_reader):
- if zip_reader not in _raw_packages:
- _raw_packages[zip_reader] = PackageImporter(zip_reader)
- return _raw_packages[zip_reader]
- _raw_packages: dict = {}
- _deploy_objects: dict = {}
- _serialized_reduces: dict = {}
- _loaded_reduces: dict = {}
|