import torch from ._functions import Scatter, Gather import warnings __all__ = ['scatter', 'scatter_kwargs', 'gather'] def is_namedtuple(obj): # Check if type was created from collections.namedtuple or a typing.NamedTuple. warnings.warn("is_namedtuple is deprecated, please use the python checks instead") return _is_namedtuple(obj) def _is_namedtuple(obj): # Check if type was created from collections.namedtuple or a typing.NamedTuple. return ( isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") ) def scatter(inputs, target_gpus, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if _is_namedtuple(obj): return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return [list(i) for i in zip(*map(scatter_map, obj))] if isinstance(obj, dict) and len(obj) > 0: return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: res = scatter_map(inputs) finally: scatter_map = None return res def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): r"""Scatter with support for kwargs dictionary""" inputs = scatter(inputs, target_gpus, dim) if inputs else [] kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend(() for _ in range(len(kwargs) - len(inputs))) elif len(kwargs) < len(inputs): kwargs.extend({} for _ in range(len(inputs) - len(kwargs))) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs def gather(outputs, target_device, dim=0): r""" Gathers tensors from different GPUs on a specified device. Use 'cpu' for CPU to avoid a deprecation warning. """ def gather_map(outputs): out = outputs[0] if isinstance(out, torch.Tensor): return Gather.apply(target_device, dim, *outputs) if out is None: return None if isinstance(out, dict): if not all(len(out) == len(d) for d in outputs): raise ValueError('All dicts must have the same number of keys') return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) if _is_namedtuple(out): return type(out)._make(map(gather_map, zip(*outputs))) return type(out)(map(gather_map, zip(*outputs))) # Recursive function calls like this create reference cycles. # Setting the function to None clears the refcycle. try: res = gather_map(outputs) finally: gather_map = None return res