import torch from torch.overrides import TorchFunctionMode from torch.utils._contextlib import context_decorator import functools @functools.lru_cache(1) def _device_constructors(): return { # standard ones torch.empty, torch.empty_strided, torch.empty_quantized, torch.ones, torch.arange, torch.bartlett_window, torch.blackman_window, torch.eye, torch.fft.fftfreq, torch.fft.rfftfreq, torch.full, torch.fill, torch.hamming_window, torch.hann_window, torch.kaiser_window, torch.linspace, torch.logspace, torch.nested.nested_tensor, # This function doesn't actually take a device argument # torch.normal, torch.ones, torch.rand, torch.randn, torch.randint, torch.randperm, torch.range, torch.sparse_coo_tensor, torch.sparse_compressed_tensor, torch.sparse_csr_tensor, torch.sparse_csc_tensor, torch.sparse_bsr_tensor, torch.sparse_bsc_tensor, torch.tril_indices, torch.triu_indices, torch.vander, torch.zeros, torch.asarray, # weird ones torch.tensor, torch.as_tensor, torch.scalar_tensor, } # NB: This is directly called from C++ in torch/csrc/Device.cpp class DeviceContext(TorchFunctionMode): def __init__(self, device): self.device = torch.device(device) def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} if func in _device_constructors() and kwargs.get('device') is None: kwargs['device'] = self.device return func(*args, **kwargs) # NB: This is directly called from C++ in torch/csrc/Device.cpp def device_decorator(device, func): return context_decorator(lambda: device, func) def set_device(device): """ Decorator which sets the default device inside of the wrapped function. If you would like to use this as a context manager, use device as a context manager directly, e.g., ``with torch.device(device)``. """ return lambda func: device_decorator(torch.device(device), func)