averagers.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import warnings
  2. from abc import ABC, abstractmethod
  3. from typing import Union, Iterable, Dict
  4. import torch
  5. import torch.distributed as dist
  6. import torch.distributed.algorithms.model_averaging.utils as utils
  7. __all__ = ['ModelAverager', 'PeriodicModelAverager']
  8. class ModelAverager(ABC):
  9. r"""Base class for all model averagers.
  10. Args:
  11. process_group: The process group to be used for all-reduce.
  12. If ``None``, the default process group, which
  13. is created by :func:`torch.distributed.init_process_group`,
  14. will be used. (default: ``None``)
  15. """
  16. def __init__(self, process_group=None):
  17. self.process_group = (
  18. process_group if process_group is not None else dist.group.WORLD
  19. )
  20. self.step = 0
  21. @abstractmethod
  22. def average_parameters(self, params):
  23. raise NotImplementedError
  24. class PeriodicModelAverager(ModelAverager):
  25. r"""
  26. Averages parameters periodically after the warm-up stage.
  27. This can be used for running `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
  28. by running :class:`~torch.nn.DistributedDataParallel` (DDP)
  29. using the subgroups created by :meth:`~torch.distributed.new_subgroups`.
  30. Args:
  31. period (int): The number of steps per model averaging.
  32. Usually the period should be greater than ``1`` to reduce the communication cost.
  33. Otherwise, only DDP needs to be used.
  34. warmup_steps (int): The number of warm-up steps. During this stage,
  35. model averaging is skipped.
  36. process_group: The process group to be used for all-reduce.
  37. If ``None``, the default process group, which
  38. is created by :func:`torch.distributed.init_process_group`,
  39. will be used. (default: ``None``)
  40. Example::
  41. >>> # xdoctest: +SKIP("undefined variables")
  42. >>> import torch
  43. >>> import torch.distributed as dist
  44. >>> import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
  45. >>> import torch.distributed.algorithms.model_averaging.averagers as averagers
  46. >>> import torch.nn as nn
  47. >>>
  48. >>> dist.init_process_group("nccl", rank=rank, world_size=16)
  49. >>> torch.cuda.set_device(rank)
  50. >>> module = nn.Linear(1, 1, bias=False).cuda()
  51. >>> model = nn.parallel.DistributedDataParallel(
  52. >>> module, device_ids=[rank], output_device=rank
  53. >>> )
  54. >>> # Register a post-localSGD communication hook.
  55. >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
  56. >>> model.register_comm_hook(state, post_localSGD_hook)
  57. >>>
  58. >>> # In the first 100 steps, run global gradient averaging like normal DDP at every step.
  59. >>> # After 100 steps, run model averaging every 4 steps.
  60. >>> # Note that ``warmup_steps`` must be the same as ``start_localSGD_iter`` used in ``PostLocalSGDState``.
  61. >>> averager = averagers.PeriodicModelAverager(period=4, warmup_steps=100)
  62. >>> for step in range(0, 200):
  63. >>> optimizer.zero_grad()
  64. >>> loss = loss_fn(output, labels)
  65. >>> loss.backward()
  66. >>> optimizer.step()
  67. >>> # Will average model parameters globally every 4 steps. Thus,
  68. >>> # inter-node communication only occurs every 4 iterations after
  69. >>> # the initial ``warmup_steps`` period.
  70. >>> averager.average_parameters(model.parameters())
  71. """
  72. def __init__(
  73. self,
  74. period,
  75. warmup_steps=0,
  76. process_group=None
  77. ):
  78. super().__init__(process_group)
  79. if warmup_steps < 0:
  80. raise ValueError("Arg ``warmup_steps`` must be a non-negative number.")
  81. self.warmup_steps = warmup_steps
  82. if period < 1:
  83. raise ValueError("Arg ``period`` must be a positive value.")
  84. elif period == 1:
  85. warnings.warn(
  86. "When period is 1, no need to use model averaging because the communication cost "
  87. "of all-reducing parameters will be no less than the cost of all-reducing gradients "
  88. "by DistributedDataParallel in the backward pass. Therefore, only "
  89. "DistributedDataParallel should be used for this case."
  90. )
  91. self.period = period
  92. def average_parameters(self, params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
  93. """
  94. Averages parameters or parameter groups of an optimizer if ``step`` is no less than ``warmup_steps``
  95. and it can be divided by ``period``, where ``step`` is increased by 1
  96. at each iteration in the training loop.
  97. Args:
  98. params: The parameters of a model or parameter groups of an optimizer.
  99. """
  100. if (
  101. self.step >= self.warmup_steps
  102. and (self.step - self.warmup_steps) % self.period == 0
  103. ):
  104. utils.average_parameters_or_parameter_groups(params, self.process_group)
  105. self.step += 1