parallel_apply.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import threading
  2. import torch
  3. from torch.cuda._utils import _get_device_index
  4. from torch.cuda.amp import autocast
  5. from torch._utils import ExceptionWrapper
  6. def get_a_var(obj):
  7. if isinstance(obj, torch.Tensor):
  8. return obj
  9. if isinstance(obj, (list, tuple)):
  10. for result in map(get_a_var, obj):
  11. if isinstance(result, torch.Tensor):
  12. return result
  13. if isinstance(obj, dict):
  14. for result in map(get_a_var, obj.items()):
  15. if isinstance(result, torch.Tensor):
  16. return result
  17. return None
  18. def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
  19. r"""Applies each `module` in :attr:`modules` in parallel on arguments
  20. contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
  21. on each of :attr:`devices`.
  22. Args:
  23. modules (Module): modules to be parallelized
  24. inputs (tensor): inputs to the modules
  25. devices (list of int or torch.device): CUDA devices
  26. :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
  27. :attr:`devices` (if given) should all have same length. Moreover, each
  28. element of :attr:`inputs` can either be a single object as the only argument
  29. to a module, or a collection of positional arguments.
  30. """
  31. assert len(modules) == len(inputs)
  32. if kwargs_tup is not None:
  33. assert len(modules) == len(kwargs_tup)
  34. else:
  35. kwargs_tup = ({},) * len(modules)
  36. if devices is not None:
  37. assert len(modules) == len(devices)
  38. else:
  39. devices = [None] * len(modules)
  40. devices = [_get_device_index(x, True) for x in devices]
  41. streams = [torch.cuda.current_stream(x) for x in devices]
  42. lock = threading.Lock()
  43. results = {}
  44. grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
  45. def _worker(i, module, input, kwargs, device=None, stream=None):
  46. torch.set_grad_enabled(grad_enabled)
  47. if device is None:
  48. device = get_a_var(input).get_device()
  49. if stream is None:
  50. stream = torch.cuda.current_stream(device)
  51. try:
  52. with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
  53. # this also avoids accidental slicing of `input` if it is a Tensor
  54. if not isinstance(input, (list, tuple)):
  55. input = (input,)
  56. output = module(*input, **kwargs)
  57. with lock:
  58. results[i] = output
  59. except Exception:
  60. with lock:
  61. results[i] = ExceptionWrapper(
  62. where="in replica {} on device {}".format(i, device))
  63. if len(modules) > 1:
  64. threads = [threading.Thread(target=_worker,
  65. args=(i, module, input, kwargs, device, stream))
  66. for i, (module, input, kwargs, device, stream) in
  67. enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
  68. for thread in threads:
  69. thread.start()
  70. for thread in threads:
  71. thread.join()
  72. else:
  73. _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
  74. outputs = []
  75. for i in range(len(inputs)):
  76. output = results[i]
  77. if isinstance(output, ExceptionWrapper):
  78. output.reraise()
  79. outputs.append(output)
  80. return outputs