remote_device.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from typing import Optional, Union
  2. import torch
  3. class _remote_device:
  4. """
  5. Represents a device on a remote worker.
  6. Args:
  7. remote_device (str or torch.device): Represents a device on a remote worker.
  8. The string format should be one of the following:
  9. 1. "<workername>/<device>", where the device field can be parsed as torch.device type.
  10. E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
  11. In addition, the device field can be optional and the default value is "cpu".
  12. 2. "rank:<rank>/<device>", where <rank> is the rank of the
  13. process and device can be parsed as torch.device type.
  14. E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
  15. 3. <workername> and <rank> are optional and formats like "cpu"
  16. and "cuda:1", just represent local devices.
  17. """
  18. def __init__(self, remote_device: Union[str, torch.device]):
  19. PARSE_ERROR = (
  20. f"Could not parse remote_device: {remote_device}. The valid format is "
  21. "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
  22. )
  23. self._worker_name = None
  24. self._rank = None
  25. self._device: Optional[Union[str, int, torch.device]] = None
  26. if isinstance(remote_device, torch.device):
  27. self._device = remote_device
  28. elif isinstance(remote_device, str):
  29. fields = remote_device.split("/")
  30. if len(fields) == 2:
  31. self._worker_name, self._device = fields
  32. elif len(fields) == 1:
  33. # Check if this is a valid device.
  34. if _remote_device._is_valid_local_device(fields[0]):
  35. self._device = fields[0]
  36. else:
  37. self._worker_name = fields[0]
  38. self._device = "cpu"
  39. else:
  40. raise ValueError(PARSE_ERROR)
  41. else:
  42. raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
  43. # Do some basic sanity check (no empty string)
  44. if self._worker_name is not None and not self._worker_name:
  45. raise ValueError(PARSE_ERROR)
  46. # Validate the device.
  47. self._device = torch.device(self._device)
  48. # Check for rank based format.
  49. if self._worker_name is not None:
  50. fields = self._worker_name.split(":")
  51. if len(fields) == 2:
  52. # rank:<rank>/device format, extract rank
  53. if fields[0] == "rank" and fields[1].isdigit():
  54. self._rank = int(fields[1]) # type: ignore[assignment]
  55. self._worker_name = None
  56. else:
  57. raise ValueError(PARSE_ERROR)
  58. elif len(fields) > 2:
  59. raise ValueError(PARSE_ERROR)
  60. @staticmethod
  61. def _is_valid_local_device(device):
  62. # Check for torch.device
  63. try:
  64. torch.device(device)
  65. return True
  66. except Exception:
  67. return False
  68. def worker_name(self) -> Optional[str]:
  69. """
  70. Returns the name of remote worker representing the remote device.
  71. Returns ``None`` if no worker name is available.
  72. """
  73. return self._worker_name
  74. def rank(self) -> Optional[int]:
  75. """
  76. Returns the rank of remote worker representing the remote device.
  77. Returns ``None`` if no rank is available.
  78. """
  79. return self._rank
  80. def device(self) -> torch.device:
  81. """
  82. Returns the local device on the remote worker.
  83. """
  84. return self._device # type: ignore[return-value]
  85. def __repr__(self):
  86. if self._device is not None:
  87. if self._worker_name is not None:
  88. return f'{self._worker_name}/{self._device}'
  89. elif self._rank is not None:
  90. return f'rank:{self._rank}/{self._device}'
  91. else:
  92. return str(self._device)
  93. else:
  94. if self._worker_name is not None:
  95. return f'{self._worker_name}'
  96. elif self._rank is not None:
  97. return f'{self._rank}'
  98. else:
  99. raise RuntimeError('Invalid state!')
  100. def __eq__(self, other):
  101. if not isinstance(other, _remote_device):
  102. return False
  103. if (
  104. self._worker_name == other._worker_name
  105. and self._device == other._device
  106. and self._rank == other._rank
  107. ):
  108. return True
  109. return False
  110. def __hash__(self):
  111. return hash(self._worker_name) ^ \
  112. hash(self._device) ^ \
  113. hash(self._rank)