123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- import logging
- from collections import abc, defaultdict
- from typing import Dict, List, Optional, Union
- import torch
- import torch.distributed as dist
- from torch.cuda import FloatTensor
- from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
- from torch.distributed.distributed_c10d import ProcessGroup
- from torch.optim.sgd import SGD
- log = logging.getLogger(__name__)
- def _refresh_per_optimizer_state():
- return {"stage": OptState.READY, "found_inf_per_device": {}}
- def _is_supported_device(tensor: torch.Tensor):
- return tensor.is_cuda or tensor.device.type in ("xla", "cpu")
- class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
- """
- Lazily serves tensor to request device. This class extends
- _MultiDeviceReplicator to allow support for "cpu" as a device.
- """
- def __init__(self, master_tensor: torch.Tensor) -> None:
- assert _is_supported_device(master_tensor)
- self.master = master_tensor
- self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
- class ShardedGradScaler(GradScaler):
- """
- ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
- functionality from GradScaler:
- * Suports Pytorch DDP and FSDP implementations
- * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
- * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
- * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
- nodes
- Example::
- # Creates a ShardedGradScaler once at the beginning of training.
- scaler = ShardedGradScaler()
- for epoch in epochs:
- for input, target in data:
- optimizer.zero_grad()
- output = model(input)
- loss = loss_fn(output, target)
- # Scales loss. Calls backward() on scaled loss to create scaled gradients.
- scaler.scale(loss).backward()
- # scaler.step() first unscales gradients of the optimizer's params.
- # If gradients don't contain infs/NaNs, optimizer.step() is then called,
- # otherwise, optimizer.step() is skipped.
- scaler.step(optimizer)
- # Updates the scale for next iteration.
- scaler.update()
- See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
- Args:
- init_scale (float, optional, default=2.**16): Initial scale factor.
- growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
- :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
- backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
- :meth:`update` if inf/NaN gradients occur in an iteration.
- growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
- that must occur for the scale to be multiplied by ``growth_factor``.
- enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
- invokes the underlying ``optimizer.step()``, and other methods become no-ops.
- Default: ``True``
- process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
- process group for sharding
- """
- def __init__(
- self,
- init_scale: float = 2.0**16,
- backoff_factor: float = 0.5,
- growth_factor: float = 2.0,
- growth_interval: int = 2000,
- enabled: bool = True,
- process_group: Optional[ProcessGroup] = dist.group.WORLD,
- ):
- super().__init__(
- init_scale=init_scale,
- backoff_factor=backoff_factor,
- growth_factor=growth_factor,
- growth_interval=growth_interval,
- enabled=enabled,
- )
- if self._enabled:
- self.process_group = process_group
- self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
- def scale(
- self, outputs: Union[torch.Tensor, List[torch.Tensor]]
- ) -> Union[torch.Tensor, List[torch.Tensor]]:
- if not self._enabled:
- return outputs
- if isinstance(outputs, torch.Tensor):
- assert _is_supported_device(outputs)
- if self._scale is None:
- self._lazy_init_scale_growth_tracker(outputs.device)
- assert self._scale is not None
- scaled_output = outputs * self._scale.to(
- device=outputs.device, non_blocking=True
- )
-
-
-
- return scaled_output.type(outputs.dtype)
- stash: List[_GeneralMultiDeviceReplicator] = []
- def apply_scale(
- val: Union[torch.Tensor, abc.Iterable]
- ) -> Union[torch.Tensor, abc.Iterable]:
- if isinstance(val, torch.Tensor):
- assert _is_supported_device(val)
- if len(stash) == 0:
- if self._scale is None:
- self._lazy_init_scale_growth_tracker(val.device)
- assert self._scale is not None
- stash.append(_GeneralMultiDeviceReplicator(self._scale))
- scaled_val = val * stash[0].get(val.device)
-
-
-
- return scaled_val.type(val.dtype)
- elif isinstance(val, abc.Iterable):
- iterator = map(apply_scale, val)
- if isinstance(val, (list, tuple)):
- return type(val)(iterator)
- else:
- return iterator
- else:
- raise ValueError("outputs must be a Tensor or an iterable of Tensors")
- return apply_scale(outputs)
- def _foreach_non_finite_check_and_unscale_cpu_(
- self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor
- ) -> None:
- if len(grads) == 0:
- return
- assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
- assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
- expected_device = grads[0].device
- for grad in grads:
- for tensor in grad:
- if tensor.device != expected_device:
- log.error(
- "tensor device is %s and expected device is %s"
- % (tensor.device, expected_device)
- )
- raise ValueError("Gradients must be on the same device.")
-
-
-
- if (
- torch.isinf(tensor).any().item() is True
- or torch.isnan(tensor).any().item() is True
- ):
- found_inf.data = torch.tensor([1.0])
- break
- else:
- tensor.data *= inv_scale.item()
- def _unscale_grads_(
- self,
- optimizer: SGD,
- inv_scale: torch.Tensor,
- found_inf: torch.Tensor,
- allow_fp16: bool = True,
- ) -> Dict[torch.device, torch.Tensor]:
- per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
- per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
-
-
-
-
-
- per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
- with torch.no_grad():
- for group in optimizer.param_groups:
- for param in group["params"]:
- if param.grad is None:
- continue
- if (not allow_fp16) and param.grad.dtype == torch.float16:
- raise ValueError("Attempting to unscale FP16 gradients.")
- if param.grad.is_sparse:
-
-
-
-
- if param.grad.dtype is torch.float16:
-
- param_grad_fp32 = param.grad.type(torch.float32).coalesce()
- param.grad = param_grad_fp32.type(torch.float16)
- to_unscale = param.grad._values()
- else:
- to_unscale = param.grad
- per_device_and_dtype_grads[to_unscale.device][
- to_unscale.dtype
- ].append(to_unscale)
- for device, per_dtype_grads in per_device_and_dtype_grads.items():
- for grads in per_dtype_grads.values():
- if grads[0].device.type == "cpu":
- self._foreach_non_finite_check_and_unscale_cpu_(
- grads,
- per_device_found_inf.get(device),
- per_device_inv_scale.get(device),
- )
- else:
- torch._amp_foreach_non_finite_check_and_unscale_(
- grads,
- per_device_found_inf.get(device),
- per_device_inv_scale.get(device),
- )
- return per_device_found_inf._per_device_tensors
- def unscale_(self, optimizer: SGD) -> None:
- if not self._enabled:
- return
- self._check_scale_growth_tracker("unscale_")
- optimizer_state = self._per_optimizer_states[id(optimizer)]
- if optimizer_state["stage"] is OptState.UNSCALED:
- raise RuntimeError(
- "unscale_() has already been called on this optimizer since the last update()."
- )
- elif optimizer_state["stage"] is OptState.STEPPED:
- raise RuntimeError("unscale_() is being called after step().")
-
- assert self._scale is not None
- inv_scale = self._scale.double().reciprocal().float()
- found_inf = torch.full(
- (1,), 0.0, dtype=torch.float32, device=self._scale.device
- )
- optimizer_state["found_inf_per_device"] = self._unscale_grads_(
- optimizer, inv_scale, found_inf, True
- )
- optimizer_state["stage"] = OptState.UNSCALED
-
- optimizer_state = self._per_optimizer_states[id(optimizer)]
- future_handles = []
- for v in optimizer_state["found_inf_per_device"].values():
- if v.device.type == "cpu":
- v_on_cuda = v.cuda()
- future_handles.append(
- dist.all_reduce(
- v_on_cuda, async_op=True, group=self.process_group
- ).get_future()
- )
- v.copy_(v_on_cuda.cpu())
- else:
- future_handles.append(
- dist.all_reduce(
- v, async_op=True, group=self.process_group
- ).get_future()
- )
-
- if future_handles:
- torch.futures.wait_all(future_handles)
- def step(self, optimizer: SGD, *args, **kwargs) -> Optional[float]:
- return super().step(optimizer, *args, **kwargs)
- def _amp_update_scale_cpu_(self, found_inf) -> None:
- """
- If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
- Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
- """
- if found_inf.item() >= 1.0:
- self._scale *= self._backoff_factor
- self._growth_tracker = 0
- else:
- successful = self._growth_tracker + 1
- if successful == self._growth_interval:
- self._scale *= self._growth_factor
- self._growth_tracker = 0
- else:
- self._growth_tracker = successful
- def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None:
- """
- Updates the scale factor.
- If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
- to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
- the scale is multiplied by ``growth_factor`` to increase it.
- Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
- used directly, it's used to fill GradScaler's internal scale tensor. So if
- ``new_scale`` was a tensor, later in-place changes to that tensor will not further
- affect the scale GradScaler uses internally.)
- Args:
- new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
- .. warning::
- :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
- been invoked for all optimizers used this iteration.
- """
- if not self._enabled:
- return
- _scale, _growth_tracker = self._check_scale_growth_tracker("update")
- if new_scale is not None:
-
- if isinstance(new_scale, float):
- self._scale.fill_(new_scale)
- else:
- reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
- assert isinstance(new_scale, torch.cuda.FloatTensor), reason
- assert new_scale.numel() == 1, reason
- assert new_scale.requires_grad is False, reason
- self._scale.copy_(new_scale)
- else:
-
-
- found_infs = [
- found_inf.to(device=_scale.device, non_blocking=True)
- for state in self._per_optimizer_states.values()
- for found_inf in state["found_inf_per_device"].values()
- ]
- assert len(found_infs) > 0, "No inf checks were recorded prior to update."
- found_inf_combined = found_infs[0]
- if len(found_infs) > 1:
- for i in range(1, len(found_infs)):
- found_inf_combined += found_infs[i]
- if _scale.device.type == "cpu":
- self._amp_update_scale_cpu_(found_inf_combined)
- else:
- torch._amp_update_scale_(
- self._scale,
- self._growth_tracker,
- found_inf_combined,
- self._growth_factor,
- self._backoff_factor,
- self._growth_interval,
- )
-
- self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|