123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- from typing import Optional, Union
- import torch
- class _remote_device:
- """
- Represents a device on a remote worker.
- Args:
- remote_device (str or torch.device): Represents a device on a remote worker.
- The string format should be one of the following:
- 1. "<workername>/<device>", where the device field can be parsed as torch.device type.
- E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
- In addition, the device field can be optional and the default value is "cpu".
- 2. "rank:<rank>/<device>", where <rank> is the rank of the
- process and device can be parsed as torch.device type.
- E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
- 3. <workername> and <rank> are optional and formats like "cpu"
- and "cuda:1", just represent local devices.
- """
- def __init__(self, remote_device: Union[str, torch.device]):
- PARSE_ERROR = (
- f"Could not parse remote_device: {remote_device}. The valid format is "
- "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
- )
- self._worker_name = None
- self._rank = None
- self._device: Optional[Union[str, int, torch.device]] = None
- if isinstance(remote_device, torch.device):
- self._device = remote_device
- elif isinstance(remote_device, str):
- fields = remote_device.split("/")
- if len(fields) == 2:
- self._worker_name, self._device = fields
- elif len(fields) == 1:
- # Check if this is a valid device.
- if _remote_device._is_valid_local_device(fields[0]):
- self._device = fields[0]
- else:
- self._worker_name = fields[0]
- self._device = "cpu"
- else:
- raise ValueError(PARSE_ERROR)
- else:
- raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
- # Do some basic sanity check (no empty string)
- if self._worker_name is not None and not self._worker_name:
- raise ValueError(PARSE_ERROR)
- # Validate the device.
- self._device = torch.device(self._device)
- # Check for rank based format.
- if self._worker_name is not None:
- fields = self._worker_name.split(":")
- if len(fields) == 2:
- # rank:<rank>/device format, extract rank
- if fields[0] == "rank" and fields[1].isdigit():
- self._rank = int(fields[1]) # type: ignore[assignment]
- self._worker_name = None
- else:
- raise ValueError(PARSE_ERROR)
- elif len(fields) > 2:
- raise ValueError(PARSE_ERROR)
- @staticmethod
- def _is_valid_local_device(device):
- # Check for torch.device
- try:
- torch.device(device)
- return True
- except Exception:
- return False
- def worker_name(self) -> Optional[str]:
- """
- Returns the name of remote worker representing the remote device.
- Returns ``None`` if no worker name is available.
- """
- return self._worker_name
- def rank(self) -> Optional[int]:
- """
- Returns the rank of remote worker representing the remote device.
- Returns ``None`` if no rank is available.
- """
- return self._rank
- def device(self) -> torch.device:
- """
- Returns the local device on the remote worker.
- """
- return self._device # type: ignore[return-value]
- def __repr__(self):
- if self._device is not None:
- if self._worker_name is not None:
- return f'{self._worker_name}/{self._device}'
- elif self._rank is not None:
- return f'rank:{self._rank}/{self._device}'
- else:
- return str(self._device)
- else:
- if self._worker_name is not None:
- return f'{self._worker_name}'
- elif self._rank is not None:
- return f'{self._rank}'
- else:
- raise RuntimeError('Invalid state!')
- def __eq__(self, other):
- if not isinstance(other, _remote_device):
- return False
- if (
- self._worker_name == other._worker_name
- and self._device == other._device
- and self._rank == other._rank
- ):
- return True
- return False
- def __hash__(self):
- return hash(self._worker_name) ^ \
- hash(self._device) ^ \
- hash(self._rank)
|