123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524 |
- from abc import ABC, abstractmethod
- import queue
- import threading
- import collections
- from dataclasses import dataclass
- import os
- import dataclasses
- import io
- import pickle
- from typing import List, Union, Dict, cast
- import torch
- from torch import Tensor
- from torch.futures import Future
- from pathlib import Path
- from .metadata import (
- Metadata,
- MetadataIndex,
- )
- from .storage import (
- StorageReader,
- StorageWriter,
- WriteResult,
- )
- from .planner import (
- LoadItemType,
- LoadPlanner,
- LoadPlan,
- SavePlan,
- SavePlanner,
- ReadItem,
- WriteItem,
- WriteItemType,
- )
- from torch.distributed._shard._utils import narrow_tensor_by_index
- __all__ = [
- "FileSystemWriter",
- "SlicedBufferedReader",
- "FileSystemReader",
- ]
- @dataclass
- class _StorageInfo:
- """
- This is the per entry storage info
- """
- relative_path: str
- offset: int
- length: int
- @dataclass
- class _StoragePrefix:
- prefix: str
- DEFAULT_SUFFIX = ".distcp"
- def _trim(tensor: torch.Tensor) -> torch.Tensor:
- tensor = tensor.detach().cpu()
- if tensor._typed_storage()._size() != tensor.numel():
- tensor = tensor.clone()
- return tensor
- def _result_from_write_item(
- item: WriteItem, size_in_bytes, storage_data
- ) -> WriteResult:
- return WriteResult(
- index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data
- )
- class _TensorLoader(ABC):
- @abstractmethod
- def add(self, size, obj):
- pass
- def start_loading(self):
- pass
- @abstractmethod
- def values(self):
- pass
- class _SerialCpuLoader(_TensorLoader):
- def __init__(self, resolve_fun):
- self.resolve_fun = resolve_fun
- self.items = []
- def add(self, size, obj):
- self.items.append((size, obj))
- def start_loading(self):
- pass
- def values(self):
- for _, obj in self.items:
- tensor = self.resolve_fun(obj).detach()
- tensor = tensor.cpu()
- if tensor.storage().size() != tensor.numel():
- tensor = tensor.clone()
- yield (
- tensor,
- obj,
- )
- class _OverlappingCpuLoader(_TensorLoader):
- def __init__(self, resolve_fun, stream=None, inflight_threshhold=1_000_000):
- self.resolve_fun = resolve_fun
- self.items = []
- self.inflight_threshhold = inflight_threshhold
- self.in_flight_data = 0
- self.current_items: collections.deque = collections.deque()
- self.idx = 0
- self.started = False
- self.stream = stream or torch.cuda.current_stream()
- if self.stream != torch.cuda.current_stream():
- self.stream.wait_stream(torch.cuda.current_stream())
- @property
- def _done(self):
- return self.idx >= len(self.items)
- def _drain(self):
- drained = []
- if self.in_flight_data >= self.inflight_threshhold:
- self.stream.synchronize()
- while self.in_flight_data >= self.inflight_threshhold:
- val = self.current_items.popleft()
- self.in_flight_data -= val[0].numel() * val[0].element_size()
- drained.append(val)
- return drained
- def _refill(self):
- with torch.cuda.stream(self.stream):
- while (
- not self._done
- and self.in_flight_data < self.inflight_threshhold
- ):
- _, obj = self.items[self.idx]
- self.idx += 1
- tensor = self.resolve_fun(obj).detach()
- if tensor.is_cuda:
- tensor = tensor.to(device="cpu", non_blocking=True)
- elif tensor.device == torch.device("cpu"):
- if tensor.storage().size() != tensor.numel():
- # this forces the tensor to be both contiguous and with minimal storage
- tensor = tensor.clone()
- self.current_items.append(
- (
- tensor,
- obj,
- )
- )
- self.in_flight_data += tensor.numel() * tensor.element_size()
- def _finish(self):
- assert self._done
- if len(self.current_items) > 0:
- self.stream.synchronize()
- return self.current_items
- def add(self, size, obj):
- if self.started:
- raise RuntimeError("cannot add items after loading started")
- self.items.append((size, obj))
- def start_loading(self):
- if self.started:
- return
- self.started = True
- self.items.sort(key=lambda x: x[0])
- self._refill()
- def values(self):
- self.start_loading()
- while not self._done:
- drained = self._drain()
- self._refill()
- yield from drained
- yield from self._finish()
- def _item_size(item: WriteItem) -> int:
- size = 1
- assert item.tensor_data is not None
- # can't use math.prod as PT needs to support older python
- for s in item.tensor_data.size:
- size *= s
- dtype = item.tensor_data.properties.dtype
- return size * torch._utils._element_size(dtype)
- def _split_by_size_and_type(
- bins, items: List[WriteItem]
- ) -> List[List[WriteItem]]:
- if bins == 1:
- return [items]
- bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
- tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
- buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
- bucket_sizes = [0 for _ in range(bins)]
- tensor_w.sort(key=_item_size, reverse=True)
- for i, wi in enumerate(bytes_w):
- buckets[i % bins].append(wi)
- for wi in tensor_w:
- # TODO replace with headq
- idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
- buckets[idx].append(wi)
- bucket_sizes[idx] += _item_size(wi)
- return buckets
- def _write_item(stream, data, write_item, storage_key):
- offset = stream.tell()
- if write_item.type == WriteItemType.BYTE_IO:
- assert isinstance(data, io.BytesIO)
- stream.write(data.getbuffer())
- else:
- assert isinstance(data, torch.Tensor)
- assert data.device == torch.device("cpu")
- torch.save(data, stream)
- length = stream.tell() - offset
- return _result_from_write_item(
- write_item, length, _StorageInfo(storage_key, offset, length)
- )
- def _write_files_from_queue(
- file_queue: queue.Queue,
- result_queue: queue.Queue,
- planner: SavePlanner,
- inflight_threshhold: int,
- use_fsync: bool,
- ):
- try:
- while True:
- file_name, storage_key, write_items = file_queue.get_nowait()
- loader: _TensorLoader
- if torch.cuda.is_available() and inflight_threshhold > 0:
- loader = _OverlappingCpuLoader(
- lambda x: planner.resolve_data(x),
- inflight_threshhold=inflight_threshhold,
- )
- else:
- loader = _SerialCpuLoader(
- lambda x: planner.resolve_data(x),
- )
- tensor_w = [
- wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
- ]
- for write_item in tensor_w:
- loader.add(_item_size(write_item), write_item)
- loader.start_loading()
- bytes_w = [
- wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
- ]
- write_results = []
- with open(file_name, "wb") as stream:
- for write_item in bytes_w:
- data = planner.resolve_data(write_item)
- write_results.append(
- _write_item(stream, data, write_item, storage_key)
- )
- for tensor, write_item in loader.values():
- assert not tensor.is_cuda
- write_results.append(
- _write_item(stream, tensor, write_item, storage_key)
- )
- if use_fsync:
- os.fsync(stream.fileno())
- result_queue.put(write_results)
- except queue.Empty:
- pass
- class FileSystemWriter(StorageWriter):
- """
- Basic implementation of StorageWriter using file IO.
- This implementation makes the following assumptions and simplifications:
- * The checkpoint path is an empty or non-existing directory.
- * File creation is atomic
- The checkpoint consist of one file per write request plus
- a `.metadata` file with the serialized metadata.
- """
- def __init__(
- self,
- path: Union[str, os.PathLike],
- single_file_per_rank: bool = True,
- sync_files: bool = True,
- thread_count: int = 1,
- per_thread_copy_ahead: int = 10_000_000,
- ) -> None:
- """
- Initialize the writer pointing to `path`
- Args:
- path: diretory where the checkpoint will be writen to.
- single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
- sync_files : force files to be synced to permanent storage. Default to True.
- thread_count: Number of IO threads to use to write. Default to 1.
- per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
- N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
- """
- super().__init__()
- self.path = Path(path)
- self.single_file_per_rank = single_file_per_rank
- self.sync_files = sync_files
- self.thread_count = thread_count
- self.per_thread_copy_ahead = per_thread_copy_ahead
- def set_up_storage_writer(self, is_coordinator: bool) -> None:
- pass
- def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
- self.path.mkdir(parents=True, exist_ok=True)
- return plan
- def prepare_global_plan(
- self, global_plan: List[SavePlan]
- ) -> List[SavePlan]:
- new_plans = [
- dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
- for i, plan in enumerate(global_plan)
- ]
- return new_plans
- def write_data(
- self,
- plan: SavePlan,
- planner: SavePlanner,
- ) -> Future[List[WriteResult]]:
- storage_plan: _StoragePrefix = plan.storage_data
- file_count = 0
- def gen_file():
- nonlocal file_count
- file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
- file_count += 1
- return file_name
- file_queue: queue.Queue = queue.Queue()
- if self.single_file_per_rank:
- for bucket in _split_by_size_and_type(
- self.thread_count, plan.items
- ):
- file_name = gen_file()
- file_queue.put((self.path / file_name, file_name, bucket))
- else:
- for item in plan.items:
- file_name = gen_file()
- file_queue.put((self.path / file_name, file_name, [item]))
- result_queue: queue.Queue = queue.Queue()
- threads = []
- for _ in range(1, self.thread_count):
- t = threading.Thread(
- target=_write_files_from_queue,
- args=(
- file_queue,
- result_queue,
- planner,
- self.per_thread_copy_ahead,
- self.sync_files,
- ),
- )
- t.start()
- threads.append(t)
- _write_files_from_queue(
- file_queue=file_queue,
- result_queue=result_queue,
- planner=planner,
- inflight_threshhold=self.per_thread_copy_ahead,
- use_fsync=self.sync_files,
- )
- for t in threads:
- t.join()
- res = []
- try:
- while True:
- res += result_queue.get_nowait()
- except queue.Empty:
- pass
- fut: Future[List[WriteResult]] = Future()
- fut.set_result(res)
- return fut
- def finish(
- self, metadata: Metadata, results: List[List[WriteResult]]
- ) -> None:
- storage_md = dict()
- for wr_list in results:
- storage_md.update({wr.index: wr.storage_data for wr in wr_list})
- metadata.storage_data = storage_md
- with (self.path / ".metadata.tmp").open("wb") as metadata_file:
- pickle.dump(metadata, metadata_file)
- os.fsync(metadata_file.fileno())
- (self.path / ".metadata.tmp").rename(self.path / ".metadata")
- class SlicedBufferedReader(io.BufferedReader):
- # TODO override read to handle (-1) correctly
- def __init__(self, base_stream: io.RawIOBase, offset: int, len: int):
- super().__init__(base_stream)
- self.offset = offset
- self.len = len
- self.seek(0)
- def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
- if __whence == os.SEEK_SET:
- __offset = self.offset + __offset
- elif __whence == os.SEEK_END:
- __whence = os.SEEK_SET
- __offset = (self.offset + self.len) - __offset
- return super().seek(__offset, __whence)
- def tell(self) -> int:
- return super().tell() - self.offset
- class FileSystemReader(StorageReader):
- def __init__(self, path: Union[str, os.PathLike]) -> None:
- super().__init__()
- self.path = Path(path)
- self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
- def _slice_file(self, file, sinfo: _StorageInfo):
- return SlicedBufferedReader(
- io.FileIO(file.fileno(), closefd=False), sinfo.offset, sinfo.length
- )
- def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
- # group requests by file
- per_file: Dict[str, List[ReadItem]] = dict()
- for read_item in plan.items:
- item_md = self.storage_data[read_item.storage_index]
- path = item_md.relative_path
- per_file.setdefault(path, []).append(read_item)
- for relative_path, reqs in per_file.items():
- with (self.path / relative_path).open("rb") as file:
- # TODO sort by offset and cache the reading
- for req in reqs:
- item_md = self.storage_data[req.storage_index]
- file_slice = self._slice_file(file, item_md)
- if req.type == LoadItemType.BYTE_IO:
- bytes = io.BytesIO(file_slice.read(item_md.length))
- bytes.seek(0)
- planner.load_bytes(req, bytes)
- else:
- tensor = cast(
- Tensor, torch.load(file_slice, map_location="cpu")
- )
- tensor = narrow_tensor_by_index(
- tensor, req.storage_offsets, req.lengths
- )
- target_tensor = planner.resolve_tensor(req).detach()
- assert (
- target_tensor.size() == tensor.size()
- ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
- target_tensor.copy_(tensor)
- planner.commit_tensor(req, target_tensor)
- fut: Future = Future()
- fut.set_result(None)
- return fut
- # Implementating the abstract function in StorageReader
- def read_metadata(self) -> Metadata:
- with (self.path / ".metadata").open("rb") as metadata_file:
- return pickle.load(metadata_file)
- def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
- self.storage_data = metadata.storage_data
- assert self.storage_data is not None
- def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
- return plan
- def prepare_global_plan(
- self, global_plan: List[LoadPlan]
- ) -> List[LoadPlan]:
- return global_plan
|