static_tcp_rendezvous.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import datetime
  9. import logging
  10. from typing import Tuple, cast, Optional
  11. # pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
  12. from torch.distributed import Store, TCPStore, PrefixStore
  13. from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
  14. from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
  15. log = logging.getLogger(__name__)
  16. _default_timeout_seconds = 600
  17. class StaticTCPRendezvous(RendezvousHandler):
  18. """
  19. Static rendezvous that is a wrapper around the TCPStore.
  20. Creates TCPStore based on the input parameters with the
  21. listener on the agent with group_rank=0
  22. """
  23. def __init__(
  24. self,
  25. master_addr: str,
  26. master_port: int,
  27. rank: int,
  28. world_size: int,
  29. run_id: str,
  30. timeout: int,
  31. ):
  32. self.master_addr = master_addr
  33. self.master_port = master_port
  34. self.rank = rank
  35. self.world_size = world_size
  36. self.run_id = run_id
  37. self.timeout = datetime.timedelta(seconds=timeout)
  38. self._store: Optional[Store] = None
  39. def get_backend(self) -> str:
  40. return "static"
  41. def next_rendezvous(self) -> Tuple[Store, int, int]:
  42. log.info("Creating TCPStore as the c10d::Store implementation")
  43. if not self._store:
  44. is_master = self.rank == 0
  45. self._store = TCPStore( # type: ignore[call-arg]
  46. self.master_addr,
  47. self.master_port,
  48. self.world_size,
  49. is_master,
  50. self.timeout,
  51. multi_tenant=True,
  52. )
  53. store = PrefixStore(self.run_id, self._store)
  54. return store, self.rank, self.world_size
  55. def is_closed(self):
  56. return False
  57. def set_closed(self):
  58. pass
  59. def num_nodes_waiting(self):
  60. return 0
  61. def get_run_id(self) -> str:
  62. return self.run_id
  63. def shutdown(self) -> bool:
  64. return True
  65. def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
  66. if "rank" not in params.config:
  67. raise ValueError(
  68. "rank is absent in RendezvousParameters."
  69. "Try add --node-rank to the cmd request"
  70. )
  71. endpoint = params.endpoint.strip()
  72. if not endpoint:
  73. raise ValueError(
  74. "endpoint is absent in RendezvousParameters"
  75. "Try add --master-port and --master-addr to the cmd request"
  76. )
  77. master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
  78. if master_port == -1:
  79. raise ValueError(
  80. f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
  81. )
  82. world_size = params.max_nodes
  83. rank = cast(int, params.config.get("rank"))
  84. run_id = params.run_id
  85. if "timeout" in params.config:
  86. timeout = int(params.config["timeout"])
  87. else:
  88. timeout = _default_timeout_seconds
  89. return StaticTCPRendezvous(
  90. master_addr, master_port, rank, world_size, run_id, timeout
  91. )