autocast_mode.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. import functools
  3. import collections
  4. try:
  5. import numpy as np
  6. HAS_NUMPY = True
  7. except ModuleNotFoundError:
  8. np = None # type: ignore[assignment]
  9. from typing import Any
  10. __all__ = ["autocast", "custom_fwd", "custom_bwd"]
  11. class autocast(torch.amp.autocast_mode.autocast):
  12. r"""
  13. See :class:`torch.autocast`.
  14. ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
  15. """
  16. def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True):
  17. if torch._jit_internal.is_scripting():
  18. self._enabled = enabled
  19. self.device = "cuda"
  20. self.fast_dtype = dtype
  21. return
  22. super().__init__("cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
  23. def __enter__(self):
  24. if torch._jit_internal.is_scripting():
  25. return self
  26. return super().__enter__()
  27. # TODO: discuss a unified TorchScript-friendly API for autocast
  28. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
  29. if torch._jit_internal.is_scripting():
  30. return
  31. return super().__exit__(exc_type, exc_val, exc_tb)
  32. def __call__(self, func):
  33. if torch._jit_internal.is_scripting():
  34. return func
  35. return super().__call__(func)
  36. # Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
  37. # may be falsely detected as "Iterables."
  38. def _cast(value, dtype):
  39. if isinstance(value, torch.Tensor):
  40. is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
  41. return value.to(dtype) if is_eligible else value
  42. elif isinstance(value, str):
  43. return value
  44. elif HAS_NUMPY and isinstance(value, np.ndarray):
  45. return value
  46. elif isinstance(value, collections.abc.Mapping):
  47. return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
  48. elif isinstance(value, collections.abc.Iterable):
  49. iterable = map(lambda v: _cast(v, dtype), value)
  50. if isinstance(value, (list, tuple)):
  51. return type(value)(iterable)
  52. else:
  53. return iterable
  54. else:
  55. return value
  56. # custom_fwd is a decorator that may or may not be used with arguments, following
  57. # https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
  58. # this works:
  59. # @custom_fwd
  60. # def forward(...):
  61. # this also works:
  62. # @custom_fwd(cast_inputs=torch.float)
  63. # def forward(...):
  64. def custom_fwd(fwd=None, *, cast_inputs=None):
  65. """
  66. Helper decorator for ``forward`` methods of custom autograd functions (subclasses of
  67. :class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>` for more detail.
  68. Args:
  69. cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
  70. when ``forward`` runs in an autocast-enabled region, casts incoming
  71. floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
  72. then executes ``forward`` with autocast disabled.
  73. If ``None``, ``forward``'s internal ops execute with the current autocast state.
  74. .. note::
  75. If the decorated ``forward`` is called outside an autocast-enabled region,
  76. :func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
  77. """
  78. if fwd is None:
  79. return functools.partial(custom_fwd, cast_inputs=cast_inputs)
  80. @functools.wraps(fwd)
  81. def decorate_fwd(*args, **kwargs):
  82. args[0]._dtype = torch.get_autocast_gpu_dtype()
  83. if cast_inputs is None:
  84. args[0]._fwd_used_autocast = torch.is_autocast_enabled()
  85. return fwd(*args, **kwargs)
  86. else:
  87. autocast_context = torch.is_autocast_enabled()
  88. args[0]._fwd_used_autocast = False
  89. if autocast_context:
  90. with autocast(enabled=False):
  91. return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
  92. else:
  93. return fwd(*args, **kwargs)
  94. return decorate_fwd
  95. # Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
  96. # cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
  97. # cast_inputs supplied to custom_fwd.
  98. def custom_bwd(bwd):
  99. """
  100. Helper decorator for backward methods of custom autograd functions (subclasses of
  101. :class:`torch.autograd.Function`).
  102. Ensures that ``backward`` executes with the same autocast state as ``forward``.
  103. See the :ref:`example page<amp-custom-examples>` for more detail.
  104. """
  105. @functools.wraps(bwd)
  106. def decorate_bwd(*args, **kwargs):
  107. with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
  108. return bwd(*args, **kwargs)
  109. return decorate_bwd