layout.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Static skip connection layout of ``@skippable`` modules."""
  8. from typing import Dict, Iterable, List, Tuple
  9. from torch import nn
  10. from .namespace import Namespace
  11. __all__: List[str] = []
  12. class SkipLayout:
  13. """Represents a skip connection layout across partitions."""
  14. # Skip routes indexed by 'ns, name': {(ns, name): (prev_j, next_j), ...}
  15. by_ns_name: Dict[Tuple[Namespace, str], Tuple[int, int]]
  16. # Skip routes indexed by partition number 'j': [[next_j]: [(prev_j, ns, name), ...], ...]
  17. by_partition: List[List[Tuple[int, Namespace, str]]]
  18. def __init__(self, num_partitions: int, skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]],) -> None:
  19. # The skip routes are already indexed by 'ns, name'.
  20. self.by_ns_name = skip_routes
  21. # Index skip routes by partition number 'j'.
  22. self.by_partition = [[] for _ in range(num_partitions)]
  23. for (ns, name), (prev_j, next_j) in skip_routes.items():
  24. self.by_partition[next_j].append((prev_j, ns, name))
  25. for p in self.by_partition:
  26. p.sort()
  27. def copy_policy(self, next_j: int) -> Iterable[Tuple[int, Namespace, str]]:
  28. """Generates skip routes for the given destination partition number.
  29. The skip routes are sorted by source partition number in ascending
  30. order.
  31. Yields:
  32. Each tuple of (source partition number, namespace, name).
  33. """
  34. for prev_j, ns, name in self.by_partition[next_j]:
  35. if prev_j == next_j:
  36. # This skip tensor will be popped at the same partition where
  37. # it is stashed. In this case, copy is not required.
  38. continue
  39. yield (prev_j, ns, name)
  40. def requires_copy(self, ns: Namespace, name: str) -> bool:
  41. """Whether the given namespace and name requires partition-to-partition
  42. copy or not.
  43. """
  44. prev_j, next_j = self.by_ns_name.get((ns, name), (-1, -1))
  45. return prev_j != next_j
  46. def inspect_skip_layout(partitions: List[nn.Sequential]) -> SkipLayout:
  47. """Inspects the skip connection layout in the given partitions."""
  48. # NOTE(sublee): Hide circular import inside this subroutine. Circular
  49. # import is not ideal but placing this logic near to SkipLayout may
  50. # increase cohesion of code.
  51. from .skippable import Skippable
  52. skip_routes: Dict[Tuple[Namespace, str], Tuple[int, int]] = {}
  53. stashed_at: Dict[Tuple[Namespace, str], int] = {}
  54. for j, partition in enumerate(partitions):
  55. def inspect_layer(layer):
  56. if not isinstance(layer, Skippable):
  57. return
  58. for ns, name in layer.stashable():
  59. stashed_at[(ns, name)] = j
  60. for ns, name in layer.poppable():
  61. prev_j = stashed_at.pop((ns, name))
  62. skip_routes[(ns, name)] = (prev_j, j)
  63. if isinstance(partition, nn.Sequential):
  64. for layer in partition:
  65. inspect_layer(layer)
  66. else:
  67. inspect_layer(partition)
  68. return SkipLayout(len(partitions), skip_routes)