pipe.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """The Pipe interface."""
  8. from collections import OrderedDict
  9. from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, Sequence, Tuple, cast
  10. import torch
  11. from torch import Tensor, nn
  12. from torch.distributed.rpc import RRef
  13. import torch.autograd
  14. import torch.cuda
  15. from . import microbatch
  16. from .batchnorm import DeferredBatchNorm
  17. from .pipeline import Pipeline
  18. from .skip.layout import inspect_skip_layout
  19. from .skip.skippable import verify_skippables
  20. from .stream import AbstractStream, new_stream
  21. __all__ = ["Pipe", "BalanceError", "PipeSequential", "WithDevice"]
  22. Device = Union[torch.device, int, str]
  23. Devices = Union[Iterable[Device], List[Device]]
  24. Tensors = Sequence[Tensor]
  25. TensorOrTensors = Union[Tensor, Tensors]
  26. if TYPE_CHECKING:
  27. # Typechecking: nn.Module is not a Generic
  28. Module = nn.Module[TensorOrTensors] # type: ignore[type-arg]
  29. NamedModules = OrderedDict[str, Module]
  30. else:
  31. Module = nn.Module
  32. NamedModules = OrderedDict
  33. def _recommend_auto_balance(message: str) -> str:
  34. """Expands a message with recommendation to :mod:`torchpipe.balance`."""
  35. return f"""{message}
  36. If your model is still under development, its optimal balance would change
  37. frequently. In this case, we highly recommend 'torch.distributed.pipeline.sync.balance' for
  38. naive automatic balancing:
  39. from torch.distributed.pipeline.sync import Pipe
  40. from torch.distributed.pipeline.sync.balance import balance_by_time
  41. partitions = torch.cuda.device_count()
  42. sample = torch.empty(...)
  43. balance = balance_by_time(partitions, model, sample)
  44. model = Pipe(model, balance, ...)
  45. """
  46. def _verify_module(module: nn.Sequential) -> None:
  47. if not isinstance(module, nn.Sequential):
  48. raise TypeError("module must be nn.Sequential to be partitioned")
  49. named_children = list(module.named_children())
  50. if len(named_children) != len(module):
  51. raise ValueError("module with duplicate children is not supported")
  52. def _verify_splitting(
  53. module: nn.Sequential, partitions: List[nn.Sequential], devices: List[torch.device]
  54. ) -> None:
  55. num_parameters = len(list(module.parameters()))
  56. num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
  57. if num_parameters == num_child_parameters:
  58. return
  59. for i in range(len(partitions)):
  60. for j in range(i + 1, len(partitions)):
  61. parti = partitions[i]
  62. partj = partitions[j]
  63. if devices[i] == devices[j]:
  64. continue
  65. for p in parti.parameters():
  66. for q in partj.parameters():
  67. if p is q:
  68. raise ValueError("module with duplicate parameters on distinct devices is not supported")
  69. class BalanceError(ValueError):
  70. pass
  71. def _retrieve_device(module: nn.Module) -> torch.device:
  72. """Validates all parameters in the Module have the same device and returns
  73. the appropriate device.
  74. Args:
  75. An ``nn.Module`` to process.
  76. Returns:
  77. ``torch.Device`` for the entire module.
  78. Raises:
  79. ValueError:
  80. If devices for ``nn.Module`` parameters are not all same.
  81. """
  82. device = None
  83. for parameter in module.parameters():
  84. if device is None:
  85. device = parameter.device
  86. elif device != parameter.device:
  87. raise ValueError(
  88. 'nn.Module: {}, should have all parameters on a single device,'
  89. ' please use .to() to place the module on a single device'.format(module))
  90. return device if device is not None else torch.device("cpu")
  91. class PipeSequential(nn.Sequential):
  92. """
  93. Pipe variant of ``nn.Sequential`` which supports multiple inputs.
  94. """
  95. def forward(self, *inputs):
  96. for module in self:
  97. if isinstance(inputs, Tuple): # type: ignore[arg-type]
  98. inputs = module(*inputs)
  99. else:
  100. # Don't expand single variables (ex: lists/Tensor)
  101. inputs = module(inputs)
  102. return inputs
  103. class WithDevice(nn.Module):
  104. """
  105. Wraps an ``nn.Module`` which is part of ``nn.Sequential`` passed into :class:`Pipe`
  106. that overrides the device for that module. In cases where :class:`Pipe`
  107. can't implicitly determine the device for the module and places it on CPU,
  108. this wrapper can be used to override the implicit behavior and explicitly
  109. specify which device a module should run on.
  110. The provided module is also moved to the given device via ``.to(device)``
  111. by :class:`Pipe`
  112. Args:
  113. module(:class:`torch.nn.Module`): The module to be wrapped.
  114. device(:class:`torch.device`): The device to run the module on.
  115. Example::
  116. >>> # xdoctest: +SKIP("distributed")
  117. >>> fc1 = nn.Linear(16, 8).cuda(0)
  118. >>> fc2 = nn.Linear(8, 4).cuda(1)
  119. >>> dropout = nn.Dropout()
  120. >>>
  121. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
  122. >>> # Dropout does not have any parameters/buffers, but we want to
  123. >>> # run it on cuda:1 to avoid any GPU to CPU transfers.
  124. >>> model = nn.Sequential(fc1, fc2, WithDevice(dropout, 'cuda:1'))
  125. >>> # xdoctest: +SKIP("Needs RPC framework init")
  126. >>> model = Pipe(model, chunks=8)
  127. """
  128. def __init__(self, module: nn.Module, device: torch.device):
  129. super().__init__()
  130. self._module = module
  131. self._device = torch.device(device)
  132. def forward(self, *args, **kwargs):
  133. return self._module(*args, **kwargs)
  134. @property
  135. def module(self):
  136. return self._module
  137. @property
  138. def device(self):
  139. return self._device
  140. def _assemble_partition(modules: List[nn.Module]):
  141. modules_list: List[nn.Module] = []
  142. for module in modules:
  143. if isinstance(module, nn.Sequential):
  144. modules_list.extend(module.children())
  145. else:
  146. modules_list.append(module)
  147. return PipeSequential(*modules_list)
  148. def _split_module(modules: nn.Sequential) -> Tuple[List[nn.Sequential], List[torch.device]]:
  149. partitions = []
  150. devices = []
  151. current_partition = []
  152. current_device = None
  153. for name, module in modules.named_children():
  154. if isinstance(module, WithDevice):
  155. # Process device override and move module to appropriate device.
  156. device = module.device
  157. module = module.module
  158. module.to(device)
  159. else:
  160. device = _retrieve_device(module)
  161. if current_device is not None and (current_device != device or device.type == 'cpu'):
  162. partitions.append(_assemble_partition(current_partition))
  163. devices.append(current_device)
  164. current_partition = []
  165. current_device = device
  166. current_partition.append(module)
  167. if current_device is not None:
  168. partitions.append(_assemble_partition(current_partition))
  169. devices.append(current_device)
  170. partitions = cast(List[nn.Sequential], nn.ModuleList(partitions))
  171. return partitions, devices
  172. MOVING_DENIED = TypeError("denied to move parameters and buffers, " "because Pipe should manage device placement")
  173. class Pipe(Module):
  174. """Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
  175. to train on using synchronous pipeline parallelism. If the module requires
  176. lots of memory and doesn't fit on a single GPU, pipeline parallelism is a
  177. useful technique to employ for training.
  178. The implementation is based on the torchgpipe_ paper.
  179. .. _torchgpipe: https://arxiv.org/abs/2004.09910
  180. Pipe combines pipeline parallelism with checkpointing to reduce peak
  181. memory required to train while minimizing device under-utilization.
  182. You should place all the modules on the appropriate devices and wrap them
  183. into an :class:`nn.Sequential <torch.nn.Sequential>` module defining the
  184. desired order of execution. If a module does not contain any
  185. parameters/buffers, it is assumed this module should be executed on CPU
  186. and appropriate input tensors to the module are moved to CPU before
  187. execution. This behavior can be overridden by the :class:`WithDevice`
  188. wrapper which can be used to explicitly specify which device a module
  189. should run on.
  190. Args:
  191. module (:class:`nn.Sequential <torch.nn.Sequential>`):
  192. sequential module to be parallelized using pipelining. Each module
  193. in the sequence has to have all of its parameters on a single
  194. device. Each module in the sequence has to either be an nn.Module
  195. or :class:`nn.Sequential <torch.nn.Sequential>` (to combine multiple
  196. sequential modules on a single device)
  197. chunks (int):
  198. number of micro-batches (default: ``1``)
  199. checkpoint (str):
  200. when to enable checkpointing, one of ``'always'``,
  201. ``'except_last'``, or ``'never'`` (default: ``'except_last'``).
  202. ``'never'`` disables checkpointing completely, ``'except_last'``
  203. enables checkpointing for all micro-batches except the last one
  204. and ``'always'`` enables checkpointing for all micro-batches.
  205. deferred_batch_norm (bool):
  206. whether to use deferred ``BatchNorm`` moving statistics (default:
  207. :data:`False`). If set to :data:`True`, we track statistics across
  208. multiple micro-batches to update the running statistics per
  209. mini-batch.
  210. Raises:
  211. TypeError:
  212. the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
  213. ValueError:
  214. invalid arguments
  215. Example::
  216. Pipeline of two FC layers across GPUs 0 and 1.
  217. >>> # Need to initialize RPC framework first.
  218. >>> # xdoctest: +SKIP
  219. >>> os.environ['MASTER_ADDR'] = 'localhost'
  220. >>> os.environ['MASTER_PORT'] = '29500'
  221. >>> torch.distributed.rpc.init_rpc('worker', rank=0, world_size=1)
  222. >>>
  223. >>> # Build pipe.
  224. >>> fc1 = nn.Linear(16, 8).cuda(0)
  225. >>> fc2 = nn.Linear(8, 4).cuda(1)
  226. >>> model = nn.Sequential(fc1, fc2)
  227. >>> model = Pipe(model, chunks=8)
  228. >>> input = torch.rand(16, 16).cuda(0)
  229. >>> output_rref = model(input)
  230. .. note::
  231. You can wrap a :class:`Pipe` model with
  232. :class:`torch.nn.parallel.DistributedDataParallel` only when the
  233. checkpoint parameter of :class:`Pipe` is ``'never'``.
  234. .. note::
  235. :class:`Pipe` only supports intra-node pipelining currently, but
  236. will be expanded to support inter-node pipelining in the future.
  237. The forward function returns an :class:`~torch.distributed.rpc.RRef`
  238. to allow for inter-node pipelining in the future, where the output
  239. might be on a remote host. For intra-node pipelinining you can use
  240. :meth:`~torch.distributed.rpc.RRef.local_value` to retrieve the
  241. output locally.
  242. .. warning::
  243. :class:`Pipe` is experimental and subject to change.
  244. """
  245. def __init__(
  246. self,
  247. module: nn.Sequential,
  248. chunks: int = 1,
  249. checkpoint: str = "except_last",
  250. deferred_batch_norm: bool = False,
  251. ) -> None:
  252. super().__init__()
  253. # Check if RPC framework is initialized.
  254. if not torch.distributed.rpc._is_current_rpc_agent_set():
  255. raise RuntimeError(
  256. 'Please initialize RPC framework for Pipe using '
  257. 'torch.distributed.rpc.init_rpc')
  258. chunks = int(chunks)
  259. checkpoint = str(checkpoint)
  260. if chunks <= 0:
  261. raise ValueError("number of chunks must be positive integer")
  262. if checkpoint not in ["always", "except_last", "never"]:
  263. raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")
  264. _verify_module(module)
  265. # Verify if the underlying skippable modules satisfy integrity. The
  266. # integrity can be verified before forward() because it is static.
  267. verify_skippables(module)
  268. self.chunks = chunks
  269. self.checkpoint = checkpoint
  270. if deferred_batch_norm:
  271. module = DeferredBatchNorm.convert_deferred_batch_norm(module, chunks)
  272. self.partitions, self.devices = _split_module(module)
  273. _verify_splitting(module, self.partitions, self.devices)
  274. self._copy_streams: List[List[AbstractStream]] = []
  275. self._skip_layout = inspect_skip_layout(self.partitions)
  276. # Separate CUDA streams for copy.
  277. copy_streams = self._ensure_copy_streams()
  278. # The micro-batch index where the checkpointing stops.
  279. checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint]
  280. self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
  281. def __len__(self) -> int:
  282. """Counts the length of the underlying sequential module."""
  283. return sum(len(p) for p in self.partitions)
  284. def __getitem__(self, index: int) -> nn.Module:
  285. """Gets a layer in the underlying sequential module."""
  286. partitions = self.partitions
  287. if index < 0:
  288. partitions = partitions[::-1]
  289. for partition in partitions:
  290. try:
  291. return partition[index]
  292. except IndexError:
  293. pass
  294. shift = len(partition)
  295. if index < 0:
  296. index += shift
  297. else:
  298. index -= shift
  299. raise IndexError
  300. def __iter__(self) -> Iterable[nn.Module]:
  301. """Iterates over children of the underlying sequential module."""
  302. for partition in self.partitions:
  303. yield from partition
  304. # Pipe should manage the device of each partition.
  305. # Deny cuda(), cpu(), and to() with device, by TypeError.
  306. def cuda(self, device: Optional[Device] = None) -> "Pipe":
  307. raise MOVING_DENIED
  308. def cpu(self) -> "Pipe":
  309. raise MOVING_DENIED
  310. def to(self, *args: Any, **kwargs: Any) -> "Pipe":
  311. # Deny these usages:
  312. #
  313. # - to(device[, dtype, non_blocking])
  314. # - to(tensor[, non_blocking])
  315. #
  316. # But allow this:
  317. #
  318. # - to(dtype[, non_blocking])
  319. #
  320. if "device" in kwargs or "tensor" in kwargs:
  321. raise MOVING_DENIED
  322. if args:
  323. if isinstance(args[0], (torch.device, int, str)):
  324. raise MOVING_DENIED
  325. if torch.is_tensor(args[0]):
  326. raise MOVING_DENIED
  327. return super().to(*args, **kwargs)
  328. def _ensure_copy_streams(self) -> List[List[AbstractStream]]:
  329. """Ensures that :class:`Pipe` caches CUDA streams for copy.
  330. It's worth to cache CUDA streams although PyTorch already manages a
  331. pool of pre-allocated CUDA streams, because it may reduce GPU memory
  332. fragementation when the number of micro-batches is small.
  333. """
  334. if not self._copy_streams:
  335. for device in self.devices:
  336. self._copy_streams.append([new_stream(device) for _ in range(self.chunks)])
  337. return self._copy_streams
  338. def forward(self, *inputs) -> RRef:
  339. """
  340. Processes a single input mini-batch through the pipe and returns an
  341. :class:`~torch.distributed.rpc.RRef` pointing to the output.
  342. :class:`Pipe` is a fairly transparent module wrapper. It doesn't
  343. modify the input and output signature of the underlying module. But
  344. there's type restriction. Input and output have to contain at least one
  345. tensor. This restriction is applied at partition boundaries too.
  346. The sequence of inputs are fed into the first stage of the pipeline as
  347. ``*inputs``. As a result the positional args for this function should
  348. match the positional args for the first stage of the pipeline. The same
  349. condition applies for output of one stage of the pipeline which is the
  350. input for the next stage.
  351. The input tensor is split into multiple micro-batches based on the
  352. ``chunks`` parameter used to initialize :class:`Pipe`. The batch size
  353. is assumed to be the first dimension of the tensor and if the batch
  354. size is less than ``chunks``, the number of micro-batches is equal to
  355. the batch size.
  356. Only tensors are split into multiple micro-batches, non-Tensor inputs
  357. are just replicated as-is in each micro-batch. For non-Tensor outputs
  358. in the last stage of the pipeline, they are aggregated as a ``List``
  359. and returned the user. For example, if you have 2 micro-batches
  360. returning the integer 5, the user would receive the consolidated
  361. output of `[5, 5]`
  362. All the input tensors need to be on the same device as the first
  363. partition of the pipeline.
  364. If a tensor is wrapped with the :class:`NoChunk` wrapper, the tensor
  365. is not split across micro-batches and is replicated as-is similar to
  366. non-tensors.
  367. Args:
  368. inputs: input mini-batch
  369. Returns:
  370. :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch
  371. Raises:
  372. TypeError: input doesn't contain at least one tensor
  373. """
  374. first_partition_device = self.devices[0] if len(self.devices) != 0 else torch.device("cpu")
  375. microbatch.check(first_partition_device, *inputs)
  376. if not self.devices:
  377. # Empty sequential module is not illegal.
  378. return RRef(*inputs)
  379. # Divide a mini-batch into micro-batches.
  380. batches = microbatch.scatter(*inputs, chunks=self.chunks)
  381. # Run pipeline parallelism.
  382. self.pipeline.run(batches)
  383. # Merge the micro-batches into one mini-batch.
  384. output = microbatch.gather(batches)
  385. return RRef(output)