_functions.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import warnings
  2. import torch
  3. from . import comm
  4. from torch.autograd import Function
  5. from torch._utils import _get_device_index
  6. from typing import List, Optional
  7. class Broadcast(Function):
  8. @staticmethod
  9. def forward(ctx, target_gpus, *inputs):
  10. assert all(i.device.type != 'cpu' for i in inputs), (
  11. 'Broadcast function not implemented for CPU tensors'
  12. )
  13. target_gpus = [_get_device_index(x, True) for x in target_gpus]
  14. ctx.target_gpus = target_gpus
  15. if len(inputs) == 0:
  16. return tuple()
  17. ctx.num_inputs = len(inputs)
  18. ctx.input_device = inputs[0].get_device()
  19. outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
  20. non_differentiables = []
  21. for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
  22. if not input_requires_grad:
  23. for output in outputs:
  24. non_differentiables.append(output[idx])
  25. ctx.mark_non_differentiable(*non_differentiables)
  26. return tuple([t for tensors in outputs for t in tensors])
  27. @staticmethod
  28. def backward(ctx, *grad_outputs):
  29. return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs)
  30. class ReduceAddCoalesced(Function):
  31. @staticmethod
  32. def forward(ctx, destination, num_inputs, *grads):
  33. ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)]
  34. grads_ = [grads[i:i + num_inputs]
  35. for i in range(0, len(grads), num_inputs)]
  36. return comm.reduce_add_coalesced(grads_, destination)
  37. @staticmethod
  38. def backward(ctx, *grad_outputs):
  39. return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
  40. class Gather(Function):
  41. @staticmethod
  42. def forward(ctx, target_device, dim, *inputs):
  43. assert all(i.device.type != 'cpu' for i in inputs), (
  44. 'Gather function not implemented for CPU tensors'
  45. )
  46. if (target_device == 'cpu'):
  47. ctx.target_device = 'cpu'
  48. else:
  49. target_device = _get_device_index(target_device, True)
  50. ctx.target_device = target_device
  51. ctx.dim = dim
  52. ctx.input_gpus = tuple(i.get_device() for i in inputs)
  53. if all(t.dim() == 0 for t in inputs) and dim == 0:
  54. inputs = tuple(t.view(1) for t in inputs)
  55. warnings.warn('Was asked to gather along dimension 0, but all '
  56. 'input tensors were scalars; will instead unsqueeze '
  57. 'and return a vector.')
  58. ctx.unsqueezed_scalar = True
  59. else:
  60. ctx.unsqueezed_scalar = False
  61. ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
  62. return comm.gather(inputs, ctx.dim, ctx.target_device)
  63. @staticmethod
  64. def backward(ctx, grad_output):
  65. scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
  66. if ctx.unsqueezed_scalar:
  67. scattered_grads = tuple(g[0] for g in scattered_grads)
  68. return (None, None) + scattered_grads
  69. class Scatter(Function):
  70. @staticmethod
  71. def forward(ctx, target_gpus, chunk_sizes, dim, input):
  72. target_gpus = [_get_device_index(x, True) for x in target_gpus]
  73. ctx.dim = dim
  74. ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
  75. streams = None
  76. if torch.cuda.is_available() and ctx.input_device == -1:
  77. # Perform CPU to GPU copies in a background stream
  78. streams = [_get_stream(device) for device in target_gpus]
  79. outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
  80. # Synchronize with the copy stream
  81. if streams is not None:
  82. for i, output in enumerate(outputs):
  83. with torch.cuda.device(target_gpus[i]):
  84. main_stream = torch.cuda.current_stream()
  85. main_stream.wait_stream(streams[i])
  86. output.record_stream(main_stream)
  87. return outputs
  88. @staticmethod
  89. def backward(ctx, *grad_output):
  90. return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
  91. # background streams used for copying
  92. _streams: Optional[List[Optional[torch.cuda.Stream]]] = None
  93. def _get_stream(device: int):
  94. """Gets a background stream for copying between CPU and GPU"""
  95. global _streams
  96. if device == -1:
  97. return None
  98. if _streams is None:
  99. _streams = [None] * torch.cuda.device_count()
  100. if _streams[device] is None:
  101. _streams[device] = torch.cuda.Stream(device)
  102. return _streams[device]