autocast_mode.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import torch
  2. import functools
  3. import warnings
  4. from typing import Any, Optional
  5. from torch.types import _dtype
  6. __all__ = ['autocast_decorator', 'autocast']
  7. def autocast_decorator(autocast_instance, func):
  8. @functools.wraps(func)
  9. def decorate_autocast(*args, **kwargs):
  10. with autocast_instance:
  11. return func(*args, **kwargs)
  12. decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode' # type: ignore[attr-defined]
  13. return decorate_autocast
  14. class autocast:
  15. r"""
  16. Instances of :class:`autocast` serve as context managers or decorators that
  17. allow regions of your script to run in mixed precision.
  18. In these regions, ops run in an op-specific dtype chosen by autocast
  19. to improve performance while maintaining accuracy.
  20. See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.
  21. When entering an autocast-enabled region, Tensors may be any type.
  22. You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
  23. :class:`autocast` should wrap only the forward pass(es) of your network, including the loss
  24. computation(s). Backward passes under autocast are not recommended.
  25. Backward ops run in the same type that autocast used for corresponding forward ops.
  26. Example for CUDA Devices::
  27. # Creates model and optimizer in default precision
  28. model = Net().cuda()
  29. optimizer = optim.SGD(model.parameters(), ...)
  30. for input, target in data:
  31. optimizer.zero_grad()
  32. # Enables autocasting for the forward pass (model + loss)
  33. with autocast():
  34. output = model(input)
  35. loss = loss_fn(output, target)
  36. # Exits the context manager before backward()
  37. loss.backward()
  38. optimizer.step()
  39. See the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
  40. in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
  41. :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
  42. class AutocastModel(nn.Module):
  43. ...
  44. @autocast()
  45. def forward(self, input):
  46. ...
  47. Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
  48. After returning to an autocast-disabled region, using them with floating-point
  49. Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s)
  50. produced in the autocast region back to ``float32`` (or other dtype if desired).
  51. If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
  52. and incurs no additional overhead.
  53. CUDA Example::
  54. # Creates some tensors in default dtype (here assumed to be float32)
  55. a_float32 = torch.rand((8, 8), device="cuda")
  56. b_float32 = torch.rand((8, 8), device="cuda")
  57. c_float32 = torch.rand((8, 8), device="cuda")
  58. d_float32 = torch.rand((8, 8), device="cuda")
  59. with autocast():
  60. # torch.mm is on autocast's list of ops that should run in float16.
  61. # Inputs are float32, but the op runs in float16 and produces float16 output.
  62. # No manual casts are required.
  63. e_float16 = torch.mm(a_float32, b_float32)
  64. # Also handles mixed input types
  65. f_float16 = torch.mm(d_float32, e_float16)
  66. # After exiting autocast, calls f_float16.float() to use with d_float32
  67. g_float32 = torch.mm(d_float32, f_float16.float())
  68. CPU Training Example::
  69. # Creates model and optimizer in default precision
  70. model = Net()
  71. optimizer = optim.SGD(model.parameters(), ...)
  72. for epoch in epochs:
  73. for input, target in data:
  74. optimizer.zero_grad()
  75. # Runs the forward pass with autocasting.
  76. with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
  77. output = model(input)
  78. loss = loss_fn(output, target)
  79. loss.backward()
  80. optimizer.step()
  81. CPU Inference Example::
  82. # Creates model in default precision
  83. model = Net().eval()
  84. with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
  85. for input in data:
  86. # Runs the forward pass with autocasting.
  87. output = model(input)
  88. CPU Inference Example with Jit Trace::
  89. class TestModel(nn.Module):
  90. def __init__(self, input_size, num_classes):
  91. super().__init__()
  92. self.fc1 = nn.Linear(input_size, num_classes)
  93. def forward(self, x):
  94. return self.fc1(x)
  95. input_size = 2
  96. num_classes = 2
  97. model = TestModel(input_size, num_classes).eval()
  98. # For now, we suggest to disable the Jit Autocast Pass,
  99. # As the issue: https://github.com/pytorch/pytorch/issues/75956
  100. torch._C._jit_set_autocast_mode(False)
  101. with torch.cpu.amp.autocast(cache_enabled=False):
  102. model = torch.jit.trace(model, torch.randn(1, input_size))
  103. model = torch.jit.freeze(model)
  104. # Models Run
  105. for _ in range(3):
  106. model(torch.randn(1, input_size))
  107. Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
  108. please file an issue.
  109. ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
  110. Locally disabling autocast can be useful, for example, if you want to force a subregion
  111. to run in a particular ``dtype``. Disabling autocast gives you explicit control over
  112. the execution type. In the subregion, inputs from the surrounding region
  113. should be cast to ``dtype`` before use::
  114. # Creates some tensors in default dtype (here assumed to be float32)
  115. a_float32 = torch.rand((8, 8), device="cuda")
  116. b_float32 = torch.rand((8, 8), device="cuda")
  117. c_float32 = torch.rand((8, 8), device="cuda")
  118. d_float32 = torch.rand((8, 8), device="cuda")
  119. with autocast():
  120. e_float16 = torch.mm(a_float32, b_float32)
  121. with autocast(enabled=False):
  122. # Calls e_float16.float() to ensure float32 execution
  123. # (necessary because e_float16 was created in an autocasted region)
  124. f_float32 = torch.mm(c_float32, e_float16.float())
  125. # No manual casts are required when re-entering the autocast-enabled region.
  126. # torch.mm again runs in float16 and produces float16 output, regardless of input types.
  127. g_float16 = torch.mm(d_float32, f_float32)
  128. The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator
  129. must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and
  130. :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
  131. (see :ref:`Working with Multiple GPUs<amp-multigpu>`).
  132. Args:
  133. device_type(str, required): Whether to use 'cuda' or 'cpu' device
  134. enabled(bool, optional): Whether autocasting should be enabled in the region.
  135. Default: ``True``
  136. dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16.
  137. cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
  138. Default: ``True``
  139. """
  140. def __init__(self, device_type : str,
  141. dtype : Optional[_dtype] = None,
  142. enabled : bool = True,
  143. cache_enabled : Optional[bool] = None):
  144. if torch._jit_internal.is_scripting():
  145. self._enabled = enabled
  146. self.device = device_type
  147. self.fast_dtype = dtype
  148. # TODO: support get_autocast_gpu/cpu_dtype
  149. assert dtype is not None
  150. return
  151. self.device = device_type
  152. if self.device == 'cuda':
  153. self.fast_dtype = torch.get_autocast_gpu_dtype()
  154. elif self.device == 'cpu':
  155. self.fast_dtype = torch.get_autocast_cpu_dtype()
  156. elif self.device == 'xpu':
  157. self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
  158. elif self.device == 'hpu':
  159. self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
  160. else:
  161. raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
  162. self._cache_enabled = torch.is_autocast_cache_enabled()
  163. if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
  164. warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
  165. enabled = False
  166. if dtype is not None:
  167. self.fast_dtype = dtype
  168. if cache_enabled is not None:
  169. self._cache_enabled = cache_enabled
  170. if self.device == 'cpu':
  171. supported_dtype = [torch.bfloat16]
  172. if self.fast_dtype not in supported_dtype:
  173. error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n'
  174. error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.'
  175. warnings.warn(error_message)
  176. enabled = False
  177. elif self.device == 'xpu':
  178. supported_dtype = [torch.bfloat16, torch.float16]
  179. if self.fast_dtype not in supported_dtype:
  180. error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n'
  181. error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.'
  182. warnings.warn(error_message)
  183. enabled = False
  184. elif self.device == 'hpu':
  185. supported_dtype = [torch.bfloat16, torch.float16]
  186. if self.fast_dtype not in supported_dtype:
  187. error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n'
  188. error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
  189. warnings.warn(error_message)
  190. enabled = False
  191. elif self.device == 'cuda':
  192. if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
  193. raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
  194. self._enabled = enabled
  195. def __enter__(self):
  196. if torch._jit_internal.is_scripting():
  197. assert self.fast_dtype is not None
  198. return self
  199. self.prev_cache_enabled = torch.is_autocast_cache_enabled()
  200. if self.device == 'cpu':
  201. self.prev = torch.is_autocast_cpu_enabled()
  202. self.prev_fastdtype = torch.get_autocast_cpu_dtype()
  203. torch.set_autocast_cpu_enabled(self._enabled)
  204. torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
  205. torch.autocast_increment_nesting()
  206. elif self.device == 'xpu':
  207. self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
  208. self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
  209. torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
  210. torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
  211. torch.autocast_increment_nesting()
  212. elif self.device == 'hpu':
  213. self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
  214. self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
  215. torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
  216. torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
  217. torch.autocast_increment_nesting()
  218. else:
  219. self.prev = torch.is_autocast_enabled()
  220. self.prev_fastdtype = torch.get_autocast_gpu_dtype()
  221. torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
  222. torch.set_autocast_enabled(self._enabled)
  223. torch.autocast_increment_nesting()
  224. torch.set_autocast_cache_enabled(self._cache_enabled)
  225. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
  226. if torch._jit_internal.is_scripting():
  227. return
  228. # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
  229. if self.device == 'cpu':
  230. if torch.autocast_decrement_nesting() == 0:
  231. torch.clear_autocast_cache()
  232. torch.set_autocast_cpu_enabled(self.prev)
  233. torch.set_autocast_cpu_dtype(self.prev_fastdtype)
  234. elif self.device == 'xpu':
  235. if torch.autocast_decrement_nesting() == 0:
  236. torch.clear_autocast_cache()
  237. torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
  238. torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
  239. elif self.device == 'hpu':
  240. if torch.autocast_decrement_nesting() == 0:
  241. torch.clear_autocast_cache()
  242. torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
  243. torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
  244. else:
  245. if torch.autocast_decrement_nesting() == 0:
  246. torch.clear_autocast_cache()
  247. torch.set_autocast_enabled(self.prev)
  248. torch.set_autocast_gpu_dtype(self.prev_fastdtype)
  249. torch.set_autocast_cache_enabled(self.prev_cache_enabled)
  250. return False
  251. def __call__(self, func):
  252. if torch._jit_internal.is_scripting():
  253. return func
  254. return autocast_decorator(self, func)