phony.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Provides phony for arbitrary dependency in a autograd graph."""
  8. from typing import Dict, List, Tuple
  9. import torch
  10. from torch import Tensor
  11. from .stream import default_stream, use_stream
  12. __all__: List[str] = ["get_phony"]
  13. _phonies: Dict[Tuple[torch.device, bool], Tensor] = {}
  14. def get_phony(device: torch.device, *, requires_grad: bool) -> Tensor:
  15. """Gets a phony. Phony is tensor without space. It is useful to make
  16. arbitrary dependency in a autograd graph because it doesn't require any
  17. gradient accumulation.
  18. .. note::
  19. Phonies for each device are cached. If an autograd function gets a phony
  20. internally, the phony must be detached to be returned. Otherwise, the
  21. autograd engine will mutate the cached phony in-place::
  22. class Phonify(torch.autograd.Function):
  23. @staticmethod
  24. def forward(ctx, input):
  25. phony = get_phony(input.device, requires_grad=False)
  26. return phony.detach() # detach() is necessary.
  27. """
  28. key = (device, requires_grad)
  29. try:
  30. phony = _phonies[key]
  31. except KeyError:
  32. with use_stream(default_stream(device)):
  33. phony = torch.empty(0, device=device, requires_grad=requires_grad)
  34. _phonies[key] = phony
  35. return phony