skippable.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 Kakao Brain
  3. #
  4. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  5. #
  6. # This source code is licensed under the BSD license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. """The user interface to define skip connections."""
  9. from typing import (
  10. TYPE_CHECKING,
  11. Any,
  12. Callable,
  13. ClassVar,
  14. Dict,
  15. FrozenSet,
  16. Generator,
  17. Iterable,
  18. List,
  19. Optional,
  20. Set,
  21. Sequence,
  22. Tuple,
  23. Type,
  24. TypeVar,
  25. Union,
  26. cast,
  27. )
  28. from torch import Tensor, nn
  29. from ..microbatch import Batch
  30. from .namespace import Namespace
  31. from .tracker import current_skip_tracker
  32. __all__ = ["skippable", "stash", "pop", "verify_skippables"]
  33. Tensors = Sequence[Tensor]
  34. TensorOrTensors = Union[Tensor, Tensors]
  35. StashPop = Union["stash", "pop"]
  36. StashPopGenerator = Generator[StashPop, Optional[Tensor], TensorOrTensors]
  37. if TYPE_CHECKING:
  38. # Typechecking: nn.Module is not a Generic
  39. SkippableModule = nn.Module[Union[StashPopGenerator, TensorOrTensors]] # type: ignore[type-arg]
  40. else:
  41. SkippableModule = nn.Module
  42. T = TypeVar("T", bound="Skippable")
  43. class Skippable(nn.Module):
  44. """The base class for skippable modules.
  45. Do not use this class directly. Define a subclass by :func:`skippable`
  46. instead.
  47. """
  48. module_cls: ClassVar[Type[SkippableModule]]
  49. stashable_names: ClassVar[FrozenSet[str]]
  50. poppable_names: ClassVar[FrozenSet[str]]
  51. def __init__(self, *args: Any, **kwargs: Any) -> None:
  52. super().__init__()
  53. self.module = self.module_cls(*args, **kwargs) # type: ignore[call-arg]
  54. self.namespaces: Dict[str, Namespace] = {}
  55. def __repr__(self) -> str:
  56. return f"@skippable({self.module})"
  57. def namespaced(self, name: str) -> Tuple[Namespace, str]:
  58. """Prepends namespace for the given skip name."""
  59. ns = self.namespaces.get(name)
  60. ns = cast(Namespace, ns)
  61. return (ns, name)
  62. def stashable(self) -> Iterable[Tuple[Namespace, str]]:
  63. """Iterates over namespaced skip names to be stashed."""
  64. for name in self.stashable_names:
  65. yield self.namespaced(name)
  66. def poppable(self) -> Iterable[Tuple[Namespace, str]]:
  67. """Iterates over namespaced skip names to be popped."""
  68. for name in self.poppable_names:
  69. yield self.namespaced(name)
  70. def isolate(self: T, ns: Namespace, *, only: Optional[Iterable[str]] = None) -> T:
  71. r"""Isolates a specified subset or the whole set of skip tensors into a
  72. namespace. In a single sequential module, skip tensors with the same
  73. name are not allowed unless they are isolated by different namespaces.
  74. Here's an example using the same name for skip tensors twice. Each pair
  75. of ``Layer1`` and ``Layer2`` is isolated with its own namespace ``ns1``
  76. and ``ns2``. There is no conflict anymore::
  77. ns1 = Namespace()
  78. ns2 = Namespace()
  79. model = nn.Sequential(
  80. Layer1().isolate(ns1),
  81. Layer1().isolate(ns2),
  82. Layer2(),
  83. Layer3().isolate(ns2),
  84. Layer3().isolate(ns1),
  85. )
  86. When `only` parameter is omitted, all skip tensors are isolated. You
  87. can isolate a subset of skip tensors by passing `only` parameter::
  88. ns_alice = Namespace()
  89. ns_bob = Namespace()
  90. model = nn.Sequential(
  91. ...
  92. StashStashPop().isolate(ns_alice, only=['alice']) \
  93. .isolate(ns_bob, only=['bob']),
  94. ...
  95. )
  96. Args:
  97. ns (Namespace):
  98. namespace for isolation
  99. Keyword Args:
  100. only (iterable of strs):
  101. names of specific skip tensors to be isolated (omit this option
  102. to isolate all skip tensors declared in this module)
  103. Returns:
  104. this module itself
  105. """
  106. names: Iterable[str]
  107. if only is None:
  108. names = self.stashable_names | self.poppable_names
  109. else:
  110. names = set(only)
  111. for name in names:
  112. self.namespaces[name] = ns
  113. return self
  114. def dispatch(
  115. self,
  116. input,
  117. handle_stash: Callable[[str, Optional[Tensor]], None],
  118. handle_pop: Callable[[str], Optional[Tensor]],
  119. ):
  120. """Dispatches :class:`stash` or :class:`pop` commands generated by the
  121. module's ``forward()``.
  122. """
  123. generator = self.module(input)
  124. if not isinstance(generator, Generator):
  125. # The underlying module returned output without any yield.
  126. output = generator
  127. return output
  128. try:
  129. op = next(generator)
  130. while True:
  131. if isinstance(op, stash):
  132. handle_stash(op.name, op.tensor)
  133. op = next(generator)
  134. continue
  135. if isinstance(op, pop):
  136. tensor = handle_pop(op.name)
  137. op = generator.send(tensor)
  138. continue
  139. raise TypeError("%r is not a command from @skippable" % op)
  140. except StopIteration as stop:
  141. output = stop.args[0]
  142. return output
  143. def forward(self, input: Union[List[Any], Tensor]) -> TensorOrTensors:
  144. """Performs the forward propagation. :class:`stash` or :class:`pop`
  145. commands will be handled by portals silently. The portals won't be
  146. exposed to users.
  147. Raises:
  148. RuntimeError:
  149. illegal 'stash' or 'pop' is found.
  150. """
  151. skip_tracker = current_skip_tracker()
  152. stashed_tensors: Dict[str, Optional[Tensor]] = {}
  153. # Load skip tensors that might be popped.
  154. poppable_tensors = {}
  155. batch = Batch(input)
  156. for ns, name in self.poppable():
  157. try:
  158. poppable_tensors[name] = skip_tracker.load(batch, ns, name)
  159. except KeyError as e:
  160. raise RuntimeError(f"'{name}' has not been stashed") from e
  161. input = batch.values
  162. # Handle skip commands.
  163. def handle_stash(name: str, tensor: Optional[Tensor]) -> None:
  164. if name not in self.stashable_names:
  165. raise RuntimeError(f"'{name}' has not been declared as stashable")
  166. stashed_tensors[name] = tensor
  167. def handle_pop(name: str) -> Optional[Tensor]:
  168. if name not in self.poppable_names:
  169. raise RuntimeError(f"'{name}' has not been declared as poppable")
  170. return poppable_tensors.pop(name)
  171. output = self.dispatch(input, handle_stash, handle_pop)
  172. # All declared skips must be stashed or popped.
  173. not_stashed = self.stashable_names - stashed_tensors.keys()
  174. if not_stashed:
  175. comma_names = ", ".join("'%s'" % n for n in not_stashed)
  176. raise RuntimeError(f"{comma_names} must be stashed but have not")
  177. not_popped = poppable_tensors.keys()
  178. if not_popped:
  179. comma_names = ", ".join("'%s'" % n for n in not_popped)
  180. raise RuntimeError(f"{comma_names} must be popped but have not")
  181. # Save stashed skip tensors.
  182. batch = Batch(output)
  183. for ns, name in self.stashable():
  184. tensor = stashed_tensors[name]
  185. skip_tracker.save(batch, ns, name, tensor)
  186. output = batch.values
  187. return output
  188. # TODO(sublee): Move to above of Skippable class for better read flow.
  189. def skippable(
  190. stash: Iterable[str] = (), pop: Iterable[str] = (),
  191. ) -> Callable[[Type[SkippableModule]], Type[Skippable]]:
  192. """The decorator to define a :class:`nn.Module <torch.nn.Module>` with skip
  193. connections. Decorated modules are called "skippable". This functionality
  194. works perfectly fine even when the module is not wrapped by
  195. :class:`~torch.distributed.pipeline.sync.Pipe`.
  196. Each skip tensor is managed by its name. Before manipulating skip tensors,
  197. a skippable module must statically declare the names for skip tensors by
  198. `stash` and/or `pop` parameters. Skip tensors with pre-declared name can be
  199. stashed by ``yield stash(name, tensor)`` or popped by ``tensor = yield
  200. pop(name)``.
  201. Here is an example with three layers. A skip tensor named "1to3" is stashed
  202. and popped at the first and last layer, respectively::
  203. @skippable(stash=['1to3'])
  204. class Layer1(nn.Module):
  205. def forward(self, input):
  206. yield stash('1to3', input)
  207. return f1(input)
  208. class Layer2(nn.Module):
  209. def forward(self, input):
  210. return f2(input)
  211. @skippable(pop=['1to3'])
  212. class Layer3(nn.Module):
  213. def forward(self, input):
  214. skip_1to3 = yield pop('1to3')
  215. return f3(input) + skip_1to3
  216. model = nn.Sequential(Layer1(), Layer2(), Layer3())
  217. One skippable module can stash or pop multiple skip tensors::
  218. @skippable(stash=['alice', 'bob'], pop=['carol'])
  219. class StashStashPop(nn.Module):
  220. def forward(self, input):
  221. yield stash('alice', f_alice(input))
  222. yield stash('bob', f_bob(input))
  223. carol = yield pop('carol')
  224. return input + carol
  225. Every skip tensor must be associated with exactly one pair of `stash` and
  226. `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this
  227. restriction automatically when wrapping a module. You can also check the
  228. restriction by :func:`verify_skippables`
  229. without :class:`~torch.distributed.pipeline.sync.Pipe`.
  230. """
  231. stashable_names = frozenset(stash)
  232. poppable_names = frozenset(pop)
  233. def extend_skippable(module_cls: Type[SkippableModule]) -> Type[Skippable]:
  234. name = module_cls.__name__
  235. bases = (Skippable,)
  236. attrs = {"module_cls": module_cls, "stashable_names": stashable_names, "poppable_names": poppable_names}
  237. return type(name, bases, attrs)
  238. return extend_skippable
  239. class stash:
  240. """The command to stash a skip tensor.
  241. ::
  242. def forward(self, input):
  243. yield stash('name', input)
  244. return f(input)
  245. Args:
  246. name (str): name of skip tensor
  247. input (torch.Tensor or None): tensor to pass to the skip connection
  248. """
  249. __slots__ = ("name", "tensor")
  250. def __init__(self, name: str, tensor: Optional[Tensor]) -> None:
  251. self.name = name
  252. self.tensor = tensor
  253. class pop:
  254. """The command to pop a skip tensor.
  255. ::
  256. def forward(self, input):
  257. skip = yield pop('name')
  258. return f(input) + skip
  259. Args:
  260. name (str): name of skip tensor
  261. Returns:
  262. the skip tensor previously stashed by another layer under the same name
  263. """
  264. __slots__ = ("name",)
  265. def __init__(self, name: str) -> None:
  266. self.name = name
  267. def verify_skippables(module: nn.Sequential) -> None:
  268. """Verifies if the underlying skippable modules satisfy integrity.
  269. Every skip tensor must have only one pair of `stash` and `pop`. If there
  270. are one or more unmatched pairs, it will raise :exc:`TypeError` with the
  271. detailed messages.
  272. Here are a few failure cases. :func:`verify_skippables` will report failure
  273. for these cases::
  274. # Layer1 stashes "1to3".
  275. # Layer3 pops "1to3".
  276. nn.Sequential(Layer1(), Layer2())
  277. # └──── ?
  278. nn.Sequential(Layer2(), Layer3())
  279. # ? ────┘
  280. nn.Sequential(Layer1(), Layer2(), Layer3(), Layer3())
  281. # └───────────────────┘ ^^^^^^
  282. nn.Sequential(Layer1(), Layer1(), Layer2(), Layer3())
  283. # ^^^^^^ └───────────────────┘
  284. To use the same name for multiple skip tensors, they must be isolated by
  285. different namespaces. See :meth:`isolate()
  286. <torchpipe.skip.skippable.Skippable.isolate>`.
  287. Raises:
  288. TypeError:
  289. one or more pairs of `stash` and `pop` are not matched.
  290. """
  291. stashed: Set[Tuple[Namespace, str]] = set()
  292. popped: Set[Tuple[Namespace, str]] = set()
  293. msgs: List[str] = []
  294. for layer_name, layer in module.named_children():
  295. if not isinstance(layer, Skippable):
  296. continue
  297. for name in layer.stashable_names & layer.poppable_names:
  298. msg = f"'{layer_name}' declared '{name}' both as stashable and as poppable"
  299. msgs.append(msg)
  300. for ns, name in layer.stashable():
  301. if name in layer.poppable_names:
  302. continue
  303. if (ns, name) in stashed:
  304. msg = f"'{layer_name}' redeclared '{name}' as stashable " "but not isolated by namespace"
  305. msgs.append(msg)
  306. continue
  307. stashed.add((ns, name))
  308. for ns, name in layer.poppable():
  309. if name in layer.stashable_names:
  310. continue
  311. if (ns, name) in popped:
  312. msg = f"'{layer_name}' redeclared '{name}' as poppable " "but not isolated by namespace"
  313. msgs.append(msg)
  314. continue
  315. if (ns, name) not in stashed:
  316. msg = f"'{layer_name}' declared '{name}' as poppable but it was not stashed"
  317. msgs.append(msg)
  318. continue
  319. popped.add((ns, name))
  320. for (_, name) in stashed - popped:
  321. msg = f"no module declared '{name}' as poppable but stashed"
  322. msgs.append(msg)
  323. if msgs:
  324. raise TypeError(
  325. "one or more pairs of stash and pop do not match:\n\n%s" "" % "\n".join("* %s" % x for x in msgs)
  326. )