_foreach_utils.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from collections import defaultdict
  2. from typing import List, Dict, Tuple, Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from torch.autograd.grad_mode import no_grad
  6. # This util function splits tensors into groups by device and dtype, which is useful before sending
  7. # tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
  8. # If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
  9. # - tensorlists CAN be None
  10. # - all tensors in the first specified list cannot be None
  11. # - given an index i, all specified tensorlist[i]s match in dtype and device
  12. # with_indices (bool, optional): whether to track previous indices as the last list per dictionary entry.
  13. # It comes in handy if there are Nones or literals in the tensorlists that are getting scattered out.
  14. # Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
  15. # original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
  16. # may be necessary. Check out torch/optim/sgd.py for an example.
  17. @no_grad()
  18. def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
  19. with_indices: Optional[bool] = False) -> \
  20. Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
  21. assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
  22. "all specified tensorlists must match in length")
  23. per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
  24. lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
  25. for i, t in enumerate(tensorlistlist[0]):
  26. key = (t.device, t.dtype)
  27. for j in range(len(tensorlistlist)):
  28. # a tensorlist may be empty/None
  29. if tensorlistlist[j]:
  30. per_device_and_dtype_tensors[key][j].append(tensorlistlist[j][i])
  31. if with_indices:
  32. # tack on previous index
  33. per_device_and_dtype_tensors[key][j + 1].append(i)
  34. return per_device_and_dtype_tensors
  35. def _has_foreach_support(tensors: List[Tensor], device: torch.device) -> bool:
  36. if device.type not in ['cpu', 'cuda'] or torch.jit.is_scripting():
  37. return False
  38. return all([t is None or type(t) == torch.Tensor for t in tensors])