123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- import torch
- import functools
- import warnings
- from typing import Any, Optional
- from torch.types import _dtype
- __all__ = ['autocast_decorator', 'autocast']
- def autocast_decorator(autocast_instance, func):
- @functools.wraps(func)
- def decorate_autocast(*args, **kwargs):
- with autocast_instance:
- return func(*args, **kwargs)
- decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode' # type: ignore[attr-defined]
- return decorate_autocast
- class autocast:
- r"""
- Instances of :class:`autocast` serve as context managers or decorators that
- allow regions of your script to run in mixed precision.
- In these regions, ops run in an op-specific dtype chosen by autocast
- to improve performance while maintaining accuracy.
- See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.
- When entering an autocast-enabled region, Tensors may be any type.
- You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
- :class:`autocast` should wrap only the forward pass(es) of your network, including the loss
- computation(s). Backward passes under autocast are not recommended.
- Backward ops run in the same type that autocast used for corresponding forward ops.
- Example for CUDA Devices::
- # Creates model and optimizer in default precision
- model = Net().cuda()
- optimizer = optim.SGD(model.parameters(), ...)
- for input, target in data:
- optimizer.zero_grad()
- # Enables autocasting for the forward pass (model + loss)
- with autocast():
- output = model(input)
- loss = loss_fn(output, target)
- # Exits the context manager before backward()
- loss.backward()
- optimizer.step()
- See the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
- in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
- :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
- class AutocastModel(nn.Module):
- ...
- @autocast()
- def forward(self, input):
- ...
- Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
- After returning to an autocast-disabled region, using them with floating-point
- Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s)
- produced in the autocast region back to ``float32`` (or other dtype if desired).
- If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
- and incurs no additional overhead.
- CUDA Example::
- # Creates some tensors in default dtype (here assumed to be float32)
- a_float32 = torch.rand((8, 8), device="cuda")
- b_float32 = torch.rand((8, 8), device="cuda")
- c_float32 = torch.rand((8, 8), device="cuda")
- d_float32 = torch.rand((8, 8), device="cuda")
- with autocast():
- # torch.mm is on autocast's list of ops that should run in float16.
- # Inputs are float32, but the op runs in float16 and produces float16 output.
- # No manual casts are required.
- e_float16 = torch.mm(a_float32, b_float32)
- # Also handles mixed input types
- f_float16 = torch.mm(d_float32, e_float16)
- # After exiting autocast, calls f_float16.float() to use with d_float32
- g_float32 = torch.mm(d_float32, f_float16.float())
- CPU Training Example::
- # Creates model and optimizer in default precision
- model = Net()
- optimizer = optim.SGD(model.parameters(), ...)
- for epoch in epochs:
- for input, target in data:
- optimizer.zero_grad()
- # Runs the forward pass with autocasting.
- with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
- output = model(input)
- loss = loss_fn(output, target)
- loss.backward()
- optimizer.step()
- CPU Inference Example::
- # Creates model in default precision
- model = Net().eval()
- with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
- for input in data:
- # Runs the forward pass with autocasting.
- output = model(input)
- CPU Inference Example with Jit Trace::
- class TestModel(nn.Module):
- def __init__(self, input_size, num_classes):
- super().__init__()
- self.fc1 = nn.Linear(input_size, num_classes)
- def forward(self, x):
- return self.fc1(x)
- input_size = 2
- num_classes = 2
- model = TestModel(input_size, num_classes).eval()
- # For now, we suggest to disable the Jit Autocast Pass,
- # As the issue: https://github.com/pytorch/pytorch/issues/75956
- torch._C._jit_set_autocast_mode(False)
- with torch.cpu.amp.autocast(cache_enabled=False):
- model = torch.jit.trace(model, torch.randn(1, input_size))
- model = torch.jit.freeze(model)
- # Models Run
- for _ in range(3):
- model(torch.randn(1, input_size))
- Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
- please file an issue.
- ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
- Locally disabling autocast can be useful, for example, if you want to force a subregion
- to run in a particular ``dtype``. Disabling autocast gives you explicit control over
- the execution type. In the subregion, inputs from the surrounding region
- should be cast to ``dtype`` before use::
- # Creates some tensors in default dtype (here assumed to be float32)
- a_float32 = torch.rand((8, 8), device="cuda")
- b_float32 = torch.rand((8, 8), device="cuda")
- c_float32 = torch.rand((8, 8), device="cuda")
- d_float32 = torch.rand((8, 8), device="cuda")
- with autocast():
- e_float16 = torch.mm(a_float32, b_float32)
- with autocast(enabled=False):
- # Calls e_float16.float() to ensure float32 execution
- # (necessary because e_float16 was created in an autocasted region)
- f_float32 = torch.mm(c_float32, e_float16.float())
- # No manual casts are required when re-entering the autocast-enabled region.
- # torch.mm again runs in float16 and produces float16 output, regardless of input types.
- g_float16 = torch.mm(d_float32, f_float32)
- The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator
- must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and
- :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
- (see :ref:`Working with Multiple GPUs<amp-multigpu>`).
- Args:
- device_type(str, required): Whether to use 'cuda' or 'cpu' device
- enabled(bool, optional): Whether autocasting should be enabled in the region.
- Default: ``True``
- dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16.
- cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
- Default: ``True``
- """
- def __init__(self, device_type : str,
- dtype : Optional[_dtype] = None,
- enabled : bool = True,
- cache_enabled : Optional[bool] = None):
- if torch._jit_internal.is_scripting():
- self._enabled = enabled
- self.device = device_type
- self.fast_dtype = dtype
- # TODO: support get_autocast_gpu/cpu_dtype
- assert dtype is not None
- return
- self.device = device_type
- if self.device == 'cuda':
- self.fast_dtype = torch.get_autocast_gpu_dtype()
- elif self.device == 'cpu':
- self.fast_dtype = torch.get_autocast_cpu_dtype()
- elif self.device == 'xpu':
- self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
- elif self.device == 'hpu':
- self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
- else:
- raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
- self._cache_enabled = torch.is_autocast_cache_enabled()
- if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
- warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
- enabled = False
- if dtype is not None:
- self.fast_dtype = dtype
- if cache_enabled is not None:
- self._cache_enabled = cache_enabled
- if self.device == 'cpu':
- supported_dtype = [torch.bfloat16]
- if self.fast_dtype not in supported_dtype:
- error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n'
- error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.'
- warnings.warn(error_message)
- enabled = False
- elif self.device == 'xpu':
- supported_dtype = [torch.bfloat16, torch.float16]
- if self.fast_dtype not in supported_dtype:
- error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n'
- error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.'
- warnings.warn(error_message)
- enabled = False
- elif self.device == 'hpu':
- supported_dtype = [torch.bfloat16, torch.float16]
- if self.fast_dtype not in supported_dtype:
- error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n'
- error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
- warnings.warn(error_message)
- enabled = False
- elif self.device == 'cuda':
- if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
- raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
- self._enabled = enabled
- def __enter__(self):
- if torch._jit_internal.is_scripting():
- assert self.fast_dtype is not None
- return self
- self.prev_cache_enabled = torch.is_autocast_cache_enabled()
- if self.device == 'cpu':
- self.prev = torch.is_autocast_cpu_enabled()
- self.prev_fastdtype = torch.get_autocast_cpu_dtype()
- torch.set_autocast_cpu_enabled(self._enabled)
- torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
- torch.autocast_increment_nesting()
- elif self.device == 'xpu':
- self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
- self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
- torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
- torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
- torch.autocast_increment_nesting()
- elif self.device == 'hpu':
- self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
- self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
- torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
- torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
- torch.autocast_increment_nesting()
- else:
- self.prev = torch.is_autocast_enabled()
- self.prev_fastdtype = torch.get_autocast_gpu_dtype()
- torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
- torch.set_autocast_enabled(self._enabled)
- torch.autocast_increment_nesting()
- torch.set_autocast_cache_enabled(self._cache_enabled)
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
- if torch._jit_internal.is_scripting():
- return
- # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
- if self.device == 'cpu':
- if torch.autocast_decrement_nesting() == 0:
- torch.clear_autocast_cache()
- torch.set_autocast_cpu_enabled(self.prev)
- torch.set_autocast_cpu_dtype(self.prev_fastdtype)
- elif self.device == 'xpu':
- if torch.autocast_decrement_nesting() == 0:
- torch.clear_autocast_cache()
- torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
- torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
- elif self.device == 'hpu':
- if torch.autocast_decrement_nesting() == 0:
- torch.clear_autocast_cache()
- torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
- torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
- else:
- if torch.autocast_decrement_nesting() == 0:
- torch.clear_autocast_cache()
- torch.set_autocast_enabled(self.prev)
- torch.set_autocast_gpu_dtype(self.prev_fastdtype)
- torch.set_autocast_cache_enabled(self.prev_cache_enabled)
- return False
- def __call__(self, func):
- if torch._jit_internal.is_scripting():
- return func
- return autocast_decorator(self, func)
|