state_dict_saver.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from typing import Optional
  2. import torch.distributed as dist
  3. from .planner import SavePlanner
  4. from .default_planner import DefaultSavePlanner
  5. from .storage import (
  6. StorageWriter,
  7. )
  8. from .metadata import Metadata, STATE_DICT_TYPE
  9. from .utils import _DistWrapper
  10. __all__ = ["save_state_dict"]
  11. def save_state_dict(
  12. state_dict: STATE_DICT_TYPE,
  13. storage_writer: StorageWriter,
  14. process_group: Optional[dist.ProcessGroup] = None,
  15. coordinator_rank: int = 0,
  16. no_dist: bool = False,
  17. planner: SavePlanner = None,
  18. ) -> Metadata:
  19. """
  20. Saves a distributed model in SPMD style.
  21. This function is different from ``torch.save()`` as it handles
  22. ``ShardedTensor`` by having each rank only save their local shards.
  23. .. warning::
  24. There is no guarantees of Backwards Compatibility across PyTorch versions
  25. for saved state_dicts.
  26. .. warning::
  27. If using the `process_group` argument, make sure that only its ranks
  28. call `save_state_dict` and that all data in state_dict belong to it.
  29. .. note::
  30. This function can be used to save a state_dict with an intialized process
  31. group by passing ``no_dist=True``. This can be used to produce a checkpoint
  32. that can consumed by load_state_dict is a SPMD fashion.
  33. Args:
  34. state_dict (Dict[str, Any]): A state_dict
  35. storage_writer (StorageWriter):
  36. Instance of StorageWrite use to perform writes.
  37. process_group (ProcessGroup):
  38. ProcessGroup to be used for cross-rank synchronization.
  39. coordinator_rank (int): Rank to use to coordinate the checkpoint.
  40. rank0 is used by default.
  41. no_dist (bool): If ``True``, distributed checkpoint will not save
  42. in SPMD style. (Default: ``False``)
  43. Returns:
  44. Metadata: Metadata object for the saved checkpoint.
  45. Example:
  46. >>> # xdoctest: +SKIP
  47. >>> my_model = MyModule()
  48. >>> model_state_dict = my_model.state_dict()
  49. >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
  50. >>> torch.distributed.checkpoint.save_state_dict(
  51. >>> state_dict=model_state_dict,
  52. >>> storage_writer=fs_stroage_writer,
  53. >>> )
  54. .. note::
  55. save_state_dict uses collectives to coordinate writes across ranks.
  56. For NCCL-based process groups, internal tensor representations of
  57. objects must be moved to the GPU device before communication takes place.
  58. In this case, the device used is given by ``torch.cuda.current_device()``
  59. and it is the user's responsibility to ensure that this is set so that
  60. each rank has an individual GPU, via ``torch.cuda.set_device()``.
  61. """
  62. distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
  63. if planner is None:
  64. planner = DefaultSavePlanner()
  65. assert planner is not None
  66. global_metatadata = None
  67. def local_step():
  68. assert planner is not None
  69. planner.set_up_planner(state_dict, distW.is_coordinator)
  70. storage_writer.set_up_storage_writer(distW.is_coordinator)
  71. local_plan = planner.create_local_plan()
  72. local_plan = storage_writer.prepare_local_plan(local_plan)
  73. return local_plan
  74. def global_step(all_local_plans):
  75. nonlocal global_metatadata
  76. assert planner is not None
  77. all_local_plans, global_metatadata = planner.create_global_plan(
  78. all_local_plans
  79. )
  80. all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
  81. return all_local_plans
  82. central_plan = distW.reduce_scatter("plan", local_step, global_step)
  83. def write_data():
  84. assert planner is not None
  85. final_local_plan = planner.finish_plan(central_plan)
  86. all_writes = storage_writer.write_data(final_local_plan, planner)
  87. all_writes.wait()
  88. return all_writes.value()
  89. def finish_checkpoint(all_results):
  90. assert global_metatadata is not None
  91. storage_writer.finish(metadata=global_metatadata, results=all_results)
  92. return global_metatadata
  93. return distW.all_reduce("write", write_data, finish_checkpoint)