utils.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. from torch import nn
  2. from typing import List
  3. __all__ = ["partition_model"]
  4. def partition_model(
  5. module: nn.Sequential,
  6. balance: List[int],
  7. devices: List[int] = None):
  8. """
  9. Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions
  10. the model across multiple GPU devices according the provided ``balance``
  11. and ``devices``.
  12. Args:
  13. module (:class:`nn.Sequential <torch.nn.Sequential>`):
  14. Sequential model representing the pipe.
  15. balance (List[int]):
  16. List indicating the number of layers in each partition.
  17. devices (List[int], optional):
  18. List indicating the device to use for each partition. Defaults to
  19. ``range(len(balance))``
  20. """
  21. device_idx = 0
  22. pipe_idx = 0
  23. balanced_pipe = []
  24. for num_layers in balance:
  25. layers = []
  26. for i in range(num_layers):
  27. layers.append(module[pipe_idx])
  28. pipe_idx += 1
  29. device = device_idx if devices is None else devices[device_idx]
  30. balanced_pipe.append(nn.Sequential(*layers).to(device))
  31. device_idx += 1
  32. return nn.Sequential(*balanced_pipe)