dependency.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Arbitrary dependency between two autograd lanes."""
  8. from typing import List, Tuple
  9. import torch
  10. from torch import Tensor
  11. from .phony import get_phony
  12. __all__: List[str] = ["fork", "Fork", "join", "Join"]
  13. def fork(input: Tensor) -> Tuple[Tensor, Tensor]:
  14. """Branches out from an autograd lane of the given tensor."""
  15. if torch.is_grad_enabled() and input.requires_grad:
  16. input, phony = Fork.apply(input)
  17. else:
  18. phony = get_phony(input.device, requires_grad=False)
  19. return input, phony
  20. class Fork(torch.autograd.Function):
  21. @staticmethod
  22. def forward(ctx: "Fork", input: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore[override]
  23. phony = get_phony(input.device, requires_grad=False)
  24. return input.detach(), phony.detach()
  25. @staticmethod
  26. def backward(ctx: "Fork", grad_input: Tensor, grad_grad: Tensor) -> Tensor: # type: ignore[override]
  27. return grad_input
  28. def join(input: Tensor, phony: Tensor) -> Tensor:
  29. """Merges two autograd lanes."""
  30. if torch.is_grad_enabled() and (input.requires_grad or phony.requires_grad):
  31. input = Join.apply(input, phony)
  32. return input
  33. class Join(torch.autograd.Function):
  34. @staticmethod
  35. def forward(ctx: "Join", input: Tensor, phony: Tensor) -> Tensor: # type: ignore[override]
  36. return input.detach()
  37. @staticmethod
  38. def backward(ctx: "Join", grad_input: Tensor) -> Tuple[Tensor, None]: # type: ignore[override]
  39. return grad_input, None