_directory_reader.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os.path
  2. from glob import glob
  3. from typing import cast
  4. import torch
  5. from torch.types import Storage
  6. # because get_storage_from_record returns a tensor!?
  7. class _HasStorage:
  8. def __init__(self, storage):
  9. self._storage = storage
  10. def storage(self):
  11. return self._storage
  12. class DirectoryReader:
  13. """
  14. Class to allow PackageImporter to operate on unzipped packages. Methods
  15. copy the behavior of the internal PyTorchFileReader class (which is used for
  16. accessing packages in all other cases).
  17. N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
  18. class due to ScriptObjects requiring an actual PyTorchFileReader instance.
  19. """
  20. def __init__(self, directory):
  21. self.directory = directory
  22. def get_record(self, name):
  23. filename = f"{self.directory}/{name}"
  24. with open(filename, "rb") as f:
  25. return f.read()
  26. def get_storage_from_record(self, name, numel, dtype):
  27. filename = f"{self.directory}/{name}"
  28. nbytes = torch._utils._element_size(dtype) * numel
  29. storage = cast(Storage, torch.UntypedStorage)
  30. return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
  31. def has_record(self, path):
  32. full_path = os.path.join(self.directory, path)
  33. return os.path.isfile(full_path)
  34. def get_all_records(
  35. self,
  36. ):
  37. files = []
  38. for filename in glob(f"{self.directory}/**", recursive=True):
  39. if not os.path.isdir(filename):
  40. files.append(filename[len(self.directory) + 1 :])
  41. return files