utils.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # flake8: noqa C101
  2. import itertools
  3. from typing import Union, Iterable, Dict, Iterator
  4. import torch
  5. import torch.distributed as dist
  6. # The two imports below are not always available depending on the
  7. # USE_DISTRIBUTED compile flag. Make sure they raise import error
  8. # if we're trying to use them.
  9. from torch.distributed import ProcessGroup, group
  10. __all__ = ["average_parameters", "get_params_to_average", "average_parameters_or_parameter_groups"]
  11. def average_parameters(
  12. params: Iterator[torch.nn.Parameter], process_group: ProcessGroup
  13. ):
  14. """
  15. Averages all the given parameters.
  16. For allreduce efficiency, all the parameters are flattened into a contiguous buffer.
  17. Thus, it requires extra memory of the same size as the given parameters.
  18. """
  19. group_to_use = process_group if process_group is not None else group.WORLD
  20. # Do not update any parameter if not in the process group.
  21. if dist._rank_not_in_group(group_to_use):
  22. return
  23. params_it1, params_it2 = itertools.tee(params)
  24. # If the input parameters have different data types,
  25. # packing these parameters will trigger an implicit type up-casting.
  26. # The original parameter data types will be restored during the subsequent unpacking.
  27. flat_params = torch.cat([p.data.reshape(-1) for p in params_it1])
  28. flat_params /= dist.get_world_size(group_to_use)
  29. # Make sure the allreduce will not conflict with any other ongoing process group.
  30. if torch.cuda.is_available():
  31. torch.cuda.synchronize()
  32. dist.all_reduce(flat_params, group=group_to_use)
  33. offset = 0
  34. for p in params_it2:
  35. p.data = flat_params[offset : offset + p.numel()].view_as(p).type_as(p)
  36. offset += p.numel()
  37. def get_params_to_average(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]]):
  38. """
  39. Returns a list of parameters that need to average, which filters out the parameters that do not contain any gradients.
  40. Args:
  41. params: The parameters of a model or parameter groups of an optimizer.
  42. """
  43. filtered_params = []
  44. for param in params:
  45. if isinstance(param, torch.nn.Parameter):
  46. # model.parameters() input
  47. param_data = param
  48. if param_data.grad is not None:
  49. filtered_params.append(param_data)
  50. elif isinstance(param, dict):
  51. # optimizer.param_groups input
  52. for param_data in param["params"]:
  53. if param_data.grad is not None:
  54. filtered_params.append(param_data)
  55. else:
  56. raise NotImplementedError(f"Parameter input of type {type(param)} is not supported")
  57. return filtered_params
  58. def average_parameters_or_parameter_groups(params: Union[Iterable[torch.nn.Parameter], Iterable[Dict[str, torch.nn.Parameter]]], process_group: ProcessGroup):
  59. """
  60. Averages parameters of a model or parameter groups of an optimizer.
  61. """
  62. average_parameters(iter(get_params_to_average(params)), process_group)