filesystem.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. from abc import ABC, abstractmethod
  2. import queue
  3. import threading
  4. import collections
  5. from dataclasses import dataclass
  6. import os
  7. import dataclasses
  8. import io
  9. import pickle
  10. from typing import List, Union, Dict, cast
  11. import torch
  12. from torch import Tensor
  13. from torch.futures import Future
  14. from pathlib import Path
  15. from .metadata import (
  16. Metadata,
  17. MetadataIndex,
  18. )
  19. from .storage import (
  20. StorageReader,
  21. StorageWriter,
  22. WriteResult,
  23. )
  24. from .planner import (
  25. LoadItemType,
  26. LoadPlanner,
  27. LoadPlan,
  28. SavePlan,
  29. SavePlanner,
  30. ReadItem,
  31. WriteItem,
  32. WriteItemType,
  33. )
  34. from torch.distributed._shard._utils import narrow_tensor_by_index
  35. __all__ = [
  36. "FileSystemWriter",
  37. "SlicedBufferedReader",
  38. "FileSystemReader",
  39. ]
  40. @dataclass
  41. class _StorageInfo:
  42. """
  43. This is the per entry storage info
  44. """
  45. relative_path: str
  46. offset: int
  47. length: int
  48. @dataclass
  49. class _StoragePrefix:
  50. prefix: str
  51. DEFAULT_SUFFIX = ".distcp"
  52. def _trim(tensor: torch.Tensor) -> torch.Tensor:
  53. tensor = tensor.detach().cpu()
  54. if tensor._typed_storage()._size() != tensor.numel():
  55. tensor = tensor.clone()
  56. return tensor
  57. def _result_from_write_item(
  58. item: WriteItem, size_in_bytes, storage_data
  59. ) -> WriteResult:
  60. return WriteResult(
  61. index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data
  62. )
  63. class _TensorLoader(ABC):
  64. @abstractmethod
  65. def add(self, size, obj):
  66. pass
  67. def start_loading(self):
  68. pass
  69. @abstractmethod
  70. def values(self):
  71. pass
  72. class _SerialCpuLoader(_TensorLoader):
  73. def __init__(self, resolve_fun):
  74. self.resolve_fun = resolve_fun
  75. self.items = []
  76. def add(self, size, obj):
  77. self.items.append((size, obj))
  78. def start_loading(self):
  79. pass
  80. def values(self):
  81. for _, obj in self.items:
  82. tensor = self.resolve_fun(obj).detach()
  83. tensor = tensor.cpu()
  84. if tensor.storage().size() != tensor.numel():
  85. tensor = tensor.clone()
  86. yield (
  87. tensor,
  88. obj,
  89. )
  90. class _OverlappingCpuLoader(_TensorLoader):
  91. def __init__(self, resolve_fun, stream=None, inflight_threshhold=1_000_000):
  92. self.resolve_fun = resolve_fun
  93. self.items = []
  94. self.inflight_threshhold = inflight_threshhold
  95. self.in_flight_data = 0
  96. self.current_items: collections.deque = collections.deque()
  97. self.idx = 0
  98. self.started = False
  99. self.stream = stream or torch.cuda.current_stream()
  100. if self.stream != torch.cuda.current_stream():
  101. self.stream.wait_stream(torch.cuda.current_stream())
  102. @property
  103. def _done(self):
  104. return self.idx >= len(self.items)
  105. def _drain(self):
  106. drained = []
  107. if self.in_flight_data >= self.inflight_threshhold:
  108. self.stream.synchronize()
  109. while self.in_flight_data >= self.inflight_threshhold:
  110. val = self.current_items.popleft()
  111. self.in_flight_data -= val[0].numel() * val[0].element_size()
  112. drained.append(val)
  113. return drained
  114. def _refill(self):
  115. with torch.cuda.stream(self.stream):
  116. while (
  117. not self._done
  118. and self.in_flight_data < self.inflight_threshhold
  119. ):
  120. _, obj = self.items[self.idx]
  121. self.idx += 1
  122. tensor = self.resolve_fun(obj).detach()
  123. if tensor.is_cuda:
  124. tensor = tensor.to(device="cpu", non_blocking=True)
  125. elif tensor.device == torch.device("cpu"):
  126. if tensor.storage().size() != tensor.numel():
  127. # this forces the tensor to be both contiguous and with minimal storage
  128. tensor = tensor.clone()
  129. self.current_items.append(
  130. (
  131. tensor,
  132. obj,
  133. )
  134. )
  135. self.in_flight_data += tensor.numel() * tensor.element_size()
  136. def _finish(self):
  137. assert self._done
  138. if len(self.current_items) > 0:
  139. self.stream.synchronize()
  140. return self.current_items
  141. def add(self, size, obj):
  142. if self.started:
  143. raise RuntimeError("cannot add items after loading started")
  144. self.items.append((size, obj))
  145. def start_loading(self):
  146. if self.started:
  147. return
  148. self.started = True
  149. self.items.sort(key=lambda x: x[0])
  150. self._refill()
  151. def values(self):
  152. self.start_loading()
  153. while not self._done:
  154. drained = self._drain()
  155. self._refill()
  156. yield from drained
  157. yield from self._finish()
  158. def _item_size(item: WriteItem) -> int:
  159. size = 1
  160. assert item.tensor_data is not None
  161. # can't use math.prod as PT needs to support older python
  162. for s in item.tensor_data.size:
  163. size *= s
  164. dtype = item.tensor_data.properties.dtype
  165. return size * torch._utils._element_size(dtype)
  166. def _split_by_size_and_type(
  167. bins, items: List[WriteItem]
  168. ) -> List[List[WriteItem]]:
  169. if bins == 1:
  170. return [items]
  171. bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
  172. tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
  173. buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
  174. bucket_sizes = [0 for _ in range(bins)]
  175. tensor_w.sort(key=_item_size, reverse=True)
  176. for i, wi in enumerate(bytes_w):
  177. buckets[i % bins].append(wi)
  178. for wi in tensor_w:
  179. # TODO replace with headq
  180. idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
  181. buckets[idx].append(wi)
  182. bucket_sizes[idx] += _item_size(wi)
  183. return buckets
  184. def _write_item(stream, data, write_item, storage_key):
  185. offset = stream.tell()
  186. if write_item.type == WriteItemType.BYTE_IO:
  187. assert isinstance(data, io.BytesIO)
  188. stream.write(data.getbuffer())
  189. else:
  190. assert isinstance(data, torch.Tensor)
  191. assert data.device == torch.device("cpu")
  192. torch.save(data, stream)
  193. length = stream.tell() - offset
  194. return _result_from_write_item(
  195. write_item, length, _StorageInfo(storage_key, offset, length)
  196. )
  197. def _write_files_from_queue(
  198. file_queue: queue.Queue,
  199. result_queue: queue.Queue,
  200. planner: SavePlanner,
  201. inflight_threshhold: int,
  202. use_fsync: bool,
  203. ):
  204. try:
  205. while True:
  206. file_name, storage_key, write_items = file_queue.get_nowait()
  207. loader: _TensorLoader
  208. if torch.cuda.is_available() and inflight_threshhold > 0:
  209. loader = _OverlappingCpuLoader(
  210. lambda x: planner.resolve_data(x),
  211. inflight_threshhold=inflight_threshhold,
  212. )
  213. else:
  214. loader = _SerialCpuLoader(
  215. lambda x: planner.resolve_data(x),
  216. )
  217. tensor_w = [
  218. wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
  219. ]
  220. for write_item in tensor_w:
  221. loader.add(_item_size(write_item), write_item)
  222. loader.start_loading()
  223. bytes_w = [
  224. wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
  225. ]
  226. write_results = []
  227. with open(file_name, "wb") as stream:
  228. for write_item in bytes_w:
  229. data = planner.resolve_data(write_item)
  230. write_results.append(
  231. _write_item(stream, data, write_item, storage_key)
  232. )
  233. for tensor, write_item in loader.values():
  234. assert not tensor.is_cuda
  235. write_results.append(
  236. _write_item(stream, tensor, write_item, storage_key)
  237. )
  238. if use_fsync:
  239. os.fsync(stream.fileno())
  240. result_queue.put(write_results)
  241. except queue.Empty:
  242. pass
  243. class FileSystemWriter(StorageWriter):
  244. """
  245. Basic implementation of StorageWriter using file IO.
  246. This implementation makes the following assumptions and simplifications:
  247. * The checkpoint path is an empty or non-existing directory.
  248. * File creation is atomic
  249. The checkpoint consist of one file per write request plus
  250. a `.metadata` file with the serialized metadata.
  251. """
  252. def __init__(
  253. self,
  254. path: Union[str, os.PathLike],
  255. single_file_per_rank: bool = True,
  256. sync_files: bool = True,
  257. thread_count: int = 1,
  258. per_thread_copy_ahead: int = 10_000_000,
  259. ) -> None:
  260. """
  261. Initialize the writer pointing to `path`
  262. Args:
  263. path: diretory where the checkpoint will be writen to.
  264. single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
  265. sync_files : force files to be synced to permanent storage. Default to True.
  266. thread_count: Number of IO threads to use to write. Default to 1.
  267. per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
  268. N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
  269. """
  270. super().__init__()
  271. self.path = Path(path)
  272. self.single_file_per_rank = single_file_per_rank
  273. self.sync_files = sync_files
  274. self.thread_count = thread_count
  275. self.per_thread_copy_ahead = per_thread_copy_ahead
  276. def set_up_storage_writer(self, is_coordinator: bool) -> None:
  277. pass
  278. def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
  279. self.path.mkdir(parents=True, exist_ok=True)
  280. return plan
  281. def prepare_global_plan(
  282. self, global_plan: List[SavePlan]
  283. ) -> List[SavePlan]:
  284. new_plans = [
  285. dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
  286. for i, plan in enumerate(global_plan)
  287. ]
  288. return new_plans
  289. def write_data(
  290. self,
  291. plan: SavePlan,
  292. planner: SavePlanner,
  293. ) -> Future[List[WriteResult]]:
  294. storage_plan: _StoragePrefix = plan.storage_data
  295. file_count = 0
  296. def gen_file():
  297. nonlocal file_count
  298. file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
  299. file_count += 1
  300. return file_name
  301. file_queue: queue.Queue = queue.Queue()
  302. if self.single_file_per_rank:
  303. for bucket in _split_by_size_and_type(
  304. self.thread_count, plan.items
  305. ):
  306. file_name = gen_file()
  307. file_queue.put((self.path / file_name, file_name, bucket))
  308. else:
  309. for item in plan.items:
  310. file_name = gen_file()
  311. file_queue.put((self.path / file_name, file_name, [item]))
  312. result_queue: queue.Queue = queue.Queue()
  313. threads = []
  314. for _ in range(1, self.thread_count):
  315. t = threading.Thread(
  316. target=_write_files_from_queue,
  317. args=(
  318. file_queue,
  319. result_queue,
  320. planner,
  321. self.per_thread_copy_ahead,
  322. self.sync_files,
  323. ),
  324. )
  325. t.start()
  326. threads.append(t)
  327. _write_files_from_queue(
  328. file_queue=file_queue,
  329. result_queue=result_queue,
  330. planner=planner,
  331. inflight_threshhold=self.per_thread_copy_ahead,
  332. use_fsync=self.sync_files,
  333. )
  334. for t in threads:
  335. t.join()
  336. res = []
  337. try:
  338. while True:
  339. res += result_queue.get_nowait()
  340. except queue.Empty:
  341. pass
  342. fut: Future[List[WriteResult]] = Future()
  343. fut.set_result(res)
  344. return fut
  345. def finish(
  346. self, metadata: Metadata, results: List[List[WriteResult]]
  347. ) -> None:
  348. storage_md = dict()
  349. for wr_list in results:
  350. storage_md.update({wr.index: wr.storage_data for wr in wr_list})
  351. metadata.storage_data = storage_md
  352. with (self.path / ".metadata.tmp").open("wb") as metadata_file:
  353. pickle.dump(metadata, metadata_file)
  354. os.fsync(metadata_file.fileno())
  355. (self.path / ".metadata.tmp").rename(self.path / ".metadata")
  356. class SlicedBufferedReader(io.BufferedReader):
  357. # TODO override read to handle (-1) correctly
  358. def __init__(self, base_stream: io.RawIOBase, offset: int, len: int):
  359. super().__init__(base_stream)
  360. self.offset = offset
  361. self.len = len
  362. self.seek(0)
  363. def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int:
  364. if __whence == os.SEEK_SET:
  365. __offset = self.offset + __offset
  366. elif __whence == os.SEEK_END:
  367. __whence = os.SEEK_SET
  368. __offset = (self.offset + self.len) - __offset
  369. return super().seek(__offset, __whence)
  370. def tell(self) -> int:
  371. return super().tell() - self.offset
  372. class FileSystemReader(StorageReader):
  373. def __init__(self, path: Union[str, os.PathLike]) -> None:
  374. super().__init__()
  375. self.path = Path(path)
  376. self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
  377. def _slice_file(self, file, sinfo: _StorageInfo):
  378. return SlicedBufferedReader(
  379. io.FileIO(file.fileno(), closefd=False), sinfo.offset, sinfo.length
  380. )
  381. def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
  382. # group requests by file
  383. per_file: Dict[str, List[ReadItem]] = dict()
  384. for read_item in plan.items:
  385. item_md = self.storage_data[read_item.storage_index]
  386. path = item_md.relative_path
  387. per_file.setdefault(path, []).append(read_item)
  388. for relative_path, reqs in per_file.items():
  389. with (self.path / relative_path).open("rb") as file:
  390. # TODO sort by offset and cache the reading
  391. for req in reqs:
  392. item_md = self.storage_data[req.storage_index]
  393. file_slice = self._slice_file(file, item_md)
  394. if req.type == LoadItemType.BYTE_IO:
  395. bytes = io.BytesIO(file_slice.read(item_md.length))
  396. bytes.seek(0)
  397. planner.load_bytes(req, bytes)
  398. else:
  399. tensor = cast(
  400. Tensor, torch.load(file_slice, map_location="cpu")
  401. )
  402. tensor = narrow_tensor_by_index(
  403. tensor, req.storage_offsets, req.lengths
  404. )
  405. target_tensor = planner.resolve_tensor(req).detach()
  406. assert (
  407. target_tensor.size() == tensor.size()
  408. ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
  409. target_tensor.copy_(tensor)
  410. planner.commit_tensor(req, target_tensor)
  411. fut: Future = Future()
  412. fut.set_result(None)
  413. return fut
  414. # Implementating the abstract function in StorageReader
  415. def read_metadata(self) -> Metadata:
  416. with (self.path / ".metadata").open("rb") as metadata_file:
  417. return pickle.load(metadata_file)
  418. def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
  419. self.storage_data = metadata.storage_data
  420. assert self.storage_data is not None
  421. def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
  422. return plan
  423. def prepare_global_plan(
  424. self, global_plan: List[LoadPlan]
  425. ) -> List[LoadPlan]:
  426. return global_plan