distributed.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import torch
  3. import torch.distributed as dist
  4. def _redefine_print(is_main):
  5. """disables printing when not in main process"""
  6. import builtins as __builtin__
  7. builtin_print = __builtin__.print
  8. def print(*args, **kwargs):
  9. force = kwargs.pop("force", False)
  10. if is_main or force:
  11. builtin_print(*args, **kwargs)
  12. __builtin__.print = print
  13. def setup_ddp(args):
  14. # Set the local_rank, rank, and world_size values as args fields
  15. # This is done differently depending on how we're running the script. We
  16. # currently support either torchrun or the custom run_with_submitit.py
  17. # If you're confused (like I was), this might help a bit
  18. # https://discuss.pytorch.org/t/what-is-the-difference-between-rank-and-local-rank/61940/2
  19. if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
  20. args.rank = int(os.environ["RANK"])
  21. args.world_size = int(os.environ["WORLD_SIZE"])
  22. args.gpu = int(os.environ["LOCAL_RANK"])
  23. elif "SLURM_PROCID" in os.environ:
  24. args.rank = int(os.environ["SLURM_PROCID"])
  25. args.gpu = args.rank % torch.cuda.device_count()
  26. elif hasattr(args, "rank"):
  27. pass
  28. else:
  29. print("Not using distributed mode")
  30. args.distributed = False
  31. args.world_size = 1
  32. return
  33. args.distributed = True
  34. torch.cuda.set_device(args.gpu)
  35. dist.init_process_group(
  36. backend="nccl",
  37. rank=args.rank,
  38. world_size=args.world_size,
  39. init_method=args.dist_url,
  40. )
  41. torch.distributed.barrier()
  42. _redefine_print(is_main=(args.rank == 0))
  43. def reduce_across_processes(val):
  44. t = torch.tensor(val, device="cuda")
  45. dist.barrier()
  46. dist.all_reduce(t)
  47. return t