common_dist_composable.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Owner(s): ["oncall: distributed"]
  2. from typing import Tuple
  3. import torch
  4. import torch.nn as nn
  5. class UnitModule(nn.Module):
  6. def __init__(self, device: torch.device):
  7. super().__init__()
  8. self.l1 = nn.Linear(100, 100, device=device)
  9. self.seq = nn.Sequential(
  10. nn.ReLU(),
  11. nn.Linear(100, 100, device=device),
  12. nn.ReLU(),
  13. )
  14. self.l2 = nn.Linear(100, 100, device=device)
  15. def forward(self, x):
  16. return self.l2(self.seq(self.l1(x)))
  17. class CompositeModel(nn.Module):
  18. def __init__(self, device: torch.device):
  19. super().__init__()
  20. self.l1 = nn.Linear(100, 100, device=device)
  21. self.u1 = UnitModule(device)
  22. self.u2 = UnitModule(device)
  23. self.l2 = nn.Linear(100, 100, device=device)
  24. def forward(self, x):
  25. return self.l2(self.u2(self.u1(self.l1(x))))
  26. class UnitParamModule(nn.Module):
  27. def __init__(self, device: torch.device):
  28. super().__init__()
  29. self.l = nn.Linear(100, 100, device=device)
  30. self.seq = nn.Sequential(
  31. nn.ReLU(),
  32. nn.Linear(100, 100, device=device),
  33. nn.ReLU(),
  34. )
  35. self.p = nn.Parameter(torch.randn((100, 100), device=device))
  36. def forward(self, x):
  37. return torch.mm(self.seq(self.l(x)), self.p)
  38. class CompositeParamModel(nn.Module):
  39. def __init__(self, device: torch.device):
  40. super().__init__()
  41. self.l = nn.Linear(100, 100, device=device)
  42. self.u1 = UnitModule(device)
  43. self.u2 = UnitModule(device)
  44. self.p = nn.Parameter(torch.randn((100, 100), device=device))
  45. def forward(self, x):
  46. a = self.u2(self.u1(self.l(x)))
  47. b = self.p
  48. return torch.mm(a, b)
  49. class FakeSequential(nn.Module):
  50. # Define this class to achieve a desired nested wrapping using the module
  51. # wrap policy with `nn.Sequential`
  52. def __init__(self, *modules: Tuple[nn.Module, ...]) -> None:
  53. super().__init__()
  54. self._module_sequence = list(modules)
  55. def forward(self, x: torch.Tensor) -> torch.Tensor:
  56. for module in self._module_sequence:
  57. x = module(x)
  58. return x
  59. class NestedSequentialModel(nn.Module):
  60. def __init__(self, device: torch.device) -> None:
  61. super().__init__()
  62. # This nested structure exercises traversal order to catch differences
  63. # between valid traversals (e.g. BFS and DFS variations).
  64. self.seq1 = nn.Sequential(
  65. nn.Linear(1, 1, device=device),
  66. FakeSequential(
  67. nn.Linear(1, 1, device=device),
  68. nn.ReLU(),
  69. FakeSequential(
  70. nn.Linear(1, 1, device=device),
  71. ),
  72. nn.ReLU(),
  73. ),
  74. nn.Linear(1, 2, device=device),
  75. )
  76. self.lin = nn.Linear(2, 2, device=device)
  77. self.seq2 = nn.Sequential(
  78. nn.ReLU(),
  79. nn.Linear(2, 3, device=device),
  80. FakeSequential(
  81. nn.Linear(3, 2, bias=False, device=device),
  82. nn.Linear(2, 4, bias=False, device=device),
  83. ),
  84. )
  85. def forward(self, x: torch.Tensor) -> torch.Tensor:
  86. return self.seq2(self.lin(self.seq1(x)))