state_dict_loader.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from typing import Any, Dict, Optional
  2. import torch.distributed as dist
  3. from .storage import (
  4. StorageReader,
  5. )
  6. from .planner import LoadPlanner
  7. from .default_planner import DefaultLoadPlanner
  8. from .utils import _DistWrapper
  9. __all__ = ["load_state_dict"]
  10. def load_state_dict(
  11. state_dict: Dict[str, Any],
  12. storage_reader: StorageReader,
  13. process_group: Optional[dist.ProcessGroup] = None,
  14. coordinator_rank: int = 0,
  15. no_dist: bool = False,
  16. planner: LoadPlanner = None,
  17. ) -> None:
  18. """
  19. Loads a distributed ``state_dict`` in SPMD style.
  20. Each rank will try to read the least amount of data necessary
  21. to fullfill the requested `state_dict`. When loading :class:`ShardedTensor`
  22. instances, each rank only reads data for their local shards.
  23. .. warning::
  24. All tensors in ``state_dict`` must be allocated on their
  25. destination device *prior to* calling this function.
  26. All non-tensor data is loaded using `torch.load()` and modified in place
  27. on state_dict.
  28. .. warning::
  29. Users must call `load_state_dict` on the root module to ensure load
  30. pos-processing and non-tensor data properly propagates.
  31. .. note:
  32. This function can be used for local inference and load a checkpoint
  33. produced by ``save_state_dict`` without having a process group initialized
  34. by passing ``no_dist=True`` and by using Tensors instead of ShardedTensors.
  35. Args:
  36. state_dict (Dict[str, Any]) : The state_dict to load. Note that this
  37. state dict will updated in place.
  38. storage_reader (StorageReader): StorageReader used to load data from.
  39. process_group (ProcessGroup):
  40. ProcessGroup to be used for cross-rank synchronization.
  41. coordinator_rank (int):
  42. Rank to use to coordinate the checkpoint.
  43. rank0 is used by default.
  44. no_dist (bool): If ``True``, distributed checkpoint will not save
  45. in SPMD style. (Default: ``False``)
  46. Returns:
  47. None.
  48. Examples
  49. >>> # xdoctest: +SKIP
  50. >>> my_model = MyModule()
  51. >>> optimizer = Adagrad(my_model.parameters())
  52. >>> model_state_dict = my_model.state_dict()
  53. >>> fs_storage_loader = torch.distributed.checkpoint.FileSystemLoader("/checkpoint/1")
  54. >>> torch.distributed.checkpoint.load_state_dict(
  55. >>> state_dict=model_state_dict,
  56. >>> storage_reader=fs_storage_loader,
  57. >>> )
  58. >>> # module.load_state_dict() function might have customized steps
  59. >>> # to flush the state_dict, must call it to
  60. >>> # ensure correct behavior.
  61. >>> my_model.load_state_dict(model_state_dict)
  62. .. note::
  63. load_state_dict uses collectives to coordinate reads across ranks.
  64. For NCCL-based process groups, internal tensor representations of
  65. objects must be moved to the GPU device before communication takes place.
  66. In this case, the device used is given by ``torch.cuda.current_device()``
  67. and it is the user's responsibility to ensure that this is set so that each
  68. rank has an individual GPU, via ``torch.cuda.set_device()``.
  69. """
  70. distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
  71. if planner is None:
  72. planner = DefaultLoadPlanner()
  73. def local_step():
  74. assert planner is not None
  75. metadata = storage_reader.read_metadata()
  76. planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
  77. storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
  78. local_plan = planner.create_local_plan()
  79. local_plan = storage_reader.prepare_local_plan(local_plan)
  80. return local_plan
  81. def global_step(all_local_plans):
  82. assert planner is not None
  83. all_local_plans = planner.create_global_plan(all_local_plans)
  84. all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
  85. return all_local_plans
  86. central_plan = distW.reduce_scatter("plan", local_step, global_step)
  87. def read_data():
  88. assert planner is not None
  89. final_local_plan = planner.finish_plan(central_plan)
  90. all_reads = storage_reader.read_data(final_local_plan, planner)
  91. all_reads.wait()
  92. return None
  93. _ = distW.all_gather("read", read_data)