123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- import operator
- import torch
- import warnings
- from itertools import chain
- from ..modules import Module
- from .scatter_gather import scatter_kwargs, gather
- from .replicate import replicate
- from .parallel_apply import parallel_apply
- from torch._utils import (
- _get_all_device_indices,
- _get_available_device_type,
- _get_device_index,
- _get_devices_properties
- )
- __all__ = ['DataParallel', 'data_parallel']
- def _check_balance(device_ids):
- imbalance_warn = """
- There is an imbalance between your GPUs. You may want to exclude GPU {} which
- has less than 75% of the memory or cores of GPU {}. You can do so by setting
- the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
- environment variable."""
- device_ids = [_get_device_index(x, True) for x in device_ids]
- dev_props = _get_devices_properties(device_ids)
- def warn_imbalance(get_prop):
- values = [get_prop(props) for props in dev_props]
- min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
- max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
- if min_val / max_val < 0.75:
- warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
- return True
- return False
- if warn_imbalance(lambda props: props.total_memory):
- return
- if warn_imbalance(lambda props: props.multi_processor_count):
- return
- class DataParallel(Module):
- r"""Implements data parallelism at the module level.
- This container parallelizes the application of the given :attr:`module` by
- splitting the input across the specified devices by chunking in the batch
- dimension (other objects will be copied once per device). In the forward
- pass, the module is replicated on each device, and each replica handles a
- portion of the input. During the backwards pass, gradients from each replica
- are summed into the original module.
- The batch size should be larger than the number of GPUs used.
- .. warning::
- It is recommended to use :class:`~torch.nn.parallel.DistributedDataParallel`,
- instead of this class, to do multi-GPU training, even if there is only a single
- node. See: :ref:`cuda-nn-ddp-instead` and :ref:`ddp`.
- Arbitrary positional and keyword inputs are allowed to be passed into
- DataParallel but some types are specially handled. tensors will be
- **scattered** on dim specified (default 0). tuple, list and dict types will
- be shallow copied. The other types will be shared among different threads
- and can be corrupted if written to in the model's forward pass.
- The parallelized :attr:`module` must have its parameters and buffers on
- ``device_ids[0]`` before running this :class:`~torch.nn.DataParallel`
- module.
- .. warning::
- In each forward, :attr:`module` is **replicated** on each device, so any
- updates to the running module in ``forward`` will be lost. For example,
- if :attr:`module` has a counter attribute that is incremented in each
- ``forward``, it will always stay at the initial value because the update
- is done on the replicas which are destroyed after ``forward``. However,
- :class:`~torch.nn.DataParallel` guarantees that the replica on
- ``device[0]`` will have its parameters and buffers sharing storage with
- the base parallelized :attr:`module`. So **in-place** updates to the
- parameters or buffers on ``device[0]`` will be recorded. E.g.,
- :class:`~torch.nn.BatchNorm2d` and :func:`~torch.nn.utils.spectral_norm`
- rely on this behavior to update the buffers.
- .. warning::
- Forward and backward hooks defined on :attr:`module` and its submodules
- will be invoked ``len(device_ids)`` times, each with inputs located on
- a particular device. Particularly, the hooks are only guaranteed to be
- executed in correct order with respect to operations on corresponding
- devices. For example, it is not guaranteed that hooks set via
- :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
- `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
- that each such hook be executed before the corresponding
- :meth:`~torch.nn.Module.forward` call of that device.
- .. warning::
- When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
- :func:`forward`, this wrapper will return a vector of length equal to
- number of devices used in data parallelism, containing the result from
- each device.
- .. note::
- There is a subtlety in using the
- ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
- :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
- See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
- details.
- Args:
- module (Module): module to be parallelized
- device_ids (list of int or torch.device): CUDA devices (default: all devices)
- output_device (int or torch.device): device location of output (default: device_ids[0])
- Attributes:
- module (Module): the module to be parallelized
- Example::
- >>> # xdoctest: +SKIP
- >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
- >>> output = net(input_var) # input_var can be on any device, including CPU
- """
- # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
- def __init__(self, module, device_ids=None, output_device=None, dim=0):
- super().__init__()
- torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
- device_type = _get_available_device_type()
- if device_type is None:
- self.module = module
- self.device_ids = []
- return
- if device_ids is None:
- device_ids = _get_all_device_indices()
- if output_device is None:
- output_device = device_ids[0]
- self.dim = dim
- self.module = module
- self.device_ids = [_get_device_index(x, True) for x in device_ids]
- self.output_device = _get_device_index(output_device, True)
- self.src_device_obj = torch.device(device_type, self.device_ids[0])
- _check_balance(self.device_ids)
- if len(self.device_ids) == 1:
- self.module.to(self.src_device_obj)
- def forward(self, *inputs, **kwargs):
- with torch.autograd.profiler.record_function("DataParallel.forward"):
- if not self.device_ids:
- return self.module(*inputs, **kwargs)
- for t in chain(self.module.parameters(), self.module.buffers()):
- if t.device != self.src_device_obj:
- raise RuntimeError("module must have its parameters and buffers "
- "on device {} (device_ids[0]) but found one of "
- "them on device: {}".format(self.src_device_obj, t.device))
- inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
- # for forward function without any inputs, empty list and dict will be created
- # so the module can be executed on one device which is the first one in device_ids
- if not inputs and not kwargs:
- inputs = ((),)
- kwargs = ({},)
- if len(self.device_ids) == 1:
- return self.module(*inputs[0], **kwargs[0])
- replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
- outputs = self.parallel_apply(replicas, inputs, kwargs)
- return self.gather(outputs, self.output_device)
- def replicate(self, module, device_ids):
- return replicate(module, device_ids, not torch.is_grad_enabled())
- def scatter(self, inputs, kwargs, device_ids):
- return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
- def parallel_apply(self, replicas, inputs, kwargs):
- return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
- def gather(self, outputs, output_device):
- return gather(outputs, output_device, dim=self.dim)
- def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
- r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
- This is the functional version of the DataParallel module.
- Args:
- module (Module): the module to evaluate in parallel
- inputs (Tensor): inputs to the module
- device_ids (list of int or torch.device): GPU ids on which to replicate module
- output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
- (default: device_ids[0])
- Returns:
- a Tensor containing the result of module(input) located on
- output_device
- """
- if not isinstance(inputs, tuple):
- inputs = (inputs,) if inputs is not None else ()
- device_type = _get_available_device_type()
- if device_ids is None:
- device_ids = _get_all_device_indices()
- if output_device is None:
- output_device = device_ids[0]
- device_ids = [_get_device_index(x, True) for x in device_ids]
- output_device = _get_device_index(output_device, True)
- src_device_obj = torch.device(device_type, device_ids[0])
- for t in chain(module.parameters(), module.buffers()):
- if t.device != src_device_obj:
- raise RuntimeError("module must have its parameters and buffers "
- "on device {} (device_ids[0]) but found one of "
- "them on device: {}".format(src_device_obj, t.device))
- inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
- # for module without any inputs, empty list and dict will be created
- # so the module can be executed on one device which is the first one in device_ids
- if not inputs and not module_kwargs:
- inputs = ((),)
- module_kwargs = ({},)
- if len(device_ids) == 1:
- return module(*inputs[0], **module_kwargs[0])
- used_device_ids = device_ids[:len(inputs)]
- replicas = replicate(module, used_device_ids)
- outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
- return gather(outputs, output_device, dim)
|