123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- #!/usr/bin/env python3
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import sys
- import uuid
- from dataclasses import dataclass, field
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- import torch.distributed.elastic.rendezvous.registry as rdzv_registry
- from torch.distributed.elastic import events, metrics
- from torch.distributed.elastic.agent.server.api import WorkerSpec
- from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
- from torch.distributed.elastic.multiprocessing import SignalException, Std
- from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
- from torch.distributed.elastic.rendezvous import RendezvousParameters
- from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
- from torch.distributed.elastic.utils.logging import get_logger
- __all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent']
- logger = get_logger()
- @dataclass
- class LaunchConfig:
- """
- Creates a rendezvous config.
- Args:
- min_nodes: Minimum amount of nodes that the user function will
- be launched on. Elastic agent ensures that the user
- function start only when the min_nodes amount enters
- the rendezvous.
- max_nodes: Maximum amount of nodes that the user function
- will be launched on.
- nproc_per_node: On each node the elastic agent will launch
- this amount of workers that will execute user
- defined function.
- rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
- rdzv_endpoint: The endpoint of the rdzv sync. storage.
- rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
- rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
- to be removed in future versions, see the note below. The default timeout is 900 seconds.
- run_id: The unique run id of the job (if not passed a unique one will be
- deduced from run environment - flow workflow id in flow - or auto generated).
- role: User defined role of the worker (defaults to "trainer").
- max_restarts: The maximum amount of restarts that elastic agent will conduct
- on workers before failure.
- monitor_interval: The interval in seconds that is used by the elastic_agent
- as a period of monitoring workers.
- start_method: The method is used by the elastic agent to start the
- workers (spawn, fork, forkserver).
- log_dir: base log directory where log files are written. If not set,
- one is created in a tmp dir but NOT removed on exit.
- redirects: configuration to redirect stdout/stderr to log files.
- Pass a single ``Std`` enum to redirect all workers,
- or a mapping keyed by local_rank to selectively redirect.
- tee: configuration to "tee" stdout/stderr to console + log file.
- metrics_cfg: configuration to initialize metrics.
- local_addr: address of the local node if any. If not set, a lookup on the local
- machine's FQDN will be performed.
- ..note:
- `rdzv_timeout` is a legacy argument that will be removed in future.
- Set the timeout via `rdzv_configs['timeout']`
- """
- min_nodes: int
- max_nodes: int
- nproc_per_node: int
- run_id: str = ""
- role: str = "default_role"
- rdzv_endpoint: str = ""
- rdzv_backend: str = "etcd"
- rdzv_configs: Dict[str, Any] = field(default_factory=dict)
- rdzv_timeout: int = -1
- max_restarts: int = 3
- monitor_interval: float = 30
- start_method: str = "spawn"
- log_dir: Optional[str] = None
- redirects: Union[Std, Dict[int, Std]] = Std.NONE
- tee: Union[Std, Dict[int, Std]] = Std.NONE
- metrics_cfg: Dict[str, str] = field(default_factory=dict)
- local_addr: Optional[str] = None
- def __post_init__(self):
- default_timeout = 900
- if self.rdzv_timeout != -1:
- self.rdzv_configs["timeout"] = self.rdzv_timeout
- elif "timeout" not in self.rdzv_configs:
- self.rdzv_configs["timeout"] = default_timeout
- class elastic_launch:
- """
- Launches an torchelastic agent on the container that invoked the entrypoint.
- 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
- ``entrypoint`` can be a function or a command.
- 2. The return value is a map of each worker's output mapped
- by their respective global rank.
- Usage
- ::
- def worker_fn(foo):
- # ...
- def main():
- # entrypoint is a function.
- outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
- # return rank 0's output
- return outputs[0]
- # entrypoint is a command and ``script.py`` is the python module.
- outputs = elastic_launch(LaunchConfig, "script.py")(args)
- outputs = elastic_launch(LaunchConfig, "python")("script.py")
- """
- def __init__(
- self,
- config: LaunchConfig,
- entrypoint: Union[Callable, str, None],
- ):
- self._config = config
- self._entrypoint = entrypoint
- def __call__(self, *args):
- return launch_agent(self._config, self._entrypoint, list(args))
- def _get_entrypoint_name(
- entrypoint: Union[Callable, str, None], args: List[Any]
- ) -> str:
- """Retrive entrypoint name with the rule:
- 1. If entrypoint is a function, use ``entrypont.__qualname__``.
- 2. If entrypoint is a string, check its value:
- 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
- which does not start with hifen letter (for example, "-u" will be skipped).
- 2.2 otherwise, use ``entrypoint`` value.
- 3. Otherwise, return empty string.
- """
- if isinstance(entrypoint, Callable): # type: ignore[arg-type]
- return entrypoint.__name__ # type: ignore[union-attr]
- elif isinstance(entrypoint, str):
- if entrypoint == sys.executable:
- return next((arg for arg in args if arg[0] != "-"), "")
- else:
- return entrypoint
- else:
- return ""
- def _get_addr_and_port(
- rdzv_parameters: RendezvousParameters,
- ) -> Tuple[Optional[str], Optional[int]]:
- if rdzv_parameters.backend != "static":
- return (None, None)
- endpoint = rdzv_parameters.endpoint
- endpoint = endpoint.strip()
- if not endpoint:
- raise ValueError(
- "Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
- )
- master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
- if master_port == -1:
- raise ValueError(
- f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
- )
- return (master_addr, master_port)
- def launch_agent(
- config: LaunchConfig,
- entrypoint: Union[Callable, str, None],
- args: List[Any],
- ) -> Dict[int, Any]:
- if not config.run_id:
- run_id = str(uuid.uuid4().int)
- logger.warning(f"config has no run_id, generated a random run_id: {run_id}")
- config.run_id = run_id
- entrypoint_name = _get_entrypoint_name(entrypoint, args)
- logger.info(
- f"Starting elastic_operator with launch configs:\n"
- f" entrypoint : {entrypoint_name}\n"
- f" min_nodes : {config.min_nodes}\n"
- f" max_nodes : {config.max_nodes}\n"
- f" nproc_per_node : {config.nproc_per_node}\n"
- f" run_id : {config.run_id}\n"
- f" rdzv_backend : {config.rdzv_backend}\n"
- f" rdzv_endpoint : {config.rdzv_endpoint}\n"
- f" rdzv_configs : {config.rdzv_configs}\n"
- f" max_restarts : {config.max_restarts}\n"
- f" monitor_interval : {config.monitor_interval}\n"
- f" log_dir : {config.log_dir}\n"
- f" metrics_cfg : {config.metrics_cfg}\n"
- )
- rdzv_parameters = RendezvousParameters(
- backend=config.rdzv_backend,
- endpoint=config.rdzv_endpoint,
- run_id=config.run_id,
- min_nodes=config.min_nodes,
- max_nodes=config.max_nodes,
- local_addr=config.local_addr,
- **config.rdzv_configs,
- )
- master_addr, master_port = _get_addr_and_port(rdzv_parameters)
- spec = WorkerSpec(
- role=config.role,
- local_world_size=config.nproc_per_node,
- entrypoint=entrypoint,
- args=tuple(args),
- rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
- max_restarts=config.max_restarts,
- monitor_interval=config.monitor_interval,
- redirects=config.redirects,
- tee=config.tee,
- master_addr=master_addr,
- master_port=master_port,
- local_addr=config.local_addr,
- )
- agent = LocalElasticAgent(
- spec=spec, start_method=config.start_method, log_dir=config.log_dir
- )
- shutdown_rdzv = True
- try:
- metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
- result = agent.run()
- # records that agent.run() has succeeded NOT that workers have succeeded
- events.record(agent.get_event_succeeded())
- if result.is_failed():
- # ChildFailedError is treated specially by @record
- # if the error files for the failed children exist
- # @record will copy the first error (root cause)
- # to the error file of the launcher process.
- raise ChildFailedError(
- name=entrypoint_name,
- failures=result.failures,
- )
- return result.return_values
- except ChildFailedError:
- raise
- except SignalException:
- # when the agent dies with a signal do NOT shutdown the rdzv_handler
- # since this closes the rendezvous on this rdzv_id permanently and
- # prevents any additional scaling events
- shutdown_rdzv = False
- events.record(agent.get_event_failed())
- raise
- except Exception:
- events.record(agent.get_event_failed())
- raise
- finally:
- if shutdown_rdzv:
- spec.rdzv_handler.shutdown()
|