123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import torch
- from torch.overrides import TorchFunctionMode
- from torch.utils._contextlib import context_decorator
- import functools
- @functools.lru_cache(1)
- def _device_constructors():
- return {
- # standard ones
- torch.empty,
- torch.empty_strided,
- torch.empty_quantized,
- torch.ones,
- torch.arange,
- torch.bartlett_window,
- torch.blackman_window,
- torch.eye,
- torch.fft.fftfreq,
- torch.fft.rfftfreq,
- torch.full,
- torch.fill,
- torch.hamming_window,
- torch.hann_window,
- torch.kaiser_window,
- torch.linspace,
- torch.logspace,
- torch.nested.nested_tensor,
- # This function doesn't actually take a device argument
- # torch.normal,
- torch.ones,
- torch.rand,
- torch.randn,
- torch.randint,
- torch.randperm,
- torch.range,
- torch.sparse_coo_tensor,
- torch.sparse_compressed_tensor,
- torch.sparse_csr_tensor,
- torch.sparse_csc_tensor,
- torch.sparse_bsr_tensor,
- torch.sparse_bsc_tensor,
- torch.tril_indices,
- torch.triu_indices,
- torch.vander,
- torch.zeros,
- torch.asarray,
- # weird ones
- torch.tensor,
- torch.as_tensor,
- torch.scalar_tensor,
- }
- # NB: This is directly called from C++ in torch/csrc/Device.cpp
- class DeviceContext(TorchFunctionMode):
- def __init__(self, device):
- self.device = torch.device(device)
- def __torch_function__(self, func, types, args=(), kwargs=None):
- kwargs = kwargs or {}
- if func in _device_constructors() and kwargs.get('device') is None:
- kwargs['device'] = self.device
- return func(*args, **kwargs)
- # NB: This is directly called from C++ in torch/csrc/Device.cpp
- def device_decorator(device, func):
- return context_decorator(lambda: device, func)
- def set_device(device):
- """
- Decorator which sets the default device inside of the wrapped
- function. If you would like to use this as a context manager,
- use device as a context manager directly, e.g.,
- ``with torch.device(device)``.
- """
- return lambda func: device_decorator(torch.device(device), func)
|