_deploy.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import io
  2. import torch
  3. from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer
  4. from torch.package._package_pickler import create_pickler
  5. from torch.package._package_unpickler import PackageUnpickler
  6. from torch.serialization import _maybe_decode_ascii
  7. def _save_storages(importer, obj):
  8. serialized_storages = []
  9. serialized_dtypes = []
  10. importer = importer if isinstance(importer, torch.package.PackageImporter) else None
  11. importers: Importer
  12. if importer is not None:
  13. importers = OrderedImporter(importer, sys_importer)
  14. else:
  15. importers = sys_importer
  16. def persistent_id(obj):
  17. if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
  18. if isinstance(obj, torch.storage.TypedStorage):
  19. # TODO: Once we decide to break serialization FC, we can
  20. # remove this case
  21. storage = obj._untyped_storage
  22. dtype = obj.dtype
  23. else:
  24. storage = obj
  25. dtype = torch.uint8
  26. serialized_storages.append(obj)
  27. serialized_dtypes.append(dtype)
  28. return ("storage", len(serialized_storages) - 1)
  29. if hasattr(obj, "__reduce_deploy__"):
  30. if _serialized_reduces.get(id(obj)) is None:
  31. _serialized_reduces[id(obj)] = (
  32. "reduce_deploy",
  33. id(obj),
  34. *obj.__reduce_deploy__(importers),
  35. )
  36. return _serialized_reduces[id(obj)]
  37. return None
  38. # Write the pickle data for `obj`
  39. data_buf = io.BytesIO()
  40. pickler = create_pickler(data_buf, importers)
  41. pickler.persistent_id = persistent_id
  42. pickler.dump(obj)
  43. data_value = data_buf.getvalue()
  44. return (
  45. data_value,
  46. serialized_storages,
  47. serialized_dtypes,
  48. importer.zip_reader if importer else None,
  49. )
  50. def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes):
  51. def persistent_load(saved_id):
  52. assert isinstance(saved_id, tuple)
  53. typename = _maybe_decode_ascii(saved_id[0])
  54. data = saved_id[1:]
  55. if typename == "storage":
  56. # TODO: Once we decide to break serialization FC, we can
  57. # stop wrapping with TypedStorage
  58. storage = serialized_storages[data[0]]
  59. dtype = serialized_dtypes[data[0]]
  60. return torch.storage.TypedStorage(
  61. wrap_storage=storage.untyped(), dtype=dtype
  62. )
  63. if typename == "reduce_deploy":
  64. reduce_id, func, args = data
  65. if reduce_id not in _loaded_reduces:
  66. _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args)
  67. return _loaded_reduces[reduce_id]
  68. return None
  69. importer: Importer
  70. if zip_reader is not None:
  71. importer = OrderedImporter(_get_package(zip_reader), sys_importer)
  72. else:
  73. importer = sys_importer
  74. unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes))
  75. unpickler.persistent_load = persistent_load # type: ignore[assignment]
  76. result = _deploy_objects[id] = unpickler.load()
  77. return result
  78. def _get_package(zip_reader):
  79. if zip_reader not in _raw_packages:
  80. _raw_packages[zip_reader] = PackageImporter(zip_reader)
  81. return _raw_packages[zip_reader]
  82. _raw_packages: dict = {}
  83. _deploy_objects: dict = {}
  84. _serialized_reduces: dict = {}
  85. _loaded_reduces: dict = {}