_vmap_internals.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import functools
  2. import warnings
  3. from typing import Any, Callable, List, Optional, Tuple, Union
  4. import torch
  5. from torch import Tensor
  6. from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
  7. in_dims_t = Union[int, Tuple]
  8. out_dims_t = Union[int, Tuple[int, ...]]
  9. # Checks that all args-to-be-batched have the same batch dim size
  10. def _validate_and_get_batch_size(
  11. flat_in_dims: List[Optional[int]], flat_args: List
  12. ) -> int:
  13. batch_sizes = [
  14. arg.size(in_dim)
  15. for in_dim, arg in zip(flat_in_dims, flat_args)
  16. if in_dim is not None
  17. ]
  18. if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]):
  19. raise ValueError(
  20. f"vmap: Expected all tensors to have the same size in the mapped "
  21. f"dimension, got sizes {batch_sizes} for the mapped dimension"
  22. )
  23. return batch_sizes[0]
  24. def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
  25. if isinstance(batched_outputs, tuple):
  26. return len(batched_outputs)
  27. return 1
  28. # If value is a tuple, check it has length `num_elements`.
  29. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  30. def _as_tuple(
  31. value: Any, num_elements: int, error_message_lambda: Callable[[], str]
  32. ) -> Tuple:
  33. if not isinstance(value, tuple):
  34. return (value,) * num_elements
  35. if len(value) != num_elements:
  36. raise ValueError(error_message_lambda())
  37. return value
  38. # Creates BatchedTensors for every Tensor in arg that should be batched.
  39. # Returns the (potentially) batched arguments and the batch_size.
  40. def _create_batched_inputs(
  41. in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable
  42. ) -> Tuple[Tuple, int]:
  43. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  44. raise ValueError(
  45. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  46. f"expected `in_dims` to be int or a (potentially nested) tuple "
  47. f"matching the structure of inputs, got: {type(in_dims)}."
  48. )
  49. if len(args) == 0:
  50. raise ValueError(
  51. f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
  52. f"inputs, or you are trying to vmap over a function with no inputs. "
  53. f"The latter is unsupported."
  54. )
  55. flat_args, args_spec = tree_flatten(args)
  56. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  57. if flat_in_dims is None:
  58. raise ValueError(
  59. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  60. f"in_dims is not compatible with the structure of `inputs`. "
  61. f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
  62. f"has structure {args_spec}."
  63. )
  64. for arg, in_dim in zip(flat_args, flat_in_dims):
  65. if not isinstance(in_dim, int) and in_dim is not None:
  66. raise ValueError(
  67. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  68. f"Got in_dim={in_dim} for an input but in_dim must be either "
  69. f"an integer dimension or None."
  70. )
  71. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  72. raise ValueError(
  73. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  74. f"Got in_dim={in_dim} for an input but the input is of type "
  75. f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
  76. f"please use None as the respective in_dim"
  77. )
  78. if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
  79. raise ValueError(
  80. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  81. f"Got in_dim={in_dim} for some input, but that input is a Tensor "
  82. f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
  83. f"0 <= in_dim < {arg.dim()}."
  84. )
  85. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  86. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  87. batched_inputs = [
  88. arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
  89. for in_dim, arg in zip(flat_in_dims, flat_args)
  90. ]
  91. return tree_unflatten(batched_inputs, args_spec), batch_size
  92. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  93. def _unwrap_batched(
  94. batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
  95. out_dims: out_dims_t,
  96. vmap_level: int,
  97. batch_size: int,
  98. func: Callable,
  99. allow_none_pass_through: bool = False,
  100. ) -> Tuple:
  101. num_outputs = _num_outputs(batched_outputs)
  102. out_dims_as_tuple = _as_tuple(
  103. out_dims,
  104. num_outputs,
  105. lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
  106. f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
  107. )
  108. # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  109. # There is something wrong with our type bindings for functions that begin
  110. # with '_', see #40397.
  111. if isinstance(batched_outputs, Tensor):
  112. out_dim = out_dims_as_tuple[0]
  113. return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
  114. if allow_none_pass_through:
  115. return tuple(
  116. (
  117. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  118. if out is not None
  119. else None
  120. )
  121. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  122. )
  123. else:
  124. return tuple(
  125. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  126. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  127. )
  128. # Checks that `fn` returned one or more Tensors and nothing else.
  129. # NB: A python function that return multiple arguments returns a single tuple,
  130. # so we are effectively checking that `outputs` is a single Tensor or a tuple of
  131. # Tensors.
  132. def _validate_outputs(outputs: Any, func: Callable) -> None:
  133. if isinstance(outputs, Tensor):
  134. return
  135. if not isinstance(outputs, tuple):
  136. raise ValueError(
  137. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  138. f"Tensors, got type {type(outputs)} as the return."
  139. )
  140. for idx, output in enumerate(outputs):
  141. if isinstance(output, Tensor):
  142. continue
  143. raise ValueError(
  144. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  145. f"Tensors, got type {type(output)} for return {idx}."
  146. )
  147. def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
  148. if isinstance(out_dims, int):
  149. return
  150. if not isinstance(out_dims, tuple) or not all(
  151. [isinstance(out_dim, int) for out_dim in out_dims]
  152. ):
  153. raise ValueError(
  154. f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
  155. f"an int or a tuple of int representing where in the outputs the "
  156. f"vmapped dimension should appear."
  157. )
  158. def _get_name(func: Callable):
  159. if hasattr(func, "__name__"):
  160. return func.__name__
  161. # Not all callables have __name__, in fact, only static functions/methods do.
  162. # A callable created via functools.partial or an nn.Module, to name some
  163. # examples, don't have a __name__.
  164. return repr(func)
  165. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  166. # sends those into func, and then unwraps the output BatchedTensors. Operations
  167. # on BatchedTensors perform the batched operations that the user is asking for.
  168. def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
  169. """
  170. Please use torch.vmap instead of this API.
  171. """
  172. warnings.warn(
  173. "Please use torch.vmap instead of torch._vmap_internals.vmap. ",
  174. stacklevel=2,
  175. )
  176. return _vmap(func, in_dims, out_dims)
  177. # A version of vmap but without the initial "experimental prototype" warning
  178. def _vmap(
  179. func: Callable,
  180. in_dims: in_dims_t = 0,
  181. out_dims: out_dims_t = 0,
  182. allow_none_pass_through: bool = False,
  183. ) -> Callable:
  184. # The `allow_none_pass_through` argument is a temporary workaround may be removed.
  185. # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
  186. # which may return None if any of the inputs are unused. See the issue discussing this:
  187. # https://github.com/facebookresearch/functorch/issues/159.
  188. @functools.wraps(func)
  189. def wrapped(*args):
  190. _check_out_dims_is_int_or_int_tuple(out_dims, func)
  191. vmap_level = torch._C._vmapmode_increment_nesting()
  192. try:
  193. batched_inputs, batch_size = _create_batched_inputs(
  194. in_dims, args, vmap_level, func
  195. )
  196. batched_outputs = func(*batched_inputs)
  197. if not allow_none_pass_through:
  198. _validate_outputs(batched_outputs, func)
  199. return _unwrap_batched(
  200. batched_outputs,
  201. out_dims,
  202. vmap_level,
  203. batch_size,
  204. func,
  205. allow_none_pass_through=allow_none_pass_through,
  206. )
  207. finally:
  208. torch._C._vmapmode_decrement_nesting()
  209. return wrapped