scatter_gather.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import torch
  2. from ._functions import Scatter, Gather
  3. import warnings
  4. __all__ = ['scatter', 'scatter_kwargs', 'gather']
  5. def is_namedtuple(obj):
  6. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
  7. warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
  8. return _is_namedtuple(obj)
  9. def _is_namedtuple(obj):
  10. # Check if type was created from collections.namedtuple or a typing.NamedTuple.
  11. return (
  12. isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
  13. )
  14. def scatter(inputs, target_gpus, dim=0):
  15. r"""
  16. Slices tensors into approximately equal chunks and
  17. distributes them across given GPUs. Duplicates
  18. references to objects that are not tensors.
  19. """
  20. def scatter_map(obj):
  21. if isinstance(obj, torch.Tensor):
  22. return Scatter.apply(target_gpus, None, dim, obj)
  23. if _is_namedtuple(obj):
  24. return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
  25. if isinstance(obj, tuple) and len(obj) > 0:
  26. return list(zip(*map(scatter_map, obj)))
  27. if isinstance(obj, list) and len(obj) > 0:
  28. return [list(i) for i in zip(*map(scatter_map, obj))]
  29. if isinstance(obj, dict) and len(obj) > 0:
  30. return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
  31. return [obj for targets in target_gpus]
  32. # After scatter_map is called, a scatter_map cell will exist. This cell
  33. # has a reference to the actual function scatter_map, which has references
  34. # to a closure that has a reference to the scatter_map cell (because the
  35. # fn is recursive). To avoid this reference cycle, we set the function to
  36. # None, clearing the cell
  37. try:
  38. res = scatter_map(inputs)
  39. finally:
  40. scatter_map = None
  41. return res
  42. def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
  43. r"""Scatter with support for kwargs dictionary"""
  44. inputs = scatter(inputs, target_gpus, dim) if inputs else []
  45. kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
  46. if len(inputs) < len(kwargs):
  47. inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
  48. elif len(kwargs) < len(inputs):
  49. kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
  50. inputs = tuple(inputs)
  51. kwargs = tuple(kwargs)
  52. return inputs, kwargs
  53. def gather(outputs, target_device, dim=0):
  54. r"""
  55. Gathers tensors from different GPUs on a specified device.
  56. Use 'cpu' for CPU to avoid a deprecation warning.
  57. """
  58. def gather_map(outputs):
  59. out = outputs[0]
  60. if isinstance(out, torch.Tensor):
  61. return Gather.apply(target_device, dim, *outputs)
  62. if out is None:
  63. return None
  64. if isinstance(out, dict):
  65. if not all(len(out) == len(d) for d in outputs):
  66. raise ValueError('All dicts must have the same number of keys')
  67. return type(out)((k, gather_map([d[k] for d in outputs]))
  68. for k in out)
  69. if _is_namedtuple(out):
  70. return type(out)._make(map(gather_map, zip(*outputs)))
  71. return type(out)(map(gather_map, zip(*outputs)))
  72. # Recursive function calls like this create reference cycles.
  73. # Setting the function to None clears the refcycle.
  74. try:
  75. res = gather_map(outputs)
  76. finally:
  77. gather_map = None
  78. return res