api.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import os
  8. import socket
  9. from string import Template
  10. from typing import List, Any
  11. def get_env_variable_or_raise(env_name: str) -> str:
  12. r"""
  13. Tries to retrieve environment variable. Raises ``ValueError``
  14. if no environment variable found.
  15. Args:
  16. env_name (str): Name of the env variable
  17. """
  18. value = os.environ.get(env_name, None)
  19. if value is None:
  20. msg = f"Environment variable {env_name} expected, but not set"
  21. raise ValueError(msg)
  22. return value
  23. def get_socket_with_port() -> socket.socket:
  24. addrs = socket.getaddrinfo(
  25. host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
  26. )
  27. for addr in addrs:
  28. family, type, proto, _, _ = addr
  29. s = socket.socket(family, type, proto)
  30. try:
  31. s.bind(("localhost", 0))
  32. s.listen(0)
  33. return s
  34. except OSError as e:
  35. s.close()
  36. raise RuntimeError("Failed to create a socket")
  37. class macros:
  38. """
  39. Defines simple macros for caffe2.distributed.launch cmd args substitution
  40. """
  41. local_rank = "${local_rank}"
  42. @staticmethod
  43. def substitute(args: List[Any], local_rank: str) -> List[str]:
  44. args_sub = []
  45. for arg in args:
  46. if isinstance(arg, str):
  47. sub = Template(arg).safe_substitute(local_rank=local_rank)
  48. args_sub.append(sub)
  49. else:
  50. args_sub.append(arg)
  51. return args_sub