12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import os.path
- from glob import glob
- from typing import cast
- import torch
- from torch.types import Storage
- # because get_storage_from_record returns a tensor!?
- class _HasStorage:
- def __init__(self, storage):
- self._storage = storage
- def storage(self):
- return self._storage
- class DirectoryReader:
- """
- Class to allow PackageImporter to operate on unzipped packages. Methods
- copy the behavior of the internal PyTorchFileReader class (which is used for
- accessing packages in all other cases).
- N.B.: ScriptObjects are not depickleable or accessible via this DirectoryReader
- class due to ScriptObjects requiring an actual PyTorchFileReader instance.
- """
- def __init__(self, directory):
- self.directory = directory
- def get_record(self, name):
- filename = f"{self.directory}/{name}"
- with open(filename, "rb") as f:
- return f.read()
- def get_storage_from_record(self, name, numel, dtype):
- filename = f"{self.directory}/{name}"
- nbytes = torch._utils._element_size(dtype) * numel
- storage = cast(Storage, torch.UntypedStorage)
- return _HasStorage(storage.from_file(filename=filename, nbytes=nbytes))
- def has_record(self, path):
- full_path = os.path.join(self.directory, path)
- return os.path.isfile(full_path)
- def get_all_records(
- self,
- ):
- files = []
- for filename in glob(f"{self.directory}/**", recursive=True):
- if not os.path.isdir(filename):
- files.append(filename[len(self.directory) + 1 :])
- return files
|