123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import torch._functorch.vmap as _vmap_impl
- import torch._functorch.eager_transforms as _impl
- import torch._functorch.make_functional as _nn_impl
- from torch._functorch.vmap import in_dims_t, out_dims_t
- from torch._functorch.eager_transforms import argnums_t
- import torch.nn as nn
- import textwrap
- from typing import Any, Callable, Optional, Tuple, Union
- import warnings
- """
- The APIs in this file are exposed as `functorch.*`. They are thin wrappers
- around the torch.func.* APIs that have deprecation warnings -- we're trying
- to move people to the torch.func.* equivalents.
- NB: We don't use *args, **kwargs in the signatures because that changes the
- documentation.
- """
- def get_warning(api, new_api=None, replace_newlines=False):
- if new_api is None:
- new_api = f'torch.func.{api}'
- warning = (
- f"We've integrated functorch into PyTorch. As the final step of the \n"
- f"integration, functorch.{api} is deprecated as of PyTorch \n"
- f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
- f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n"
- f"and/or the torch.func migration guide for more details \n"
- f"https://pytorch.org/docs/master/func.migrating.html"
- )
- if replace_newlines:
- warning = warning.replace("\n", "")
- return warning
- def warn_deprecated(api, new_api=None):
- warning = get_warning(api, new_api, replace_newlines=True)
- warnings.warn(warning, stacklevel=2)
- def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
- api_name = functorch_api.__name__
- if torch_func_api is None:
- torch_func_api = getattr(_impl, api_name)
- warning = get_warning(api_name, new_api_name)
- warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ")
- warning_note = textwrap.indent(warning_note, " ")
- functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
- def vmap(
- func: Callable,
- in_dims: in_dims_t = 0,
- out_dims: out_dims_t = 0,
- randomness: str = 'error',
- *,
- chunk_size=None) -> Callable:
- warn_deprecated('vmap', 'torch.vmap')
- return _vmap_impl.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
- def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
- warn_deprecated('grad')
- return _impl.grad(func, argnums, has_aux)
- def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
- warn_deprecated('grad_and_value')
- return _impl.grad_and_value(func, argnums, has_aux)
- def vjp(func: Callable, *primals, has_aux: bool = False):
- warn_deprecated('vjp')
- return _impl.vjp(func, *primals, has_aux=has_aux)
- def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
- warn_deprecated('jvp')
- return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
- def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
- chunk_size: Optional[int] = None,
- _preallocate_and_copy=False):
- warn_deprecated('jacrev')
- return _impl.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size,
- _preallocate_and_copy=_preallocate_and_copy)
- def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
- warn_deprecated('jacfwd')
- return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
- def hessian(func, argnums=0):
- warn_deprecated('hessian')
- return _impl.hessian(func, argnums=argnums)
- def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
- warn_deprecated('functionalize')
- return _impl.functionalize(func, remove=remove)
- def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
- warn_deprecated('make_functional', 'torch.func.functional_call')
- return _nn_impl.make_functional(model, disable_autograd_tracking)
- def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
- warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')
- return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
- def combine_state_for_ensemble(models):
- warn_deprecated('combine_state_for_ensemble', 'torch.func.stack_module_state')
- return _nn_impl.combine_state_for_ensemble(models)
- setup_docs(vmap, _vmap_impl.vmap, 'torch.vmap')
- setup_docs(grad)
- setup_docs(grad_and_value)
- setup_docs(vjp)
- setup_docs(jvp)
- setup_docs(jacrev)
- setup_docs(jacfwd)
- setup_docs(hessian)
- setup_docs(functionalize)
- setup_docs(make_functional, _nn_impl.make_functional,
- 'torch.func.functional_call')
- setup_docs(make_functional_with_buffers, _nn_impl.make_functional,
- 'torch.func.functional_call')
- setup_docs(combine_state_for_ensemble, _nn_impl.combine_state_for_ensemble,
- 'torch.func.stack_module_state')
|