sharded_grad_scaler.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import logging
  2. from collections import abc, defaultdict
  3. from typing import Dict, List, Optional, Union
  4. import torch
  5. import torch.distributed as dist
  6. from torch.cuda import FloatTensor # type: ignore[attr-defined]
  7. from torch.cuda.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState
  8. from torch.distributed.distributed_c10d import ProcessGroup
  9. from torch.optim.sgd import SGD
  10. log = logging.getLogger(__name__)
  11. def _refresh_per_optimizer_state():
  12. return {"stage": OptState.READY, "found_inf_per_device": {}}
  13. def _is_supported_device(tensor: torch.Tensor):
  14. return tensor.is_cuda or tensor.device.type in ("xla", "cpu")
  15. class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
  16. """
  17. Lazily serves tensor to request device. This class extends
  18. _MultiDeviceReplicator to allow support for "cpu" as a device.
  19. """
  20. def __init__(self, master_tensor: torch.Tensor) -> None:
  21. assert _is_supported_device(master_tensor)
  22. self.master = master_tensor
  23. self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
  24. class ShardedGradScaler(GradScaler):
  25. """
  26. ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends
  27. functionality from GradScaler:
  28. * Suports Pytorch DDP and FSDP implementations
  29. * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP])
  30. * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns
  31. * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across
  32. nodes
  33. Example::
  34. # Creates a ShardedGradScaler once at the beginning of training.
  35. scaler = ShardedGradScaler()
  36. for epoch in epochs:
  37. for input, target in data:
  38. optimizer.zero_grad()
  39. output = model(input)
  40. loss = loss_fn(output, target)
  41. # Scales loss. Calls backward() on scaled loss to create scaled gradients.
  42. scaler.scale(loss).backward()
  43. # scaler.step() first unscales gradients of the optimizer's params.
  44. # If gradients don't contain infs/NaNs, optimizer.step() is then called,
  45. # otherwise, optimizer.step() is skipped.
  46. scaler.step(optimizer)
  47. # Updates the scale for next iteration.
  48. scaler.update()
  49. See :class:`GradScaler` for explanation of scaling/unscaling and more use cases.
  50. Args:
  51. init_scale (float, optional, default=2.**16): Initial scale factor.
  52. growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
  53. :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
  54. backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
  55. :meth:`update` if inf/NaN gradients occur in an iteration.
  56. growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
  57. that must occur for the scale to be multiplied by ``growth_factor``.
  58. enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply
  59. invokes the underlying ``optimizer.step()``, and other methods become no-ops.
  60. Default: ``True``
  61. process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD):
  62. process group for sharding
  63. """
  64. def __init__(
  65. self,
  66. init_scale: float = 2.0**16,
  67. backoff_factor: float = 0.5,
  68. growth_factor: float = 2.0,
  69. growth_interval: int = 2000,
  70. enabled: bool = True,
  71. process_group: Optional[ProcessGroup] = dist.group.WORLD,
  72. ):
  73. super().__init__(
  74. init_scale=init_scale,
  75. backoff_factor=backoff_factor,
  76. growth_factor=growth_factor,
  77. growth_interval=growth_interval,
  78. enabled=enabled,
  79. )
  80. if self._enabled:
  81. self.process_group = process_group
  82. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
  83. def scale(
  84. self, outputs: Union[torch.Tensor, List[torch.Tensor]]
  85. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  86. if not self._enabled:
  87. return outputs
  88. if isinstance(outputs, torch.Tensor):
  89. assert _is_supported_device(outputs)
  90. if self._scale is None:
  91. self._lazy_init_scale_growth_tracker(outputs.device)
  92. assert self._scale is not None
  93. scaled_output = outputs * self._scale.to(
  94. device=outputs.device, non_blocking=True
  95. )
  96. # Here we ensure the return dtype is the same as the outputs dtype.
  97. # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
  98. # format (fp16, bf16) and so the scaled loss should be of the same dtype.
  99. return scaled_output.type(outputs.dtype)
  100. stash: List[_GeneralMultiDeviceReplicator] = []
  101. def apply_scale(
  102. val: Union[torch.Tensor, abc.Iterable]
  103. ) -> Union[torch.Tensor, abc.Iterable]:
  104. if isinstance(val, torch.Tensor):
  105. assert _is_supported_device(val)
  106. if len(stash) == 0:
  107. if self._scale is None:
  108. self._lazy_init_scale_growth_tracker(val.device)
  109. assert self._scale is not None
  110. stash.append(_GeneralMultiDeviceReplicator(self._scale))
  111. scaled_val = val * stash[0].get(val.device)
  112. # Here we ensure the return dtype is the same as the outputs dtype.
  113. # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision
  114. # format (fp16, bf16) and so the scaled loss should be of the same dtype.
  115. return scaled_val.type(val.dtype)
  116. elif isinstance(val, abc.Iterable):
  117. iterator = map(apply_scale, val)
  118. if isinstance(val, (list, tuple)):
  119. return type(val)(iterator)
  120. else:
  121. return iterator
  122. else:
  123. raise ValueError("outputs must be a Tensor or an iterable of Tensors")
  124. return apply_scale(outputs) # type: ignore[return-value]
  125. def _foreach_non_finite_check_and_unscale_cpu_(
  126. self, grads: List, found_inf: torch.Tensor, inv_scale: torch.Tensor
  127. ) -> None:
  128. if len(grads) == 0:
  129. return
  130. assert inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."
  131. assert found_inf.numel() == 1, "found_inf must be a 1-element tensor."
  132. expected_device = grads[0].device
  133. for grad in grads:
  134. for tensor in grad:
  135. if tensor.device != expected_device:
  136. log.error(
  137. "tensor device is %s and expected device is %s"
  138. % (tensor.device, expected_device)
  139. )
  140. raise ValueError("Gradients must be on the same device.")
  141. # check for non_overlapping_and_dense doesn't exist in the python world
  142. # as remarked here https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/AmpKernels.cu#L108
  143. # we assume tensor is not MTA(multi tensor apply) safe. iterate through each item regardless of dtype
  144. if (
  145. torch.isinf(tensor).any().item() is True
  146. or torch.isnan(tensor).any().item() is True
  147. ):
  148. found_inf.data = torch.tensor([1.0])
  149. break
  150. else:
  151. tensor.data *= inv_scale.item()
  152. def _unscale_grads_(
  153. self,
  154. optimizer: SGD,
  155. inv_scale: torch.Tensor,
  156. found_inf: torch.Tensor,
  157. allow_fp16: bool = True,
  158. ) -> Dict[torch.device, torch.Tensor]:
  159. per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale)
  160. per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf)
  161. # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
  162. # There could be thousands of grads, so we'd like to iterate through them just once.
  163. # However, we don't know their devices or dtypes in advance.
  164. # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
  165. # Google says mypy struggles with defaultdicts type annotations.
  166. per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
  167. with torch.no_grad():
  168. for group in optimizer.param_groups:
  169. for param in group["params"]:
  170. if param.grad is None:
  171. continue
  172. if (not allow_fp16) and param.grad.dtype == torch.float16:
  173. raise ValueError("Attempting to unscale FP16 gradients.")
  174. if param.grad.is_sparse:
  175. # is_coalesced() == False means the sparse grad has values with duplicate indices.
  176. # coalesce() deduplicates indices and adds all values that have the same index.
  177. # For scaled fp16 values, there's a good chance coalescing will cause overflow,
  178. # so we should check the coalesced _values().
  179. if param.grad.dtype is torch.float16:
  180. # coalesce is not suported in torch.float16
  181. param_grad_fp32 = param.grad.type(torch.float32).coalesce()
  182. param.grad = param_grad_fp32.type(torch.float16)
  183. to_unscale = param.grad._values()
  184. else:
  185. to_unscale = param.grad
  186. per_device_and_dtype_grads[to_unscale.device][
  187. to_unscale.dtype
  188. ].append(to_unscale)
  189. for device, per_dtype_grads in per_device_and_dtype_grads.items():
  190. for grads in per_dtype_grads.values():
  191. if grads[0].device.type == "cpu":
  192. self._foreach_non_finite_check_and_unscale_cpu_(
  193. grads,
  194. per_device_found_inf.get(device),
  195. per_device_inv_scale.get(device),
  196. )
  197. else:
  198. torch._amp_foreach_non_finite_check_and_unscale_(
  199. grads,
  200. per_device_found_inf.get(device),
  201. per_device_inv_scale.get(device),
  202. )
  203. return per_device_found_inf._per_device_tensors
  204. def unscale_(self, optimizer: SGD) -> None:
  205. if not self._enabled:
  206. return
  207. self._check_scale_growth_tracker("unscale_")
  208. optimizer_state = self._per_optimizer_states[id(optimizer)]
  209. if optimizer_state["stage"] is OptState.UNSCALED:
  210. raise RuntimeError(
  211. "unscale_() has already been called on this optimizer since the last update()."
  212. )
  213. elif optimizer_state["stage"] is OptState.STEPPED:
  214. raise RuntimeError("unscale_() is being called after step().")
  215. # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
  216. assert self._scale is not None
  217. inv_scale = self._scale.double().reciprocal().float()
  218. found_inf = torch.full(
  219. (1,), 0.0, dtype=torch.float32, device=self._scale.device
  220. )
  221. optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  222. optimizer, inv_scale, found_inf, True
  223. )
  224. optimizer_state["stage"] = OptState.UNSCALED
  225. # Synchronize the detected inf across the ranks
  226. optimizer_state = self._per_optimizer_states[id(optimizer)]
  227. future_handles = []
  228. for v in optimizer_state["found_inf_per_device"].values():
  229. if v.device.type == "cpu":
  230. v_on_cuda = v.cuda()
  231. future_handles.append(
  232. dist.all_reduce(
  233. v_on_cuda, async_op=True, group=self.process_group
  234. ).get_future()
  235. )
  236. v.copy_(v_on_cuda.cpu())
  237. else:
  238. future_handles.append(
  239. dist.all_reduce(
  240. v, async_op=True, group=self.process_group
  241. ).get_future()
  242. )
  243. # Make sure that the calls are done before moving out.
  244. if future_handles:
  245. torch.futures.wait_all(future_handles)
  246. def step(self, optimizer: SGD, *args, **kwargs) -> Optional[float]:
  247. return super().step(optimizer, *args, **kwargs)
  248. def _amp_update_scale_cpu_(self, found_inf) -> None:
  249. """
  250. If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
  251. Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
  252. """
  253. if found_inf.item() >= 1.0:
  254. self._scale *= self._backoff_factor # type: ignore[arg-type]
  255. self._growth_tracker = 0
  256. else:
  257. successful = self._growth_tracker + 1 # type: ignore[operator]
  258. if successful == self._growth_interval: # type: ignore[arg-type]
  259. self._scale *= self._growth_factor # type: ignore[arg-type]
  260. self._growth_tracker = 0
  261. else:
  262. self._growth_tracker = successful
  263. def update(self, new_scale: Optional[Union[float, FloatTensor]] = None) -> None:
  264. """
  265. Updates the scale factor.
  266. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
  267. to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
  268. the scale is multiplied by ``growth_factor`` to increase it.
  269. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
  270. used directly, it's used to fill GradScaler's internal scale tensor. So if
  271. ``new_scale`` was a tensor, later in-place changes to that tensor will not further
  272. affect the scale GradScaler uses internally.)
  273. Args:
  274. new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor.
  275. .. warning::
  276. :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
  277. been invoked for all optimizers used this iteration.
  278. """
  279. if not self._enabled:
  280. return
  281. _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated]
  282. if new_scale is not None:
  283. # Accept a new user-defined scale.
  284. if isinstance(new_scale, float):
  285. self._scale.fill_(new_scale) # type: ignore[union-attr]
  286. else:
  287. reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
  288. assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined]
  289. assert new_scale.numel() == 1, reason
  290. assert new_scale.requires_grad is False, reason
  291. self._scale.copy_(new_scale) # type: ignore[union-attr]
  292. else:
  293. # Consume shared inf/nan data collected from optimizers to update the scale.
  294. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
  295. found_infs = [
  296. found_inf.to(device=_scale.device, non_blocking=True)
  297. for state in self._per_optimizer_states.values()
  298. for found_inf in state["found_inf_per_device"].values()
  299. ]
  300. assert len(found_infs) > 0, "No inf checks were recorded prior to update."
  301. found_inf_combined = found_infs[0]
  302. if len(found_infs) > 1:
  303. for i in range(1, len(found_infs)):
  304. found_inf_combined += found_infs[i]
  305. if _scale.device.type == "cpu":
  306. self._amp_update_scale_cpu_(found_inf_combined)
  307. else:
  308. torch._amp_update_scale_(
  309. self._scale, # type: ignore[arg-type]
  310. self._growth_tracker, # type: ignore[arg-type]
  311. found_inf_combined,
  312. self._growth_factor, # type: ignore[arg-type]
  313. self._backoff_factor, # type: ignore[arg-type]
  314. self._growth_interval, # type: ignore[arg-type]
  315. )
  316. # To prepare for next iteration, clear the data collected from optimizers this iteration.
  317. self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)