1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import torch
- from ._functions import Scatter, Gather
- import warnings
- __all__ = ['scatter', 'scatter_kwargs', 'gather']
- def is_namedtuple(obj):
- # Check if type was created from collections.namedtuple or a typing.NamedTuple.
- warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
- return _is_namedtuple(obj)
- def _is_namedtuple(obj):
- # Check if type was created from collections.namedtuple or a typing.NamedTuple.
- return (
- isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
- )
- def scatter(inputs, target_gpus, dim=0):
- r"""
- Slices tensors into approximately equal chunks and
- distributes them across given GPUs. Duplicates
- references to objects that are not tensors.
- """
- def scatter_map(obj):
- if isinstance(obj, torch.Tensor):
- return Scatter.apply(target_gpus, None, dim, obj)
- if _is_namedtuple(obj):
- return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
- if isinstance(obj, tuple) and len(obj) > 0:
- return list(zip(*map(scatter_map, obj)))
- if isinstance(obj, list) and len(obj) > 0:
- return [list(i) for i in zip(*map(scatter_map, obj))]
- if isinstance(obj, dict) and len(obj) > 0:
- return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
- return [obj for targets in target_gpus]
- # After scatter_map is called, a scatter_map cell will exist. This cell
- # has a reference to the actual function scatter_map, which has references
- # to a closure that has a reference to the scatter_map cell (because the
- # fn is recursive). To avoid this reference cycle, we set the function to
- # None, clearing the cell
- try:
- res = scatter_map(inputs)
- finally:
- scatter_map = None
- return res
- def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
- r"""Scatter with support for kwargs dictionary"""
- inputs = scatter(inputs, target_gpus, dim) if inputs else []
- kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
- if len(inputs) < len(kwargs):
- inputs.extend(() for _ in range(len(kwargs) - len(inputs)))
- elif len(kwargs) < len(inputs):
- kwargs.extend({} for _ in range(len(inputs) - len(kwargs)))
- inputs = tuple(inputs)
- kwargs = tuple(kwargs)
- return inputs, kwargs
- def gather(outputs, target_device, dim=0):
- r"""
- Gathers tensors from different GPUs on a specified device.
- Use 'cpu' for CPU to avoid a deprecation warning.
- """
- def gather_map(outputs):
- out = outputs[0]
- if isinstance(out, torch.Tensor):
- return Gather.apply(target_device, dim, *outputs)
- if out is None:
- return None
- if isinstance(out, dict):
- if not all(len(out) == len(d) for d in outputs):
- raise ValueError('All dicts must have the same number of keys')
- return type(out)((k, gather_map([d[k] for d in outputs]))
- for k in out)
- if _is_namedtuple(out):
- return type(out)._make(map(gather_map, zip(*outputs)))
- return type(out)(map(gather_map, zip(*outputs)))
- # Recursive function calls like this create reference cycles.
- # Setting the function to None clears the refcycle.
- try:
- res = gather_map(outputs)
- finally:
- gather_map = None
- return res
|