12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649 |
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- import collections
- import copy
- import enum
- import inspect
- import io
- import logging
- from itertools import chain
- from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
- import torch
- import torch.distributed as dist
- from torch.distributed.algorithms.join import Join, Joinable, JoinHook
- from torch.distributed.optim.utils import functional_optim_map
- from torch.optim import Optimizer
- logger = logging.getLogger(__name__)
- __all__ = ["ZeroRedundancyOptimizer"]
- # Credits: classy_vision/generic/distributed_util.py
- def _recursive_copy_to_device(
- value: Any,
- non_blocking: bool,
- device: torch.device,
- ) -> Any:
- r"""
- Recursively searches lists, tuples, dicts and copies tensors to device if
- possible. Non-tensor values are passed as-is in the result.
- .. note: These are all copies, so if there are two objects that reference
- the same object, then after this call, there will be two different objects
- referenced on the device.
- """
- if isinstance(value, torch.Tensor):
- return value.to(device, non_blocking=non_blocking)
- if isinstance(value, (list, tuple)):
- values = [
- _recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
- for val in value
- ]
- return values if isinstance(value, list) else tuple(values)
- if isinstance(value, collections.abc.Mapping):
- return {
- key: _recursive_copy_to_device(
- val, non_blocking=non_blocking, device=device
- )
- for key, val in value.items()
- }
- return value
- def _is_trainable(param: torch.Tensor) -> bool:
- r"""
- Returns if a parameter is trainable, where trainability is equivalent to
- requiring a gradient.
- """
- return param.requires_grad
- def _broadcast_object(
- obj: Any,
- src_rank: int,
- group: object = dist.group.WORLD,
- device: torch.device = torch.device("cpu"),
- ) -> Any:
- r"""
- Broadcasts an object to the given group, sending the object if called from
- the source rank and receiving the object otherwise.
- Arguments:
- obj: object to broadcast; only used if called on the source rank.
- src_rank (int): source rank.
- group (``ProcessGroup``, optional): group used for the broadcast
- (default: ``dist.group.WORLD``).
- device (``torch.device``, optional): device to send from or receive
- to (default: ``torch.device("cpu")``).
- Returns:
- The broadcasted object.
- """
- if dist.get_rank() == src_rank:
- # Send the object
- buffer = io.BytesIO()
- torch.save(obj, buffer)
- data = bytearray(buffer.getbuffer())
- length_tensor = torch.LongTensor([len(data)]).to(device)
- data_send_tensor = torch.ByteTensor(data).to(device)
- dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
- dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
- else:
- # Receive the object
- length_tensor = torch.LongTensor([0]).to(device)
- dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
- data_recv_tensor = torch.empty(
- [int(length_tensor.item())], dtype=torch.uint8, device=device
- )
- dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
- buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
- obj = torch.load(buffer, map_location=device)
- return obj
- class _ZeROJoinHook(JoinHook):
- def __init__(self, zero):
- assert isinstance(zero, ZeroRedundancyOptimizer), (
- "ZeRO join hook requires passing in a ZeroRedundancyOptimizer "
- "instance as the state"
- )
- self.zero = zero
- super().__init__()
- def main_hook(self):
- """
- Performs an optimizer step, which updates the joined process's shard of
- the parameters and broadcasts those parameters.
- """
- self.zero.step()
- class _DDPBucketAssignment:
- r"""
- This represents a :class:`DistributedDataParallel` bucket assignment,
- meaning a (possibly non-strict) subset of the parameters corresponding to
- a DDP bucket assigned to a rank to update.
- Attributes:
- bucket_index (int): index of the bucket determined by the DDP gradient
- bucket all-reduce order.
- parameters (List[torch.Tensor]): model parameters in the bucket
- assigned to this rank.
- offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
- giving the index of the first element in the passed-in
- ``parameters``; this equivalently indexes into the
- :class:`GradBucket` 's :meth:`gradients`.
- device (torch.device): device on which the parameters are stored.
- tensor (torch.Tensor): flattened tensor giving the data of the
- parameter subset assigned to the rank.
- """
- def __init__(
- self,
- bucket_index: int,
- parameters: List[torch.Tensor],
- offset: int,
- ):
- self.bucket_index = bucket_index
- self.parameters = parameters
- self.offset = offset
- if len(self.parameters) == 0:
- raise ValueError("Empty bucket assignment")
- # DDP guarantees all parameters in the bucket have the same device
- self.device: torch.device = self.parameters[0].device
- self.tensor: Optional[torch.Tensor] = None
- class _OverlapStatus(enum.IntEnum):
- r"""
- This defines the three possible statuses that
- :class:`ZeroRedundancyOptimizer` can be in when overlapping with
- :class:`DistributedDataParallel`.
- ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and
- is waiting for DDP to finalize its bucketing.
- ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that
- its bucketing is finalized. The ZeRO instance can now collect the
- necessary information about the DDP bucketing.
- ``INITIALIZED``: The ZeRO instance is fully initialized and can now
- optimize parameters.
- """
- UNINITIALIZED = 0
- DDP_HAS_REBUILT_BUCKETS = 1
- INITIALIZED = 2
- class _OverlapInfo:
- r"""
- This contains the information needed by :class:`ZeroRedundancyOptimizer`
- to overlap with :class:`DistributedDataParallel`.
- Arguments:
- world_size (int): world size of the process group being used.
- Attributes:
- shard_buckets (bool): if ``True``, then the assignment of each
- :class:`DistributedDataParallel` bucket is partitioned across
- possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
- across possibly multiple ranks) to approximate uniformity following
- a threshold given by the total parameter size divided by the world
- size; if ``False``, then each bucket is wholly assigned to a single
- :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
- this should be set to the value passed into the hook constructor.
- status (_OverlapStatus): current status; see :class:`_OverlapStatus`
- for more information.
- params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
- gives the model parameters in the ``i``th bucket.
- params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]``
- gives the model parameters assigned to the ``i``th rank, where the
- parameters are grouped by increasing bucket indices.
- offsets (Dict[int, int]): maps from bucket index to the offset in
- ``self.params_per_rank[rank]`` giving the index of the first
- parameter in that bucket, where ``rank`` is this process's own
- rank; the keys of this :class:`dict` are the bucket indices
- assigned to this rank.
- num_bucket_assignments (int): total number of bucket assignments across
- all ranks; this is equal to the number of
- :class:`DistributedDataParallel` gradient buckets if
- ``shard_buckets=False`` and possibly greater otherwise.
- total_size (int, optional): total size of all buckets (i.e. sum of
- ``param.numel()`` for all ``param`` across all buckets) if
- ``shard_buckets=True``; otherwise, ``None``.
- broadcast_handles (List[Work]): :class:`list` of async work handles for
- the parameter broadcasts.
- bucket_index_to_future (Dict[int, torch.futures.Future]):
- :class:`dict` mapping bucket index to the corresponding all-reduce
- future.
- bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict`
- mapping bucket index to the corresponding bucket.
- bucket_indices_seen (List[int]): :class:`list` of the bucket indices
- seen on this iteration.
- """
- def __init__(self, world_size) -> None:
- self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
- self.shard_buckets: bool = False
- # Modified per bucket reconstruction
- self.params_per_bucket: List[List[torch.Tensor]] = []
- self.params_per_rank: List[List[torch.Tensor]] = [[] for _ in range(world_size)]
- self.offsets: Dict[int, int] = {}
- # Group Ranks
- self.assigned_ranks_per_bucket: List[Set[int]] = []
- self.num_bucket_assignments: int = 0
- self.total_size: Optional[int] = None
- # Modified per iteration
- self.broadcast_handles: List[Any] = []
- self.bucket_indices_seen: List[int] = []
- # Used by `hook_with_zero_step()`
- self.bucket_index_to_future: Dict[int, torch.futures.Future] = {}
- self.bucket_index_to_bucket: Dict[int, dist.GradBucket] = {}
- def wait_for_broadcasts(self) -> None:
- r"""
- Waits for all parameter broadcasts. This should be called once all
- broadcasts have been scheduled, meaning ``self.broadcast_handles`` is
- filled. This clears ``self.broadcast_handles`` in preparation for the
- next iteration.
- """
- assert (
- len(self.broadcast_handles) == self.num_bucket_assignments
- ), f"Missing at least one broadcast handle on rank {dist.get_rank()}"
- _ = list(map(lambda x: x.wait(), self.broadcast_handles))
- self.broadcast_handles.clear()
- def clear_per_iter_info(self) -> None:
- r"""
- Clears the data structures that are modified per-iteration. This should
- be called at the end of an iteration.
- """
- self.bucket_indices_seen.clear()
- self.bucket_index_to_future.clear()
- self.bucket_index_to_bucket.clear()
- class ZeroRedundancyOptimizer(Optimizer, Joinable):
- r"""
- This class wraps an arbitrary :class:`optim.Optimizer
- <torch.optim.Optimizer>` and shards its states across ranks in the group as
- described by ZeRO_. The local optimizer instance in each rank is only
- responsible for updating approximately ``1 / world_size`` parameters and
- hence only needs to keep ``1 / world_size`` optimizer states. After
- parameters are updated locally, each rank will broadcast its parameters to
- all other peers to keep all model replicas in the same state.
- ``ZeroRedundancyOptimizer`` can be used in conjunction with
- :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
- memory consumption.
- ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
- of parameters at each rank. Each parameter belongs to a single rank and is
- not divided among ranks. The partition is arbitrary and might not match the
- the parameter registration or usage order.
- Arguments:
- params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
- or :class:`dict` s giving all parameters, which will be sharded
- across ranks.
- Keyword Args:
- optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
- optimizer.
- process_group (``ProcessGroup``, optional): ``torch.distributed``
- ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by
- :meth:`torch.distributed.init_process_group`).
- parameters_as_bucket_view (bool, optional): if ``True``, parameters are
- packed into buckets to speed up communication, and ``param.data``
- fields point to bucket views at different offsets; if ``False``,
- each individual parameter is communicated separately, and each
- ``params.data`` stays intact (default: ``False``).
- overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is
- overlapped with :class:`DistributedDataParallel` 's gradient
- synchronization; this requires (1) either a functional optimizer
- for the ``optimizer_class`` argument or one with a functional
- equivalent and (2) registering a DDP communication hook
- constructed from one of the functions in ``ddp_zero_hook.py``;
- parameters are packed into buckets matching those in
- :class:`DistributedDataParallel`, meaning that the
- ``parameters_as_bucket_view`` argument is ignored.
- If ``False``, :meth:`step` runs disjointly after the backward pass
- (per normal).
- (default: ``False``)
- **defaults: any trailing arguments, which are forwarded to the local
- optimizer.
- Example::
- >>> # xdoctest: +SKIP
- >>> import torch.nn as nn
- >>> from torch.distributed.optim import ZeroRedundancyOptimizer
- >>> from torch.nn.parallel import DistributedDataParallel as DDP
- >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
- >>> ddp = DDP(model, device_ids=[rank])
- >>> opt = ZeroRedundancyOptimizer(
- >>> ddp.parameters(),
- >>> optimizer_class=torch.optim.Adam,
- >>> lr=0.01
- >>> )
- >>> ddp(inputs).sum().backward()
- >>> opt.step()
- .. warning::
- Currently, ``ZeroRedundancyOptimizer`` requires that all of the
- passed-in parameters are the same dense type.
- .. warning::
- If you pass ``overlap_with_ddp=True``, be wary of the following: Given
- the way that overlapping :class:`DistributedDataParallel` with
- :class:`ZeroRedundancyOptimizer` is currently implemented, the first
- two or three training iterations do not perform parameter updates in
- the optimizer step, depending on if ``static_graph=False`` or
- ``static_graph=True``, respectively. This is because it needs
- information about the gradient bucketing strategy used by
- :class:`DistributedDataParallel`, which is not finalized until the
- second forward pass if ``static_graph=False`` or until the third
- forward pass if ``static_graph=True``. To adjust for this, one option
- is to prepend dummy inputs.
- .. warning:: ZeroRedundancyOptimizer is experimental and subject to change.
- .. _ZeRO: https://arxiv.org/abs/1910.02054
- """
- def __init__(
- self,
- params,
- optimizer_class: Type[Optimizer],
- process_group: Optional[Any] = None,
- parameters_as_bucket_view: bool = False,
- overlap_with_ddp: bool = False,
- **defaults: Any,
- ):
- # Perform type and assumption checks on the input parameters
- params = self._verify_and_init_params(params)
- self._verify_same_dense_param_type()
- # NOTE: The parent constructor uses `add_param_group()` which is
- # partially overloaded in ZeroRedundancyOptimizer, so we use the
- # `initialized` flag to dissociate the behaviour of `add_param_group()`
- # between the parent and child.
- self.initialized = False
- Optimizer.__init__(self, params, defaults)
- Joinable.__init__(self)
- # Now, all parameters are held in both `self._all_params` and
- # `self.param_groups`
- # Internal data structures (`_cache` indicates lazily evaluated)
- self._param_to_rank_cache: Dict[torch.Tensor, int] = {}
- self._param_to_index_cache: Dict[torch.Tensor, int] = {}
- self._partition_parameters_cache: List[List[Dict]] = []
- self._index_to_param_cache: List[torch.Tensor] = []
- self._device_to_params_per_rank_cache: Dict[
- torch.device, List[List[torch.Tensor]]
- ] = {}
- self._bucket_assignments_per_rank_cache: List[
- Dict[int, _DDPBucketAssignment]
- ] = []
- self._is_trainable_mask = self._get_is_trainable_mask()
- # Default device for collective communication and buckets
- self._default_device = self._all_params[0].device
- self.process_group = (
- process_group if process_group is not None else dist.group.WORLD
- )
- self.world_size: int = dist.get_world_size(self.process_group)
- self.rank: int = dist.get_rank(self.process_group)
- self.global_rank: int = dist.distributed_c10d.get_global_rank(
- self.process_group, self.rank
- )
- self._overlap_with_ddp: bool = overlap_with_ddp
- self._optim_defaults = defaults
- self._optim_constructor = self._get_optimizer_constructor(optimizer_class)
- # If `overlap_with_ddp=True`, local optimizer initialization is delayed
- # to run time after the necessary information has been collected
- if not overlap_with_ddp:
- self._init_local_optimizer()
- else:
- self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
- if parameters_as_bucket_view:
- logger.warning(
- "`parameters_as_bucket_view=True` will be ignored since "
- "`overlap_with_ddp=True`; instead, a different bucketing "
- "strategy will be used"
- )
- # `self._buckets` is used if `parameters_as_bucket_view=True`, in
- # which case parameter data is flattened into contiguous bucket tensors
- self.parameters_as_bucket_view = parameters_as_bucket_view
- self._buckets: List[List[torch.Tensor]] = []
- self._build_param_buckets()
- # Optional consolidated optimizer state, only populated if this rank
- # is the target in `consolidate_state_dict()`
- self._all_state_dicts: List[Dict[str, Any]] = []
- self.initialized = True
- def _clear_cache(self) -> None:
- r"""
- Clears the cached data structures giving partition information.
- """
- self._partition_parameters_cache.clear()
- self._param_to_rank_cache.clear()
- self._index_to_param_cache.clear()
- self._param_to_index_cache.clear()
- self._device_to_params_per_rank_cache.clear()
- self._bucket_assignments_per_rank_cache.clear()
- def add_param_group(self, param_group: dict) -> None:
- r"""
- Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
- This can be useful when fine tuning a pre-trained network, as frozen
- layers can be made trainable and added to the :class:`Optimizer` as
- training progresses.
- Arguments:
- param_group (dict): specifies the parameters to be optimized and
- group-specific optimization options.
- .. warning:: This method handles updating the shards on all partitions
- but needs to be called on all ranks. Calling this on a subset of
- the ranks will cause the training to hang because communication
- primitives are called depending on the managed parameters and
- expect all the ranks to participate on the same set of parameters.
- """
- if self.initialized and self._overlap_with_ddp:
- raise RuntimeError(
- "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only "
- "supports a single parameter group"
- )
- super().add_param_group(param_group)
- # NOTE: The rest of the method assumes that the call to the parent's
- # `add_param_group()` appends the new parameter group and preserves
- # the previous parameter-group ordering
- if self.initialized:
- # Force a re-partitioning of the parameters
- self._clear_cache()
- param_groups = self._partition_parameters()[self.rank]
- # NOTE: All parameters in the old parameter groups should be
- # assigned to the same ranks so that the local optimizers do not
- # need to be reinitialized
- # Add the parameters assigned to this rank from the new parameter
- # group to the local optimizer, if any
- if len(param_groups) == len(self.optim.param_groups) + 1:
- self.optim.add_param_group(param_groups[-1])
- # Update the bucketing strategy accordingly
- if self.parameters_as_bucket_view:
- self._build_param_buckets()
- def consolidate_state_dict(self, to: int = 0) -> None:
- r"""
- Consolidate a list of ``state_dict`` s (one per rank) on the target
- rank.
- Arguments:
- to (int): the rank that receives the optimizer states (default: 0).
- Raises:
- RuntimeError: if ``overlap_with_ddp=True`` and this method is
- called before this :class:`ZeroRedundancyOptimizer` instance
- has been fully initialized, which happens once
- :class:`DistributedDataParallel` gradient buckets have been
- rebuilt.
- .. warning:: This needs to be called on all ranks.
- """
- self._check_overlap_initialized()
- # Sync the exposed `param_groups` attributes to the local optimizer in
- # case they have been updated
- self._sync_param_groups(self.param_groups, self.optim.param_groups)
- # Pull the sharded state from all ranks and store them in rank order
- empty_messenger = torch.tensor(
- [0], dtype=torch.uint8, device=self._default_device
- )
- # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
- # due to compatibility issues with NCCL backend; a possible follow-up
- # is to move all sharded state management to RPC RRef
- self._all_state_dicts = []
- for rank in range(self.world_size):
- global_rank = dist.distributed_c10d.get_global_rank(
- self.process_group, rank
- )
- if self.rank == to:
- # Consolidate all local `state_dict`s on this rank, storing on
- # CPU to save GPU memory
- if rank == self.rank:
- # Directly append own optimizer state
- self._all_state_dicts.append(
- _recursive_copy_to_device(
- self.optim.state_dict(),
- non_blocking=True,
- device=torch.device("cpu"),
- )
- )
- else:
- # Receive the optimizer state from the source rank
- local_state_dict = _broadcast_object(
- empty_messenger,
- src_rank=global_rank,
- group=self.process_group,
- device=self._default_device,
- )
- self._all_state_dicts.append(
- _recursive_copy_to_device(
- local_state_dict,
- non_blocking=True,
- device=torch.device("cpu"),
- )
- )
- else:
- if rank == self.rank:
- # Send the optimizer state to the target rank
- _ = _broadcast_object(
- self.optim.state_dict(),
- src_rank=self.global_rank,
- group=self.process_group,
- device=self._default_device,
- )
- elif rank != to:
- # Discard the received object; `broadcast()` is used for
- # compatibility reasons
- _ = _broadcast_object(
- empty_messenger,
- src_rank=global_rank,
- group=self.process_group,
- device=self._default_device,
- )
- def _verify_params_per_rank(
- self,
- params_per_rank: List[List[torch.Tensor]],
- ) -> None:
- r"""
- Verifies ``params_per_rank`` for :meth:`_partition_parameters`,
- checking that ``params_per_rank`` has length equal to the world size
- and that it does not contain any parameters not passed into the
- :class:`ZeroRedundancyOptimizer` constructor.
- The parameters in ``params_per_rank`` being a strict subset of those
- passed into the constructor is valid since some parameters may be
- frozen.
- Raises:
- ValueError: if ``params_per_rank`` does not have length equal to
- the world size or if it contains a parameter that was not
- passed into the :class:`ZeroRedundancyOptimizer` constructor.
- """
- if len(params_per_rank) != self.world_size:
- raise ValueError(
- "`params_per_rank` must have length equal to the world size"
- )
- all_params_set = set(self._all_params)
- for params in params_per_rank:
- for param in params:
- if param not in all_params_set:
- raise ValueError(
- "Passing a new parameter in `params_per_rank` that "
- "was not passed into the ZeroRedundancyOptimizer "
- "constructor"
- )
- def _partition_param_group(
- self, param_group: Dict[str, Any], params_per_rank: List[List[torch.Tensor]]
- ) -> None:
- r"""
- Partitions the parameter group ``param_group`` according to
- ``params_per_rank`` by modifying ``self._partition_parameters_cache``.
- This method should only be used as a subroutine for
- :meth:`_partition_parameters`.
- Arguments:
- param_group (dict[str, Any]): a parameter group as normally defined
- in an optimizer state.
- params_per_rank (list[list[torch.Tensor]]): a :class:`list` of
- length world size containing :class:`list` s of parameters to
- assign to each rank.
- """
- for rank, params in enumerate(params_per_rank):
- rank_param_group = copy.copy(param_group)
- rank_param_group["params"] = params
- self._partition_parameters_cache[rank].append(rank_param_group)
- def _partition_parameters(
- self,
- params_per_rank: Optional[List[List[torch.Tensor]]] = None,
- ) -> List[List[Dict]]:
- r"""
- Partitions parameters across distributed data parallel ranks.
- Arguments:
- params_per_rank (list[list[torch.Tensor]], optional): a
- :class:`list` of length world size containing :class:`list` s
- of parameters to assign to each rank; this provides a way to
- specify a partition manually.
- If ``None``, the parameters are partitioned according to an
- internal algorithm.
- (default: ``None``)
- Returns:
- A :class:`list` where each element of the list contains the
- ``param_groups`` for a rank (which itself is a :class:`list` of
- :class:`dict`); element 0 corresponds to rank 0, etc.; each rank
- stores the ``param_groups`` for all ranks for the collective
- communication in :meth:`step`.
- Raises:
- ValueError: see :meth:`_validate_params_per_rank`.
- RuntimeError: if ``params_per_rank`` is not ``None`` and this
- :class:`ZeroRedundancyOptimizer` instance is using more than
- one parameter group.
- """
- if params_per_rank is None:
- # Partition the parameters optimizing for uniformity
- if len(self._partition_parameters_cache) == 0:
- self._partition_parameters_cache = [[] for _ in range(self.world_size)]
- sizes = [0] * self.world_size
- for param_group in self.param_groups:
- param_group_params_per_rank: List[List] = [
- [] for _ in range(self.world_size)
- ]
- # Sort the parameters by size (largest first)
- params_sorted = sorted(
- param_group["params"], key=lambda t: t.numel(), reverse=True
- )
- for param in params_sorted:
- # Greedily add the parameter to rank with smallest size so far
- rank = self._get_min_index(sizes)
- param_group_params_per_rank[rank].append(param)
- sizes[rank] += param.numel()
- # Apply the constructed partition of the parameter group
- self._partition_param_group(
- param_group, param_group_params_per_rank
- )
- return self._partition_parameters_cache
- # Partition the parameters according to `params_per_rank`
- assert len(self._partition_parameters_cache) == 0, (
- "Specifying `params_per_rank` should only be done when the "
- "parameters have not been partitioned yet"
- )
- if len(self.param_groups) != 1:
- raise RuntimeError(
- "Specifying `params_per_rank` only supports a single " "parameter group"
- )
- self._verify_params_per_rank(params_per_rank)
- self._partition_parameters_cache = [[] for _ in range(self.world_size)]
- # Apply the passed-in partition of the parameter group
- param_group = self.param_groups[0]
- self._partition_param_group(param_group, params_per_rank)
- return self._partition_parameters_cache
- @property
- def _param_to_rank(self) -> Dict[torch.Tensor, int]:
- r"""
- :class:`dict` mapping parameters to their assigned data parallel rank
- in the partition.
- """
- if len(self._param_to_rank_cache) == 0:
- for rank, param_groups in enumerate(self._partition_parameters()):
- for param_group in param_groups:
- for param in param_group["params"]:
- self._param_to_rank_cache[param] = rank
- return self._param_to_rank_cache
- @property
- def _param_to_index(self) -> Dict[torch.Tensor, int]:
- r"""
- :class:`dict` mapping parameters to their indices in the global
- optimizer state.
- NOTE: This assumes that the global optimizer state's indexing (in
- ``state_dict``) follows a linear ordering over the parameter groups.
- """
- if len(self._param_to_index_cache) == 0:
- self._param_to_index_cache = {
- p: i
- for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))
- }
- return self._param_to_index_cache
- @property
- def _index_to_param(self) -> List[torch.Tensor]:
- r"""
- List mapping parameter indices in the global optimizer scheme to the
- actual params.
- """
- if len(self._index_to_param_cache) == 0:
- self._index_to_param_cache = list(
- chain(*(g["params"] for g in self.param_groups))
- )
- return self._index_to_param_cache
- def _broadcast_params_from_rank(self, rank: int):
- r"""
- Broadcasts the shard of parameters from a given rank to all other
- ranks asynchronously.
- Arguments:
- rank (int): the source rank.
- Returns:
- A :class:`list` of async work handles for the ``broadcast()`` s
- performed to synchronize the parameters.
- """
- assert not self._overlap_with_ddp, (
- "`_broadcast_params_from_rank()` should not be used if "
- "`overlap_with_ddp=True`; instead, the broadcasting should "
- "happen in the DDP communication hook"
- )
- handles = []
- if self.parameters_as_bucket_view:
- for dev_i_buckets in self._buckets:
- bucket = dev_i_buckets[rank]
- global_rank = dist.distributed_c10d.get_global_rank(
- self.process_group, rank
- )
- handles.append(
- dist.broadcast(
- tensor=bucket,
- src=global_rank,
- group=self.process_group,
- async_op=True,
- )
- )
- else:
- param_groups = self._partition_parameters()[rank]
- global_rank = dist.distributed_c10d.get_global_rank(
- self.process_group, rank
- )
- for param_group in param_groups:
- for param in param_group["params"]:
- handles.append(
- dist.broadcast(
- tensor=param.data,
- src=global_rank,
- group=self.process_group,
- async_op=True,
- )
- )
- return handles
- def _sync_params(self):
- r"""
- Syncs all parameter shards across the ranks.
- This rank sends its shard of the parameters to all other ranks and
- receives a shard from each other rank. This is done using
- ``broadcast()``. Parameters are sent bucket-by-bucket if
- ``parameters_as_bucket_view=True``and sent parameter-by-parameter
- otherwise.
- """
- handles = []
- for rank in range(self.world_size):
- handles.extend(self._broadcast_params_from_rank(rank))
- _ = list(map(lambda x: x.wait(), handles))
- @property
- def _device_to_params_per_rank(
- self,
- ) -> Dict[torch.device, List[List[torch.Tensor]]]:
- r"""
- :class:`dict` mapping each device to a :class:`list` of the per-rank parameter
- lists filtered to only include the parameters stored on that device.
- Each per-rank parameter list gives the parameters assigned to that rank
- to update.
- This is used for constructing the parameter buckets if
- ``parameters_as_bucket_view=True``.
- Let ``dev_i`` denote the ``i``th device for this rank. Then:
- ``dev_0`` maps to a list containing:
- rank 0's assigned parameters stored on ``dev_0``,
- rank 1's assigned parameters stored on ``dev_0``,
- ...
- ``dev_1`` maps to a list containing:
- rank 0's assigned parameters stored on ``dev_1``,
- rank 1's assigned parameters stored on ``dev_1``,
- ...
- ...
- """
- assert self.parameters_as_bucket_view, (
- "`_device_to_params_per_rank` should only be used if "
- "`parameters_as_bucket_view=True`"
- )
- if len(self._device_to_params_per_rank_cache) == 0:
- for rank, param_groups in enumerate(self._partition_parameters()):
- for param_group in param_groups:
- for param in param_group["params"]:
- device = param.device
- if device not in self._device_to_params_per_rank_cache:
- self._device_to_params_per_rank_cache[device] = [
- [] for _ in range(self.world_size)
- ]
- self._device_to_params_per_rank_cache[device][rank].append(
- param
- )
- return self._device_to_params_per_rank_cache
- def _get_min_index(
- self,
- values: List[int],
- disallowed_indices: Optional[Set[int]] = None,
- ) -> int:
- r"""
- Returns ``values.index(min(values))``, except only uses one pass. It
- also excludes any indices in ``disallowed_indices`` if provided.
- Arguments:
- values: (List[int]): :class:`list` of values.
- disallowed_indices (Optional[Set[int]]): indices that are
- disallowed from being the returned min index.
- """
- min_index = -1
- min_value = float("inf")
- for i, value in enumerate(values):
- if disallowed_indices and i in disallowed_indices:
- continue
- if value < min_value:
- min_value = value
- min_index = i
- assert min_index >= 0, "All indices are disallowed"
- return min_index
- def _assign_bucket_subset_to_rank(
- self,
- bucket_index: int,
- bucket_params: List[torch.Tensor],
- bucket_offset: int,
- assigned_rank: int,
- assigned_ranks_per_bucket: List[Set[int]],
- ) -> None:
- r"""
- Assigns the model parameters given by ``bucket_params``, representing a
- (possibly non-strict) subset of the parameters corresponding to a
- :class:`DistributedDataParallel` bucket, to the rank with the least
- size assigned so far and collects relevant information.
- Arguments:
- bucket_index (int): index of the :class:`DistributedDataParallel`
- gradient bucket.
- bucket_params (List[torch.Tensor]): subset of the parameters
- corresponding to the bucket to assign.
- bucket_offset (int): offset giving the index of the first element
- in ``bucket_params`` in the bucket's full parameter list.
- assigned_rank (int): group rank to assign to.
- assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks
- assigned to each bucket.
- """
- overlap_info = self._overlap_info
- if len(bucket_params) == 0:
- raise ValueError("Empty bucket assignment")
- params_per_rank = overlap_info.params_per_rank
- offsets = overlap_info.offsets
- self._bucket_assignments_per_rank_cache[assigned_rank][
- bucket_index
- ] = _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
- if self.global_rank == assigned_rank:
- offsets[bucket_index] = len(params_per_rank[assigned_rank])
- params_per_rank[assigned_rank].extend(bucket_params)
- assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
- self._overlap_info.num_bucket_assignments += 1
- @property
- def _bucket_assignments_per_rank(self) -> List[Dict[int, _DDPBucketAssignment]]:
- r"""
- :class:`list` of length world size consisting of :class:`dict` s
- mapping bucket indices to :class:`_DDPBucketAssignment` s for each
- rank.
- """
- assert self._overlap_with_ddp, (
- "`_bucket_assignments_per_rank` " "only be used if `overlap_with_ddp=True`"
- )
- if len(self._bucket_assignments_per_rank_cache) > 0:
- return self._bucket_assignments_per_rank_cache
- overlap_info = self._overlap_info
- assert overlap_info.status == _OverlapStatus.INITIALIZED
- self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
- params_per_bucket = overlap_info.params_per_bucket
- if overlap_info.shard_buckets:
- # Define the assignment threshold to approximate uniformity
- assert overlap_info.total_size is not None, "`total_size` was not computed"
- threshold = overlap_info.total_size / self.world_size # type: ignore[operator]
- size_per_rank = [0 for _ in range(self.world_size)]
- num_buckets = len(params_per_bucket)
- overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
- assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
- if not overlap_info.shard_buckets:
- # Assign each DDP bucket entirely to a single rank
- for bucket_index, bucket_params in enumerate(params_per_bucket):
- assert len(bucket_params) > 0, "Empty bucket"
- assigned_rank = self._get_assigned_rank(bucket_index)
- self._assign_bucket_subset_to_rank(
- bucket_index,
- bucket_params,
- 0,
- assigned_rank,
- assigned_ranks_per_bucket,
- )
- else:
- # Assign each DDP bucket to possibly multiple ranks
- # Specifically, sort the DDP buckets by increasing size, and for
- # each bucket, iteratively assign the maximal unassigned subset
- # with size less than `threshold` to the rank with the least total
- # size so far -- each such assignment is represented by a
- # `_DDPBucketAssignment` instance and only contains parameters from
- # a single DDP bucket
- params_per_bucket_enum = sorted(
- enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1])
- )
- for bucket_index, bucket_params in params_per_bucket_enum:
- assert len(bucket_params) > 0, "Empty bucket"
- bucket_offset = 0
- assignment_size = 0
- for param_index, param in enumerate(bucket_params):
- param_numel = param.numel()
- if (
- assignment_size + param_numel >= threshold
- and param_index > bucket_offset
- ):
- assigned_rank = self._get_min_index(
- size_per_rank, assigned_ranks_per_bucket[bucket_index]
- )
- # Include up to but not including the parameter that
- # exceeded the threshold
- self._assign_bucket_subset_to_rank(
- bucket_index,
- bucket_params[bucket_offset:param_index],
- bucket_offset,
- assigned_rank,
- assigned_ranks_per_bucket,
- )
- size_per_rank[assigned_rank] += assignment_size
- bucket_offset = param_index
- assignment_size = 0
- assignment_size += param_numel
- # Assign the remainder of the bucket so that no assignment
- # spans across two buckets
- assigned_rank = self._get_min_index(
- size_per_rank, assigned_ranks_per_bucket[bucket_index]
- )
- self._assign_bucket_subset_to_rank(
- bucket_index,
- bucket_params[bucket_offset:],
- bucket_offset,
- assigned_rank,
- assigned_ranks_per_bucket,
- )
- size_per_rank[assigned_rank] += assignment_size
- return self._bucket_assignments_per_rank_cache
- def _local_step(
- self,
- gradients: Optional[List[Optional[torch.Tensor]]] = None,
- closure: Optional[Callable[[], float]] = None,
- **kwargs: Any,
- ) -> Optional[float]:
- r"""
- Performs a single optimizer step without syncing parameters across
- ranks.
- Arguments:
- gradients (list[Optional[torch.Tensor]], optional): a :class:`list`
- of length equal to the number of parameters assigned to this
- rank containing gradient tensors or ``None`` as its elements;
- a ``None`` in the :class:`list` indicates that the
- corresponding parameter should not be updated.
- If the argument itself is ``None``, then all parameters are
- updated, and the gradients are assumed to be already populated.
- (default: ``None``)
- closure (Callable): a closure that re-evaluates the model and
- returns the loss; optional for most optimizers and should be
- ``None`` if ``gradients`` is not ``None``; (default: ``None``)
- Returns:
- Optional loss depending on the underlying local optimizer.
- .. warning::
- The argument ``gradients`` should only be specified (i.e. not
- ``None``) if ``overlap_with_ddp=True``, in which case
- :class:`ZeroRedundancyOptimizer` wraps a functional optimizer.
- """
- Join.notify_join_context(self)
- # Check if the model trainability has changed
- is_trainable_mask = self._get_is_trainable_mask()
- if is_trainable_mask != self._is_trainable_mask:
- if self._overlap_with_ddp:
- raise RuntimeError(
- "ZeroRedundancyOptimizer with `overlap_with_ddp=True` "
- "does not support changing parameter trainability at run "
- "time"
- )
- logger.warning(
- "ZeroRedundancyOptimizer detected that the trainable "
- "parameters changed; rebuilding the parameter buckets if "
- "enabled"
- )
- self._build_param_buckets()
- self._is_trainable_mask = is_trainable_mask
- # Sync the exposed `param_groups` attributes to the local optimizer in
- # case they have been updated
- self._sync_param_groups(self.param_groups, self.optim.param_groups)
- # Run the optimizer step on this shard only
- if gradients is None:
- loss = (
- self.optim.step(**kwargs)
- if closure is None
- else self.optim.step(closure=closure, **kwargs)
- )
- else:
- assert self._overlap_with_ddp, (
- "Specifying `gradients` should not "
- "be used when `overlap_with_ddp=False`"
- )
- assert closure is None, (
- "`closure` is not supported when using " "a local functional optimizer"
- )
- loss = self.optim.step(gradients=gradients)
- # Sync any updated attributes in the local optimizer to the exposed
- # `param_groups`
- self._sync_param_groups(self.optim.param_groups, self.param_groups)
- return loss
- def step(
- self,
- closure: Optional[Callable[[], float]] = None,
- **kwargs: Any,
- ) -> Optional[float]:
- r"""
- Performs a single optimizer step and syncs parameters across all ranks.
- Arguments:
- closure (Callable): a closure that re-evaluates the model and
- returns the loss; optional for most optimizers.
- Returns:
- Optional loss depending on the underlying local optimizer.
- .. note: Any extra parameters are passed to the base optimizer as-is.
- """
- if self._overlap_with_ddp:
- logger.warning(
- "`step()` should not be included in the training loop when "
- "`overlap_with_ddp=True`"
- )
- return None
- # Perform the local optimizer step
- loss = self._local_step(closure=closure, **kwargs)
- # Sync all of the updated parameter shards across the ranks
- self._sync_params()
- return loss
- def join_hook(self, **kwargs):
- r"""
- Returns the ZeRO join hook, which enables training on uneven inputs by
- shadowing the collective communications in the optimizer step.
- Gradients must be properly set before this hook is called.
- Arguments:
- kwargs (dict): a :class:`dict` containing any keyword arguments
- to modify the behavior of the join hook at run time; all
- :class:`Joinable` instances sharing the same join context
- manager are forwarded the same value for ``kwargs``.
- This hook does not support any keyword arguments; i.e. ``kwargs`` is
- unused.
- """
- return _ZeROJoinHook(self)
- @property
- def join_device(self) -> torch.device:
- return self._default_device
- @property
- def join_process_group(self) -> Any:
- return self.process_group
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
- r"""
- Load the state pertaining to the given rank from the input
- ``state_dict``, updating the local optimizer as needed.
- Arguments:
- state_dict (dict): optimizer state; should be an object returned
- from a call to :meth:`state_dict`.
- Raises:
- RuntimeError: if ``overlap_with_ddp=True`` and this method is
- called before this :class:`ZeroRedundancyOptimizer` instance
- has been fully initialized, which happens once
- :class:`DistributedDataParallel` gradient buckets have been
- rebuilt.
- """
- self._check_overlap_initialized()
- for index, value in state_dict["state"].items():
- param = self._index_to_param[index]
- if self._param_to_rank[param] != self.rank:
- # Clear any state irrelevant to this rank
- state_dict["state"][index] = None
- else:
- # Load the parameter state to the local optimizer
- self.optim.state[param] = _recursive_copy_to_device(
- value, non_blocking=True, device=param.device
- )
- # Force zero-dimensional tensors (like Adam "step") on CPU
- for state_name, state_value in self.optim.state[param].items():
- if torch.is_tensor(state_value) and state_value.dim() == 0:
- self.optim.state[param][state_name] = state_value.cpu()
- super().load_state_dict(state_dict)
- # Sync the input state with the exposed and local optimizer states
- self._sync_param_groups(state_dict["param_groups"], self.param_groups)
- self._sync_param_groups(self.param_groups, self.optim.param_groups)
- def state_dict(self) -> Dict[str, Any]:
- r"""
- Returns the last global optimizer state known to this rank.
- .. warning:
- If the state has not been consolidated to this rank, this raises a
- runtime error, and even if it has, the state may not be up-to-date,
- depending on when :meth:`consolidate_state_dict` was last called.
- Raises:
- RuntimeError: if ``overlap_with_ddp=True`` and this method is
- called before this :class:`ZeroRedundancyOptimizer` instance
- has been fully initialized, which happens once
- :class:`DistributedDataParallel` gradient buckets have been
- rebuilt; or if this method is called without a preceding call
- to :meth:`consolidate_state_dict`.
- """
- self._check_overlap_initialized()
- if len(self._all_state_dicts) == 0:
- raise RuntimeError(
- "Optimizer state has not been consolidated on this rank. "
- f"Please call `consolidate_state_dict(to={self.rank})` on "
- "all ranks beforehand if you meant to save the global state."
- )
- # Get the possibly-stale global optimizer state that uses global
- # parameter indexing
- state_dict = super().state_dict()
- # Update the global optimizer state with local state information,
- # factoring in the translation from local to global indexing
- for rank, local_state_dict in enumerate(self._all_state_dicts):
- local_param_groups = local_state_dict["param_groups"]
- global_param_groups = self._partition_parameters()[rank]
- assert len(local_param_groups) == len(
- global_param_groups
- ), "Mismatch between number of local and global parameter groups"
- for local_param_group, global_param_group in zip(
- local_param_groups, global_param_groups
- ):
- # `local_param_group` stores local indices, while
- # `global_param_group` stores the tensors directly
- local_param_indices = local_param_group["params"]
- global_params = global_param_group["params"]
- assert len(local_param_indices) == len(
- global_params
- ), "Mismatch between number of local and global parameters in parameter group"
- for local_param_index, global_param in zip(
- local_param_indices, global_params
- ):
- # Update the global parameter state, if any
- if local_param_index in local_state_dict["state"]:
- global_param_index = self._param_to_index[global_param]
- state_dict["state"][global_param_index] = local_state_dict[
- "state"
- ][local_param_index]
- # Sort the parameters in the state
- state_dict["state"] = dict(sorted(state_dict["state"].items()))
- return state_dict
- @staticmethod
- def _sync_param_groups(
- src_param_groups: List[Dict[Any, Any]],
- dst_param_groups: List[Dict[Any, Any]],
- ) -> None:
- r"""
- Syncs the attributes from the source parameter groups to the
- destination parameter groups.
- Example attributes include learning rate or scheduler attributes. The
- two parameter groups should have the same length (i.e. same number of
- parameter groups).
- Arguments:
- src_param_groups (list[dict]): parameter groups giving the
- attribute settings to copy.
- dst_param_groups (list[dict]): parameter groups giving the
- attribute settings to set.
- """
- assert len(src_param_groups) == len(
- dst_param_groups
- ), "Mismatch between number of source and destination parameter groups"
- for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
- # Sync all attributes except the parameters
- for attr in filter(lambda x: x != "params", src_param_group.keys()):
- dst_param_group[attr] = src_param_group[attr]
- def _build_param_buckets(self) -> None:
- r"""
- Builds parameter buckets if ``parameters_as_bucket_view=True`` so
- that for each device that stores this rank's parameters, there is a
- bucket (represented as a tensor) containing all of the parameters on
- that device that are assigned to a given rank in the parameter update
- partition.
- This method is called in the constructor and any time parameter
- trainability is changed.
- .. warning::
- The current implementation assumes that all of the parameters in a
- bucket are of the same dense type when allocating the bucket's
- tensor.
- .. warning::
- If the model parameters are stored across more than one device,
- then the storage partitioning must be the same across all
- processes in order for parameter synchronization to work.
- """
- if not self.parameters_as_bucket_view or self._overlap_with_ddp:
- return
- # `self._buckets[i][j]` are the parameters stored on device i and
- # assigned to rank j
- num_devices = len(self._device_to_params_per_rank)
- self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment]
- for dev_i, (device, params_per_rank) in enumerate(
- self._device_to_params_per_rank.items()
- ):
- for params in params_per_rank:
- bucket_size = 0
- dtype = None
- trainable_params = []
- for param in params:
- if not _is_trainable(param):
- # Clone in case the parameter was previously part of
- # a bucket to avoid the data from being destroyed
- param.data = param.data.detach().clone()
- else:
- bucket_size += param.numel()
- trainable_params.append(param)
- dtype = param.dtype # assumes all same dtype
- if bucket_size == 0:
- # Create a dummy bucket if there are no parameters
- bucket = torch.zeros(1, device=device)
- else:
- # Construct the bucket (assuming all dense and same dtype)
- bucket = torch.empty(bucket_size, dtype=dtype, device=device)
- offset = 0
- for param in trainable_params:
- offset_next = offset + param.numel()
- bucket[offset:offset_next].copy_(param.data.flatten())
- param.data = bucket[offset:offset_next].view_as(param.data)
- offset = offset_next
- self._buckets[dev_i].append(bucket) # type: ignore[arg-type]
- def _build_ddp_param_buckets(self) -> None:
- r"""
- For each DDP bucket with parameters assigned to this rank, flattens the
- data of those parameters into a single tensor and saves the tensor to
- the ``tensor`` attribute in the corresponding
- :class:`_DDPBucketAssignment` instance stored in
- ``self._bucket_assignments_per_rank``.
- :class:`DistributedDataParallel` guarantees that the parameters
- corresponding to a gradient bucket have the same device and the same
- dtype.
- """
- for bucket_assignments in self._bucket_assignments_per_rank:
- for bucket_assignment in bucket_assignments.values():
- params = bucket_assignment.parameters
- bucket_size = 0
- dtype = None
- for param in params:
- assert _is_trainable(param), (
- "Model parameter "
- "corresponding to a gradient in a DDP bucket should "
- "require a gradient"
- )
- bucket_size += param.numel()
- dtype = param.dtype # assumes all same dtype
- assert bucket_size > 0, "Empty bucket"
- # Construct the bucket tensor (assuming all dense and same dtype)
- tensor = torch.empty(
- bucket_size, dtype=dtype, device=bucket_assignment.device
- )
- offset = 0
- for param in params:
- offset_next = offset + param.numel()
- tensor[offset:offset_next].copy_(param.data.flatten())
- param.data = tensor[offset:offset_next].view_as(param.data)
- offset = offset_next
- bucket_assignment.tensor = tensor
- def _verify_and_init_params(
- self,
- params: Any,
- ) -> Union[List[torch.Tensor], List[dict]]:
- r"""
- Verifies the type of ``params`` and initializes ``self._all_params``
- as a :class:`list` of all parameters if ``params`` is valid.
- Arguments:
- params (Any): Candidate parameter list or parameter groups to
- verify.
- Raises:
- TypeError: ``params`` has an invalid type.
- ValueError: ``params`` is empty.
- Returns:
- The persistent form of ``params`` to be passed into the parent
- :class:`Optimizer` constructor -- i.e. returns ``params`` as a
- :class:`list` to ensure that it can be iterated over again.
- """
- if isinstance(params, torch.Tensor):
- raise TypeError(
- "`params` argument should be an iterable of "
- f"Tensors, but got {torch.typename(params)}"
- )
- try:
- all_params = list(params)
- except TypeError as e:
- raise TypeError(
- "`params` argument should be an iterable of Tensors"
- f" or dicts, but got {torch.typename(params)}"
- ) from e
- if len(all_params) == 0:
- raise ValueError("ZeroRedundancyOptimizer got an empty parameter " "list")
- all_tensors = True
- all_dicts = True
- for param in all_params:
- all_tensors &= isinstance(param, torch.Tensor)
- all_dicts &= isinstance(param, dict)
- if not all_tensors and not all_dicts:
- raise TypeError(
- "`params` argument should be an iterable of " "Tensors or dicts"
- )
- # Ensure that `self._all_params` contains a list of all parameters
- if all_tensors:
- self._all_params = all_params
- elif all_dicts:
- self._all_params = []
- # `all_params` contains parameter groups (not parameters)
- for param_group in all_params:
- if "params" not in param_group:
- raise ValueError(
- "Each parameter group passed-in via `params` must "
- "have a 'params' key mapping to the parameters in "
- "the group"
- )
- self._all_params.extend(param_group["params"])
- return all_params
- def _verify_same_dense_param_type(self) -> None:
- r"""
- Verifies that all parameters are of the same dense type.
- The method assumes that ``self._all_params`` has been initialized
- and is non-empty.
- Raises:
- ValueError: ``params`` contains sparse parameters or parameters
- of varying dense types.
- NOTE: This method can be removed once support for sparse parameters
- and varying parameter types is added.
- """
- typename = torch.typename(self._all_params[0])
- if self._all_params[0].is_sparse:
- raise ValueError(
- "ZeroRedundancyOptimizer only supports using "
- "the same dense type for all parameters but got "
- f"{typename}"
- )
- for param in self._all_params[1:]:
- other_typename = torch.typename(param)
- if other_typename != typename:
- raise ValueError(
- "ZeroRedundancyOptimizer only supports "
- "using the same dense type for all "
- f"parameters but got both {typename} and "
- f"{other_typename}"
- )
- def _get_is_trainable_mask(self) -> List[bool]:
- r"""
- Returns a boolean mask indicating if each parameter is trainable
- (``requires_grad``) or not.
- """
- return list(map(_is_trainable, self._all_params))
- def _init_local_optimizer(self) -> None:
- r"""
- Initializes this rank's local optimizer, responsible for its subset of
- the parameters.
- The local optimizer is saved in ``self.optim``.
- """
- assert (
- self._optim_constructor is not None
- ), "The local optimizer class has not been set"
- param_groups = self._partition_parameters()[self.rank]
- # `overlap_with_ddp=True` requires a local functional optimizer
- if self._overlap_with_ddp:
- # Functional optimizers only support a single parameter group and
- # require passing in the parameters as a list
- assert len(param_groups) == 1, (
- "Initializing the local "
- "functional optimizer with more than one parameter group"
- )
- params = param_groups[0]["params"]
- # Try to pass `_allow_empty_param_list=True` to avoid erroring
- if (
- "_allow_empty_param_list"
- in inspect.signature(self._optim_constructor).parameters
- ):
- self.optim: Any = self._optim_constructor(
- params, **self._optim_defaults, _allow_empty_param_list=True
- )
- else:
- logger.warning(
- f"{self._optim_constructor} does not support the argument "
- "`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
- "error due to an empty parameter list"
- )
- self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # type: ignore[no-redef]
- # Log information about the DDP and ZeRO bucketing
- if dist.get_debug_level() != dist.DebugLevel.OFF:
- local_numel = sum(p.numel() for p in params)
- num_assigned_buckets = len(
- self._bucket_assignments_per_rank[self.global_rank]
- )
- logger.info(
- f"rank {self.global_rank} with {local_numel} parameters "
- f"across {num_assigned_buckets} buckets"
- )
- if self.global_rank == 0:
- logger.info(
- f"{len(self._overlap_info.params_per_bucket)} DDP "
- f"buckets and "
- f"{self._overlap_info.num_bucket_assignments} bucket "
- "assignments"
- )
- else:
- # NOTE: Passing `param_groups` into the local optimizer constructor
- # bypasses the empty parameter list check
- self.optim: Optimizer = self._optim_constructor(param_groups, **self._optim_defaults) # type: ignore[no-redef]
- # TODO: Manually add `self.param_groups` if using a functional
- # optimizer; remove this if/when the functional optimizers support
- # multiple parameter groups
- if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"):
- assert hasattr(self.optim, "param_group"), (
- "The functional optimizer should set at least one of the "
- "attributes `param_group` or `param_groups`"
- )
- self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined]
- self._sync_param_groups(self.optim.param_groups, self.param_groups)
- def _init_zero_for_overlap(self) -> None:
- r"""
- Performs a delayed initialization of the local optimizer and the
- supporting data structures.
- """
- assert self._overlap_with_ddp, (
- "`_init_zero_for_overlap()` should only be called when "
- "`overlap_with_ddp=True`"
- )
- self._overlap_info.status = _OverlapStatus.INITIALIZED
- self._clear_cache()
- self._partition_parameters(self._overlap_info.params_per_rank)
- self._build_ddp_param_buckets()
- self._init_local_optimizer()
- def _get_assigned_rank(self, bucket_index: int) -> int:
- r"""
- Returns the single rank assigned to a :class:`DistributedDataParallel`
- gradient bucket.
- Arguments:
- bucket_index (int): index of the :class:`DistributedDataParallel`
- bucket for which to get the assigned rank.
- """
- assert not self._overlap_info.shard_buckets, (
- "The bucket assignment requires global bucket information and "
- "will be computed later; there should be no need to use this "
- "method"
- )
- return bucket_index % self.world_size
- def _check_overlap_initialized(self):
- r"""
- Checks that the delayed initialization has occurred (see
- :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and
- raises a ``RuntimeError`` if not. This should preface methods that
- should not be run before that delayed initialization.
- Raises:
- RuntimeError: if ``overlap_with_ddp=True`` and
- :meth:`_init_zero_for_overlap` has not been called.
- """
- if (
- self._overlap_with_ddp
- and self._overlap_info.status != _OverlapStatus.INITIALIZED
- ):
- raise RuntimeError(
- "This method should not be called until this "
- "ZeroRedundancyOptimizer instance has been fully "
- "initialized"
- )
- def _get_optimizer_constructor(self, optimizer_class: Any) -> Any:
- r"""
- Returns the proper optimizer constructor, performing the necessary
- validation and transformation depending on ``overlap_with_ddp``.
- Returns:
- - ``optimizer_class`` if ``overlap_with_ddp=False`` and
- ``optimizer_class`` is not a functional optimizer.
- - ``optimizer_class`` if ``overlap_with_ddp=True`` and
- ``optimizer_class`` is already a functional optimizer.
- - The functional equivalent of ``optimizer_class`` if
- ``overlap_with_ddp=True`` and ``optimizer_class`` is not
- already a functional optimizer (assuming the equivalent
- exists).
- Raises:
- ValueError:
- - if ``overlap_with_ddp=True`` but ``optimizer_class`` is
- neither a functional optimizer nor translatable to a
- functional optimizer.
- - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a
- functional optimizer.
- """
- functional_optims = functional_optim_map.values()
- if not self._overlap_with_ddp:
- if optimizer_class in functional_optims:
- # Using a functional optimizer is only supported when
- # `overlap_with_ddp=True`
- raise ValueError(
- f"Passing in a functional optimizer {optimizer_class} "
- "when `overlap_with_ddp=False`"
- )
- else:
- return optimizer_class
- else:
- if optimizer_class in functional_optims:
- # Already a functional optimizer
- return optimizer_class
- elif optimizer_class in functional_optim_map:
- # Translate the passed-in optimizer class to its functional
- # equivalent if `overlap_with_ddp=True`
- optim_constructor = functional_optim_map[optimizer_class]
- logger.info(
- f"Using the functional optimizer {optim_constructor} "
- f"instead of {optimizer_class} since "
- "`overlap_with_ddp=True`"
- )
- return optim_constructor
- else:
- raise ValueError(
- "Using `ddp_with_overlap=True` requires using a "
- "functional optimizer, but there is no supported functional "
- f"optimizer equivalent for {optimizer_class}"
- )
|