deprecated.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch._functorch.vmap as _vmap_impl
  2. import torch._functorch.eager_transforms as _impl
  3. import torch._functorch.make_functional as _nn_impl
  4. from torch._functorch.vmap import in_dims_t, out_dims_t
  5. from torch._functorch.eager_transforms import argnums_t
  6. import torch.nn as nn
  7. import textwrap
  8. from typing import Any, Callable, Optional, Tuple, Union
  9. import warnings
  10. """
  11. The APIs in this file are exposed as `functorch.*`. They are thin wrappers
  12. around the torch.func.* APIs that have deprecation warnings -- we're trying
  13. to move people to the torch.func.* equivalents.
  14. NB: We don't use *args, **kwargs in the signatures because that changes the
  15. documentation.
  16. """
  17. def get_warning(api, new_api=None, replace_newlines=False):
  18. if new_api is None:
  19. new_api = f'torch.func.{api}'
  20. warning = (
  21. f"We've integrated functorch into PyTorch. As the final step of the \n"
  22. f"integration, functorch.{api} is deprecated as of PyTorch \n"
  23. f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
  24. f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n"
  25. f"and/or the torch.func migration guide for more details \n"
  26. f"https://pytorch.org/docs/master/func.migrating.html"
  27. )
  28. if replace_newlines:
  29. warning = warning.replace("\n", "")
  30. return warning
  31. def warn_deprecated(api, new_api=None):
  32. warning = get_warning(api, new_api, replace_newlines=True)
  33. warnings.warn(warning, stacklevel=2)
  34. def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
  35. api_name = functorch_api.__name__
  36. if torch_func_api is None:
  37. torch_func_api = getattr(_impl, api_name)
  38. warning = get_warning(api_name, new_api_name)
  39. warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, " ")
  40. warning_note = textwrap.indent(warning_note, " ")
  41. functorch_api.__doc__ = torch_func_api.__doc__ + warning_note
  42. def vmap(
  43. func: Callable,
  44. in_dims: in_dims_t = 0,
  45. out_dims: out_dims_t = 0,
  46. randomness: str = 'error',
  47. *,
  48. chunk_size=None) -> Callable:
  49. warn_deprecated('vmap', 'torch.vmap')
  50. return _vmap_impl.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
  51. def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
  52. warn_deprecated('grad')
  53. return _impl.grad(func, argnums, has_aux)
  54. def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
  55. warn_deprecated('grad_and_value')
  56. return _impl.grad_and_value(func, argnums, has_aux)
  57. def vjp(func: Callable, *primals, has_aux: bool = False):
  58. warn_deprecated('vjp')
  59. return _impl.vjp(func, *primals, has_aux=has_aux)
  60. def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
  61. warn_deprecated('jvp')
  62. return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
  63. def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
  64. chunk_size: Optional[int] = None,
  65. _preallocate_and_copy=False):
  66. warn_deprecated('jacrev')
  67. return _impl.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size,
  68. _preallocate_and_copy=_preallocate_and_copy)
  69. def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
  70. warn_deprecated('jacfwd')
  71. return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
  72. def hessian(func, argnums=0):
  73. warn_deprecated('hessian')
  74. return _impl.hessian(func, argnums=argnums)
  75. def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
  76. warn_deprecated('functionalize')
  77. return _impl.functionalize(func, remove=remove)
  78. def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
  79. warn_deprecated('make_functional', 'torch.func.functional_call')
  80. return _nn_impl.make_functional(model, disable_autograd_tracking)
  81. def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
  82. warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')
  83. return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
  84. def combine_state_for_ensemble(models):
  85. warn_deprecated('combine_state_for_ensemble', 'torch.func.stack_module_state')
  86. return _nn_impl.combine_state_for_ensemble(models)
  87. setup_docs(vmap, _vmap_impl.vmap, 'torch.vmap')
  88. setup_docs(grad)
  89. setup_docs(grad_and_value)
  90. setup_docs(vjp)
  91. setup_docs(jvp)
  92. setup_docs(jacrev)
  93. setup_docs(jacfwd)
  94. setup_docs(hessian)
  95. setup_docs(functionalize)
  96. setup_docs(make_functional, _nn_impl.make_functional,
  97. 'torch.func.functional_call')
  98. setup_docs(make_functional_with_buffers, _nn_impl.make_functional,
  99. 'torch.func.functional_call')
  100. setup_docs(combine_state_for_ensemble, _nn_impl.combine_state_for_ensemble,
  101. 'torch.func.stack_module_state')