replicate.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from . import comm
  2. from torch._utils import _get_device_index
  3. from collections import OrderedDict
  4. def _is_script_module(module):
  5. import torch.jit
  6. return isinstance(module, torch.jit.ScriptModule)
  7. def _is_script_method(module):
  8. import torch.jit
  9. return isinstance(module, torch._C.ScriptMethod)
  10. def _init_script_module():
  11. import torch.jit
  12. return torch.jit.ScriptModule()
  13. def _is_jit_enabled():
  14. import torch.jit
  15. return torch.jit._state._enabled
  16. # Check if we can safely replicate the module.
  17. # there are two types of module:
  18. # 1. python modules
  19. # 2. ScriptModule
  20. #
  21. # currently a module cannot be replicated properly if the descendants of
  22. # any ScriptModule contains python module (type 1 above)
  23. def _replicatable_module(module, memo=None):
  24. # module.modules() contains module itself as the first element
  25. def descendant_modules(module):
  26. gen = module.modules()
  27. next(gen)
  28. return gen
  29. if not _is_jit_enabled():
  30. return True
  31. if memo is None:
  32. memo = set()
  33. # memoize visited modules
  34. memo.add(module)
  35. if _is_script_module(module):
  36. memo.update(descendant_modules(module))
  37. return all(_is_script_module(descendant) for
  38. descendant in descendant_modules(module))
  39. for child in module.children():
  40. # since any unreplicatable module will cause the check to return
  41. # False early, visited modules here can be safely ignored.
  42. if child in memo:
  43. continue
  44. if not _replicatable_module(child, memo):
  45. return False
  46. return True
  47. def _broadcast_coalesced_reshape(tensors, devices, detach=False):
  48. from ._functions import Broadcast
  49. if detach:
  50. return comm.broadcast_coalesced(tensors, devices)
  51. else:
  52. # Use the autograd function to broadcast if not detach
  53. if len(tensors) > 0:
  54. tensor_copies = Broadcast.apply(devices, *tensors)
  55. return [tensor_copies[i:i + len(tensors)]
  56. for i in range(0, len(tensor_copies), len(tensors))]
  57. else:
  58. return []
  59. def replicate(network, devices, detach=False):
  60. if not _replicatable_module(network):
  61. raise RuntimeError("Cannot replicate network where python modules are "
  62. "childrens of ScriptModule")
  63. if not devices:
  64. return []
  65. devices = [_get_device_index(x, True) for x in devices]
  66. num_replicas = len(devices)
  67. params = list(network.parameters())
  68. param_indices = {param: idx for idx, param in enumerate(params)}
  69. param_copies = _broadcast_coalesced_reshape(params, devices, detach)
  70. buffers = list(network.buffers())
  71. buffers_rg = []
  72. buffers_not_rg = []
  73. for buf in buffers:
  74. if buf.requires_grad and not detach:
  75. buffers_rg.append(buf)
  76. else:
  77. buffers_not_rg.append(buf)
  78. buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
  79. buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
  80. buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
  81. buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
  82. modules = list(network.modules())
  83. module_copies = [[] for device in devices]
  84. module_indices = {}
  85. for i, module in enumerate(modules):
  86. module_indices[module] = i
  87. for j in range(num_replicas):
  88. replica = module._replicate_for_data_parallel()
  89. # This is a temporary fix for DDP. DDP needs to access the
  90. # replicated model parameters. It used to do so through
  91. # `mode.parameters()`. The fix added in #33907 for DP stops the
  92. # `parameters()` API from exposing the replicated parameters.
  93. # Hence, we add a `_former_parameters` dict here to support DDP.
  94. replica._former_parameters = OrderedDict()
  95. module_copies[j].append(replica)
  96. for i, module in enumerate(modules):
  97. for key, child in module._modules.items():
  98. if child is None:
  99. for j in range(num_replicas):
  100. replica = module_copies[j][i]
  101. replica._modules[key] = None
  102. else:
  103. module_idx = module_indices[child]
  104. for j in range(num_replicas):
  105. replica = module_copies[j][i]
  106. setattr(replica, key, module_copies[j][module_idx])
  107. for key, param in module._parameters.items():
  108. if param is None:
  109. for j in range(num_replicas):
  110. replica = module_copies[j][i]
  111. replica._parameters[key] = None
  112. else:
  113. param_idx = param_indices[param]
  114. for j in range(num_replicas):
  115. replica = module_copies[j][i]
  116. param = param_copies[j][param_idx]
  117. # parameters in replicas are no longer leaves,
  118. # so setattr them as non-parameter attributes
  119. setattr(replica, key, param)
  120. # expose the parameter for DDP
  121. replica._former_parameters[key] = param
  122. for key, buf in module._buffers.items():
  123. if buf is None:
  124. for j in range(num_replicas):
  125. replica = module_copies[j][i]
  126. replica._buffers[key] = None
  127. else:
  128. if buf.requires_grad and not detach:
  129. buffer_copies = buffer_copies_rg
  130. buffer_idx = buffer_indices_rg[buf]
  131. else:
  132. buffer_copies = buffer_copies_not_rg
  133. buffer_idx = buffer_indices_not_rg[buf]
  134. for j in range(num_replicas):
  135. replica = module_copies[j][i]
  136. setattr(replica, key, buffer_copies[j][buffer_idx])
  137. return [module_copies[j][0] for j in range(num_replicas)]