123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import functools
- import warnings
- from typing import Any, Callable, List, Optional, Tuple, Union
- import torch
- from torch import Tensor
- from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
- in_dims_t = Union[int, Tuple]
- out_dims_t = Union[int, Tuple[int, ...]]
- # Checks that all args-to-be-batched have the same batch dim size
- def _validate_and_get_batch_size(
- flat_in_dims: List[Optional[int]], flat_args: List
- ) -> int:
- batch_sizes = [
- arg.size(in_dim)
- for in_dim, arg in zip(flat_in_dims, flat_args)
- if in_dim is not None
- ]
- if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]):
- raise ValueError(
- f"vmap: Expected all tensors to have the same size in the mapped "
- f"dimension, got sizes {batch_sizes} for the mapped dimension"
- )
- return batch_sizes[0]
- def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
- if isinstance(batched_outputs, tuple):
- return len(batched_outputs)
- return 1
- # If value is a tuple, check it has length `num_elements`.
- # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
- def _as_tuple(
- value: Any, num_elements: int, error_message_lambda: Callable[[], str]
- ) -> Tuple:
- if not isinstance(value, tuple):
- return (value,) * num_elements
- if len(value) != num_elements:
- raise ValueError(error_message_lambda())
- return value
- # Creates BatchedTensors for every Tensor in arg that should be batched.
- # Returns the (potentially) batched arguments and the batch_size.
- def _create_batched_inputs(
- in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable
- ) -> Tuple[Tuple, int]:
- if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
- raise ValueError(
- f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
- f"expected `in_dims` to be int or a (potentially nested) tuple "
- f"matching the structure of inputs, got: {type(in_dims)}."
- )
- if len(args) == 0:
- raise ValueError(
- f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
- f"inputs, or you are trying to vmap over a function with no inputs. "
- f"The latter is unsupported."
- )
- flat_args, args_spec = tree_flatten(args)
- flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
- if flat_in_dims is None:
- raise ValueError(
- f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
- f"in_dims is not compatible with the structure of `inputs`. "
- f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
- f"has structure {args_spec}."
- )
- for arg, in_dim in zip(flat_args, flat_in_dims):
- if not isinstance(in_dim, int) and in_dim is not None:
- raise ValueError(
- f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
- f"Got in_dim={in_dim} for an input but in_dim must be either "
- f"an integer dimension or None."
- )
- if isinstance(in_dim, int) and not isinstance(arg, Tensor):
- raise ValueError(
- f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
- f"Got in_dim={in_dim} for an input but the input is of type "
- f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
- f"please use None as the respective in_dim"
- )
- if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
- raise ValueError(
- f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
- f"Got in_dim={in_dim} for some input, but that input is a Tensor "
- f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
- f"0 <= in_dim < {arg.dim()}."
- )
- batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
- # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
- batched_inputs = [
- arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
- for in_dim, arg in zip(flat_in_dims, flat_args)
- ]
- return tree_unflatten(batched_inputs, args_spec), batch_size
- # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
- def _unwrap_batched(
- batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
- out_dims: out_dims_t,
- vmap_level: int,
- batch_size: int,
- func: Callable,
- allow_none_pass_through: bool = False,
- ) -> Tuple:
- num_outputs = _num_outputs(batched_outputs)
- out_dims_as_tuple = _as_tuple(
- out_dims,
- num_outputs,
- lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
- f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
- )
- # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
- # There is something wrong with our type bindings for functions that begin
- # with '_', see #40397.
- if isinstance(batched_outputs, Tensor):
- out_dim = out_dims_as_tuple[0]
- return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
- if allow_none_pass_through:
- return tuple(
- (
- torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
- if out is not None
- else None
- )
- for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
- )
- else:
- return tuple(
- torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
- for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
- )
- # Checks that `fn` returned one or more Tensors and nothing else.
- # NB: A python function that return multiple arguments returns a single tuple,
- # so we are effectively checking that `outputs` is a single Tensor or a tuple of
- # Tensors.
- def _validate_outputs(outputs: Any, func: Callable) -> None:
- if isinstance(outputs, Tensor):
- return
- if not isinstance(outputs, tuple):
- raise ValueError(
- f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
- f"Tensors, got type {type(outputs)} as the return."
- )
- for idx, output in enumerate(outputs):
- if isinstance(output, Tensor):
- continue
- raise ValueError(
- f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
- f"Tensors, got type {type(output)} for return {idx}."
- )
- def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
- if isinstance(out_dims, int):
- return
- if not isinstance(out_dims, tuple) or not all(
- [isinstance(out_dim, int) for out_dim in out_dims]
- ):
- raise ValueError(
- f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
- f"an int or a tuple of int representing where in the outputs the "
- f"vmapped dimension should appear."
- )
- def _get_name(func: Callable):
- if hasattr(func, "__name__"):
- return func.__name__
- # Not all callables have __name__, in fact, only static functions/methods do.
- # A callable created via functools.partial or an nn.Module, to name some
- # examples, don't have a __name__.
- return repr(func)
- # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
- # sends those into func, and then unwraps the output BatchedTensors. Operations
- # on BatchedTensors perform the batched operations that the user is asking for.
- def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
- """
- Please use torch.vmap instead of this API.
- """
- warnings.warn(
- "Please use torch.vmap instead of torch._vmap_internals.vmap. ",
- stacklevel=2,
- )
- return _vmap(func, in_dims, out_dims)
- # A version of vmap but without the initial "experimental prototype" warning
- def _vmap(
- func: Callable,
- in_dims: in_dims_t = 0,
- out_dims: out_dims_t = 0,
- allow_none_pass_through: bool = False,
- ) -> Callable:
- # The `allow_none_pass_through` argument is a temporary workaround may be removed.
- # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
- # which may return None if any of the inputs are unused. See the issue discussing this:
- # https://github.com/facebookresearch/functorch/issues/159.
- @functools.wraps(func)
- def wrapped(*args):
- _check_out_dims_is_int_or_int_tuple(out_dims, func)
- vmap_level = torch._C._vmapmode_increment_nesting()
- try:
- batched_inputs, batch_size = _create_batched_inputs(
- in_dims, args, vmap_level, func
- )
- batched_outputs = func(*batched_inputs)
- if not allow_none_pass_through:
- _validate_outputs(batched_outputs, func)
- return _unwrap_batched(
- batched_outputs,
- out_dims,
- vmap_level,
- batch_size,
- func,
- allow_none_pass_through=allow_none_pass_through,
- )
- finally:
- torch._C._vmapmode_decrement_nesting()
- return wrapped
|