123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- # Copyright 2019 Kakao Brain
- #
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- """Manipulation of micro-batches."""
- import typing
- from typing import Any, Callable, List, Union, cast, Sequence
- import torch
- from torch import Tensor
- import torch.cuda.comm
- __all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]
- Tensors = Sequence[Tensor]
- TensorOrTensors = Union[Tensor, Tensors]
- Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]]
- class NoChunk:
- """
- Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor
- should not be chunked on the batch dimension and instead be replicated
- as-is across all micro-batches. This is useful for tensors which might
- not have any 'batch' semantics for the model.
- """
- def __init__(self, inp: Tensor):
- if not torch.is_tensor(inp):
- raise TypeError(f'NoChunk only supported for tensors, found: {inp}')
- self._tensor = inp
- @property
- def tensor(self):
- return self._tensor
- class Batch:
- """
- An abstraction representing a microbatch in the pipeline.
- """
- def __init__(self, values: Union[List[Any], Tensor]) -> None:
- self._values = values
- self.atomic = torch.is_tensor(values)
- # Verify at least on tensor
- if not self.atomic:
- if not any(torch.is_tensor(value) for value in self._values):
- raise TypeError(f'No tensors found in batch: {self._values}')
- @property
- def tensor(self) -> Tensor:
- """Retrieves the underlying tensor."""
- if not self.atomic:
- raise AttributeError("not atomic batch")
- return cast(Tensor, self._values)
- @property
- def values(self):
- """Retreives the underlying values for the batch"""
- return self._values
- def find_tensor_idx(self):
- """
- Retrieves the index of first tensor found.
- """
- if self.atomic:
- return 0
- for i, value in enumerate(self._values):
- if torch.is_tensor(value):
- return i
- raise TypeError("No tensor found!")
- def get_device(self):
- """
- Retrieves the device for this microbatch.
- """
- if self.atomic:
- return self._values.device # type: ignore[union-attr]
- for value in self._values:
- if torch.is_tensor(value):
- return value.device
- def call(self, function: Function) -> "Batch":
- """Calls a function on the microbatch. It also wraps
- the output with :class:`Batch`.
- """
- if self.atomic:
- return Batch(function(self._values))
- else:
- return Batch(function(*self._values))
- def __repr__(self) -> str:
- return f"Batch[atomic={self.atomic!r}]({self._values!r})"
- def __iter__(self):
- if self.atomic:
- yield self._values
- else:
- yield from self._values
- def __len__(self) -> int:
- return 1 if self.atomic else len(self._values)
- def __getitem__(self, index: int):
- if not self.atomic:
- return self._values[index]
- if index != 0:
- raise IndexError("atomic batch allows index 0 only")
- return self._values
- # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
- @typing.overload
- def __setitem__(self, index: int, value: Tensor) -> None:
- ...
- @typing.overload
- def __setitem__(self, index: slice, value: Tensors) -> None:
- ...
- def __setitem__(self, index: Union[int, slice], value) -> None:
- if isinstance(index, int):
- self._setitem_by_index(index, value)
- else:
- self._setitem_by_slice(index, value)
- def _setitem_by_index(self, index: int, value) -> None:
- if not self.atomic:
- i = index
- self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator]
- return
- if index != 0:
- raise IndexError("atomic batch allows index 0 only")
- self._values = value
- def _setitem_by_slice(self, index: slice, value) -> None:
- if not (index.start is index.stop is index.step is None):
- raise NotImplementedError("only slice [:] supported")
- if not self.atomic:
- self._values = value
- return
- if len(value) != 1:
- raise IndexError("atomic batch cannot be replaced with multiple tensors")
- self._values = value[0]
- def check(first_device, *inputs) -> None:
- """
- Checks whether the input contains at least one tensor and each tensor is
- on the same device as the first partition.
- Raises:
- ValueError: input does not contain at least one tensor
- """
- if not any(torch.is_tensor(input) for input in inputs):
- raise TypeError(f'inputs do not have any tensors: {inputs}')
- if any(torch.is_tensor(input) and input.device != first_device for input in inputs):
- raise ValueError('All inputs should be on the same device as the first partition')
- def scatter(*inputs, chunks: int) -> List[Batch]:
- """Splits an input mini-batch into multiple micro-batches."""
- if len(inputs) == 1 and isinstance(inputs[0], Tensor):
- return [Batch(x) for x in inputs[0].chunk(chunks)]
- batches: List[Any] = [[] for _ in range(chunks)]
- # Actual number of chunks produced
- num_chunks = -1
- for input in inputs:
- if torch.is_tensor(input):
- # Chunk only tensors.
- tensors = input.chunk(chunks)
- # Validate number of chunks equal across all inputs.
- if num_chunks != -1 and num_chunks != len(tensors):
- raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}')
- num_chunks = len(tensors)
- for i, tensor in enumerate(tensors):
- batches[i].append(tensor)
- else:
- # Replicate non-tensors or tensors wrapped with 'NoChunk'.
- for i in range(chunks):
- if isinstance(input, NoChunk):
- # Extract the tensor out.
- batches[i].append(input.tensor)
- else:
- batches[i].append(input)
- # Truncate to actual number of chunks
- batches = batches[:num_chunks]
- return [Batch(x) for x in batches]
- def gather(outputs: List[Batch]):
- """Concatenates output micro-batches into a mini-batch."""
- output: Any
- if outputs[0].atomic:
- tensors = tuple(b.tensor for b in outputs)
- output = torch.cat(tensors)
- else:
- output_buf: List[Any] = []
- for i in range(len(outputs[0])):
- output_type = type(outputs[0][i])
- current_outputs = []
- for batch in outputs:
- if output_type != type(batch[i]):
- raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}')
- current_outputs.append(batch[i])
- if torch.is_tensor(outputs[0][i]):
- output_buf.append(torch.cat(current_outputs))
- else:
- output_buf.append(current_outputs)
- output = tuple(output_buf)
- return output
|