123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- # Copyright 2019 Kakao Brain
- #
- # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- """Utilities for eliminating boilerplate code to handle abstract streams with
- CPU device.
- """
- from contextlib import contextmanager
- from typing import Generator, List, Union, cast
- import torch
- __all__: List[str] = ["CPUStreamType", "new_stream", "current_stream", "default_stream",
- "use_device", "use_stream", "get_device", "wait_stream", "record_stream",
- "is_cuda", "as_cuda"]
- class CPUStreamType:
- pass
- # The placeholder on place of streams for the CPU device instead of CUDA.
- CPUStream = CPUStreamType()
- # It represents both CUDA streams and the CPU stream.
- AbstractStream = Union[torch.cuda.Stream, CPUStreamType]
- def new_stream(device: torch.device) -> AbstractStream:
- """Creates a new stream for either CPU or CUDA device."""
- if device.type != "cuda":
- return CPUStream
- return torch.cuda.Stream(device)
- def current_stream(device: torch.device) -> AbstractStream:
- """:func:`torch.cuda.current_stream` for either CPU or CUDA device."""
- if device.type != "cuda":
- return CPUStream
- return torch.cuda.current_stream(device)
- def default_stream(device: torch.device) -> AbstractStream:
- """:func:`torch.cuda.default_stream` for either CPU or CUDA device."""
- if device.type != "cuda":
- return CPUStream
- return torch.cuda.default_stream(device)
- @contextmanager
- def use_device(device: torch.device) -> Generator[None, None, None]:
- """:func:`torch.cuda.device` for either CPU or CUDA device."""
- if device.type != "cuda":
- yield
- return
- with torch.cuda.device(device):
- yield
- @contextmanager
- def use_stream(stream: AbstractStream) -> Generator[None, None, None]:
- """:func:`torch.cuda.stream` for either CPU or CUDA stream."""
- if not is_cuda(stream):
- yield
- return
- with torch.cuda.stream(as_cuda(stream)):
- yield
- def get_device(stream: AbstractStream) -> torch.device:
- """Gets the device from CPU or CUDA stream."""
- if is_cuda(stream):
- return as_cuda(stream).device
- return torch.device("cpu")
- def wait_stream(source: AbstractStream, target: AbstractStream) -> None:
- """:meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It
- makes the source stream wait until the target stream completes work queued.
- """
- if is_cuda(target):
- if is_cuda(source):
- # A CUDA stream waits another CUDA stream.
- as_cuda(source).wait_stream(as_cuda(target))
- else:
- # CPU waits a CUDA stream.
- as_cuda(target).synchronize()
- # If the target is CPU, synchronization is not required.
- def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None:
- """:meth:`torch.Tensor.record_stream` for either CPU or CUDA stream."""
- if is_cuda(stream):
- # NOTE(sublee): record_stream() on a shifted view tensor throws
- # RuntimeError in PyTorch 1.1.0, and does nothing in 1.2.0. To safely
- # protect the tensor against unexpected reallocation, here we use a
- # temporal tensor associated with the same storage without shifting as
- # a workaround.
- #
- # Issue: https://github.com/pytorch/pytorch/issues/27366
- #
- tensor = tensor.new_empty([0]).set_(tensor._typed_storage())
- # Typechecking: torch.cuda.Stream is incompatible with torch._C.Stream
- tensor.record_stream(as_cuda(stream)) # type: ignore[arg-type]
- def is_cuda(stream: AbstractStream) -> bool:
- """Returns ``True`` if the given stream is a valid CUDA stream."""
- return stream is not CPUStream
- def as_cuda(stream: AbstractStream) -> torch.cuda.Stream:
- """Casts the given stream as :class:`torch.cuda.Stream`."""
- return cast(torch.cuda.Stream, stream)
|