12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import threading
- import torch
- from torch.cuda._utils import _get_device_index
- from torch.cuda.amp import autocast
- from torch._utils import ExceptionWrapper
- def get_a_var(obj):
- if isinstance(obj, torch.Tensor):
- return obj
- if isinstance(obj, (list, tuple)):
- for result in map(get_a_var, obj):
- if isinstance(result, torch.Tensor):
- return result
- if isinstance(obj, dict):
- for result in map(get_a_var, obj.items()):
- if isinstance(result, torch.Tensor):
- return result
- return None
- def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
- r"""Applies each `module` in :attr:`modules` in parallel on arguments
- contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
- on each of :attr:`devices`.
- Args:
- modules (Module): modules to be parallelized
- inputs (tensor): inputs to the modules
- devices (list of int or torch.device): CUDA devices
- :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
- :attr:`devices` (if given) should all have same length. Moreover, each
- element of :attr:`inputs` can either be a single object as the only argument
- to a module, or a collection of positional arguments.
- """
- assert len(modules) == len(inputs)
- if kwargs_tup is not None:
- assert len(modules) == len(kwargs_tup)
- else:
- kwargs_tup = ({},) * len(modules)
- if devices is not None:
- assert len(modules) == len(devices)
- else:
- devices = [None] * len(modules)
- devices = [_get_device_index(x, True) for x in devices]
- streams = [torch.cuda.current_stream(x) for x in devices]
- lock = threading.Lock()
- results = {}
- grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
- def _worker(i, module, input, kwargs, device=None, stream=None):
- torch.set_grad_enabled(grad_enabled)
- if device is None:
- device = get_a_var(input).get_device()
- if stream is None:
- stream = torch.cuda.current_stream(device)
- try:
- with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
- # this also avoids accidental slicing of `input` if it is a Tensor
- if not isinstance(input, (list, tuple)):
- input = (input,)
- output = module(*input, **kwargs)
- with lock:
- results[i] = output
- except Exception:
- with lock:
- results[i] = ExceptionWrapper(
- where="in replica {} on device {}".format(i, device))
- if len(modules) > 1:
- threads = [threading.Thread(target=_worker,
- args=(i, module, input, kwargs, device, stream))
- for i, (module, input, kwargs, device, stream) in
- enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- else:
- _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
- outputs = []
- for i in range(len(inputs)):
- output = results[i]
- if isinstance(output, ExceptionWrapper):
- output.reraise()
- outputs.append(output)
- return outputs
|