stream.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright 2019 Kakao Brain
  2. #
  3. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
  4. #
  5. # This source code is licensed under the BSD license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """Utilities for eliminating boilerplate code to handle abstract streams with
  8. CPU device.
  9. """
  10. from contextlib import contextmanager
  11. from typing import Generator, List, Union, cast
  12. import torch
  13. __all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream",
  14. "use_device", "use_stream", "get_device", "wait_stream", "record_stream",
  15. "is_cuda", "as_cuda"]
  16. class CPUStreamType:
  17. pass
  18. # The placeholder on place of streams for the CPU device instead of CUDA.
  19. CPUStream = CPUStreamType()
  20. # It represents both CUDA streams and the CPU stream.
  21. AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
  22. def new_stream(device: torch.device) -> AbstractStream:
  23. """Creates a new stream for either CPU or CUDA device."""
  24. if device.type != "cuda":
  25. return CPUStream
  26. return torch.cuda.Stream(device)
  27. def current_stream(device: torch.device) -> AbstractStream:
  28. """:func:`torch.cuda.current_stream` for either CPU or CUDA device."""
  29. if device.type != "cuda":
  30. return CPUStream
  31. return torch.cuda.current_stream(device)
  32. def default_stream(device: torch.device) -> AbstractStream:
  33. """:func:`torch.cuda.default_stream` for either CPU or CUDA device."""
  34. if device.type != "cuda":
  35. return CPUStream
  36. return torch.cuda.default_stream(device)
  37. @contextmanager
  38. def use_device(device: torch.device) -> Generator[None, None, None]:
  39. """:func:`torch.cuda.device` for either CPU or CUDA device."""
  40. if device.type != "cuda":
  41. yield
  42. return
  43. with torch.cuda.device(device):
  44. yield
  45. @contextmanager
  46. def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
  47. """:func:`torch.cuda.stream` for either CPU or CUDA stream."""
  48. if not is_cuda(stream):
  49. yield
  50. return
  51. with torch.cuda.stream(as_cuda(stream)):
  52. yield
  53. def get_device(stream: AbstractStream) -> torch.device:
  54. """Gets the device from CPU or CUDA stream."""
  55. if is_cuda(stream):
  56. return as_cuda(stream).device
  57. return torch.device("cpu")
  58. def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
  59. """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
  60. makes the source stream wait until the target stream completes work queued.
  61. """
  62. if is_cuda(target):
  63. if is_cuda(source):
  64. # A CUDA stream waits another CUDA stream.
  65. as_cuda(source).wait_stream(as_cuda(target))
  66. else:
  67. # CPU waits a CUDA stream.
  68. as_cuda(target).synchronize()
  69. # If the target is CPU, synchronization is not required.
  70. def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
  71. """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
  72. if is_cuda(stream):
  73. # NOTE(sublee): record_stream() on a shifted view tensor throws
  74. # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
  75. # protect the tensor against unexpected reallocation, here we use a
  76. # temporal tensor associated with the same storage without shifting as
  77. # a workaround.
  78. #
  79. # Issue: https://github.com/pytorch/pytorch/issues/27366
  80. #
  81. tensor = tensor.new_empty([0]).set_(tensor._typed_storage())
  82. # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
  83. tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]
  84. def is_cuda(stream: AbstractStream) -> bool:
  85. """Returns ``True`` if the given stream is a valid CUDA stream."""
  86. return stream is not CPUStream
  87. def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
  88. """Casts the given stream as :class:`torch.cuda.Stream`."""
  89. return cast(torch.cuda.Stream, stream)