microbatch.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Manipulation of micro-batches."""
  8. import typing
  9. from typing import Any, Callable, List, Union, cast, Sequence
  10. import torch
  11. from torch import Tensor
  12. import torch.cuda.comm
  13. __all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]
  14. Tensors = Sequence[Tensor]
  15. TensorOrTensors = Union[Tensor, Tensors]
  16. Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]]
  17. class NoChunk:
  18. """
  19. Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor
  20. should not be chunked on the batch dimension and instead be replicated
  21. as-is across all micro-batches. This is useful for tensors which might
  22. not have any 'batch' semantics for the model.
  23. """
  24. def __init__(self, inp: Tensor):
  25. if not torch.is_tensor(inp):
  26. raise TypeError(f'NoChunk only supported for tensors, found: {inp}')
  27. self._tensor = inp
  28. @property
  29. def tensor(self):
  30. return self._tensor
  31. class Batch:
  32. """
  33. An abstraction representing a microbatch in the pipeline.
  34. """
  35. def __init__(self, values: Union[List[Any], Tensor]) -> None:
  36. self._values = values
  37. self.atomic = torch.is_tensor(values)
  38. # Verify at least on tensor
  39. if not self.atomic:
  40. if not any(torch.is_tensor(value) for value in self._values):
  41. raise TypeError(f'No tensors found in batch: {self._values}')
  42. @property
  43. def tensor(self) -> Tensor:
  44. """Retrieves the underlying tensor."""
  45. if not self.atomic:
  46. raise AttributeError("not atomic batch")
  47. return cast(Tensor, self._values)
  48. @property
  49. def values(self):
  50. """Retreives the underlying values for the batch"""
  51. return self._values
  52. def find_tensor_idx(self):
  53. """
  54. Retrieves the index of first tensor found.
  55. """
  56. if self.atomic:
  57. return 0
  58. for i, value in enumerate(self._values):
  59. if torch.is_tensor(value):
  60. return i
  61. raise TypeError("No tensor found!")
  62. def get_device(self):
  63. """
  64. Retrieves the device for this microbatch.
  65. """
  66. if self.atomic:
  67. return self._values.device # type: ignore[union-attr]
  68. for value in self._values:
  69. if torch.is_tensor(value):
  70. return value.device
  71. def call(self, function: Function) -> "Batch":
  72. """Calls a function on the microbatch. It also wraps
  73. the output with :class:`Batch`.
  74. """
  75. if self.atomic:
  76. return Batch(function(self._values))
  77. else:
  78. return Batch(function(*self._values))
  79. def __repr__(self) -> str:
  80. return f"Batch[atomic={self.atomic!r}]({self._values!r})"
  81. def __iter__(self):
  82. if self.atomic:
  83. yield self._values
  84. else:
  85. yield from self._values
  86. def __len__(self) -> int:
  87. return 1 if self.atomic else len(self._values)
  88. def __getitem__(self, index: int):
  89. if not self.atomic:
  90. return self._values[index]
  91. if index != 0:
  92. raise IndexError("atomic batch allows index 0 only")
  93. return self._values
  94. # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
  95. @typing.overload
  96. def __setitem__(self, index: int, value: Tensor) -> None:
  97. ...
  98. @typing.overload
  99. def __setitem__(self, index: slice, value: Tensors) -> None:
  100. ...
  101. def __setitem__(self, index: Union[int, slice], value) -> None:
  102. if isinstance(index, int):
  103. self._setitem_by_index(index, value)
  104. else:
  105. self._setitem_by_slice(index, value)
  106. def _setitem_by_index(self, index: int, value) -> None:
  107. if not self.atomic:
  108. i = index
  109. self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator]
  110. return
  111. if index != 0:
  112. raise IndexError("atomic batch allows index 0 only")
  113. self._values = value
  114. def _setitem_by_slice(self, index: slice, value) -> None:
  115. if not (index.start is index.stop is index.step is None):
  116. raise NotImplementedError("only slice [:] supported")
  117. if not self.atomic:
  118. self._values = value
  119. return
  120. if len(value) != 1:
  121. raise IndexError("atomic batch cannot be replaced with multiple tensors")
  122. self._values = value[0]
  123. def check(first_device, *inputs) -> None:
  124. """
  125. Checks whether the input contains at least one tensor and each tensor is
  126. on the same device as the first partition.
  127. Raises:
  128. ValueError: input does not contain at least one tensor
  129. """
  130. if not any(torch.is_tensor(input) for input in inputs):
  131. raise TypeError(f'inputs do not have any tensors: {inputs}')
  132. if any(torch.is_tensor(input) and input.device != first_device for input in inputs):
  133. raise ValueError('All inputs should be on the same device as the first partition')
  134. def scatter(*inputs, chunks: int) -> List[Batch]:
  135. """Splits an input mini-batch into multiple micro-batches."""
  136. if len(inputs) == 1 and isinstance(inputs[0], Tensor):
  137. return [Batch(x) for x in inputs[0].chunk(chunks)]
  138. batches: List[Any] = [[] for _ in range(chunks)]
  139. # Actual number of chunks produced
  140. num_chunks = -1
  141. for input in inputs:
  142. if torch.is_tensor(input):
  143. # Chunk only tensors.
  144. tensors = input.chunk(chunks)
  145. # Validate number of chunks equal across all inputs.
  146. if num_chunks != -1 and num_chunks != len(tensors):
  147. raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}')
  148. num_chunks = len(tensors)
  149. for i, tensor in enumerate(tensors):
  150. batches[i].append(tensor)
  151. else:
  152. # Replicate non-tensors or tensors wrapped with 'NoChunk'.
  153. for i in range(chunks):
  154. if isinstance(input, NoChunk):
  155. # Extract the tensor out.
  156. batches[i].append(input.tensor)
  157. else:
  158. batches[i].append(input)
  159. # Truncate to actual number of chunks
  160. batches = batches[:num_chunks]
  161. return [Batch(x) for x in batches]
  162. def gather(outputs: List[Batch]):
  163. """Concatenates output micro-batches into a mini-batch."""
  164. output: Any
  165. if outputs[0].atomic:
  166. tensors = tuple(b.tensor for b in outputs)
  167. output = torch.cat(tensors)
  168. else:
  169. output_buf: List[Any] = []
  170. for i in range(len(outputs[0])):
  171. output_type = type(outputs[0][i])
  172. current_outputs = []
  173. for batch in outputs:
  174. if output_type != type(batch[i]):
  175. raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}')
  176. current_outputs.append(batch[i])
  177. if torch.is_tensor(outputs[0][i]):
  178. output_buf.append(torch.cat(current_outputs))
  179. else:
  180. output_buf.append(current_outputs)
  181. output = tuple(output_buf)
  182. return output