store.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. from datetime import timedelta
  8. from typing import List
  9. def get_all(store, rank: int, prefix: str, size: int):
  10. r"""
  11. Given a store and a prefix, the method goes through the array of keys
  12. of the following format: ``{prefix}{idx}``, where idx is in a range
  13. from 0 to size, and tries to retrieve the data.
  14. The Rank0 process waits at the end to make sure all other processes
  15. finished the procedure before exiting.
  16. Usage
  17. ::
  18. values = get_all(store, 'torchelastic/data', 3)
  19. value1 = values[0] # retrieves the data for key torchelastic/data0
  20. value2 = values[1] # retrieves the data for key torchelastic/data1
  21. value3 = values[2] # retrieves the data for key torchelastic/data2
  22. """
  23. data_arr = []
  24. for idx in range(size):
  25. data = store.get(f"{prefix}{idx}")
  26. data_arr.append(data)
  27. store.set(f"{prefix}{rank}.FIN", b"FIN")
  28. if rank == 0:
  29. # Rank0 runs the TCPStore daemon, as a result it needs to exit last.
  30. # Otherwise, the barrier may timeout if rank0 process finished the work
  31. # before other processes finished `get_all` method
  32. for node_rank in range(size):
  33. store.get(f"{prefix}{node_rank}.FIN")
  34. return data_arr
  35. def synchronize(
  36. store,
  37. data: bytes,
  38. rank: int,
  39. world_size: int,
  40. key_prefix: str,
  41. barrier_timeout: float = 300,
  42. ) -> List[bytes]:
  43. """
  44. Synchronizes ``world_size`` agents between each other using the underlying c10d store.
  45. The ``data`` will be available on each of the agents.
  46. Note: The data on the path is not deleted, as a result there can be stale data if
  47. you use the same key_prefix twice.
  48. """
  49. store.set_timeout(timedelta(seconds=barrier_timeout))
  50. store.set(f"{key_prefix}{rank}", data)
  51. agent_data = get_all(store, rank, key_prefix, world_size)
  52. return agent_data
  53. def barrier(
  54. store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
  55. ) -> None:
  56. """
  57. A global lock between agents.
  58. Note: Since the data is not removed from the store, the barrier can be used
  59. once per unique ``key_prefix``.
  60. """
  61. data = f"{rank}".encode(encoding="UTF-8")
  62. synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)