_device.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. from torch.overrides import TorchFunctionMode
  3. from torch.utils._contextlib import context_decorator
  4. import functools
  5. @functools.lru_cache(1)
  6. def _device_constructors():
  7. return {
  8. # standard ones
  9. torch.empty,
  10. torch.empty_strided,
  11. torch.empty_quantized,
  12. torch.ones,
  13. torch.arange,
  14. torch.bartlett_window,
  15. torch.blackman_window,
  16. torch.eye,
  17. torch.fft.fftfreq,
  18. torch.fft.rfftfreq,
  19. torch.full,
  20. torch.fill,
  21. torch.hamming_window,
  22. torch.hann_window,
  23. torch.kaiser_window,
  24. torch.linspace,
  25. torch.logspace,
  26. torch.nested.nested_tensor,
  27. # This function doesn't actually take a device argument
  28. # torch.normal,
  29. torch.ones,
  30. torch.rand,
  31. torch.randn,
  32. torch.randint,
  33. torch.randperm,
  34. torch.range,
  35. torch.sparse_coo_tensor,
  36. torch.sparse_compressed_tensor,
  37. torch.sparse_csr_tensor,
  38. torch.sparse_csc_tensor,
  39. torch.sparse_bsr_tensor,
  40. torch.sparse_bsc_tensor,
  41. torch.tril_indices,
  42. torch.triu_indices,
  43. torch.vander,
  44. torch.zeros,
  45. torch.asarray,
  46. # weird ones
  47. torch.tensor,
  48. torch.as_tensor,
  49. torch.scalar_tensor,
  50. }
  51. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  52. class DeviceContext(TorchFunctionMode):
  53. def __init__(self, device):
  54. self.device = torch.device(device)
  55. def __torch_function__(self, func, types, args=(), kwargs=None):
  56. kwargs = kwargs or {}
  57. if func in _device_constructors() and kwargs.get('device') is None:
  58. kwargs['device'] = self.device
  59. return func(*args, **kwargs)
  60. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  61. def device_decorator(device, func):
  62. return context_decorator(lambda: device, func)
  63. def set_device(device):
  64. """
  65. Decorator which sets the default device inside of the wrapped
  66. function. If you would like to use this as a context manager,
  67. use device as a context manager directly, e.g.,
  68. ``with torch.device(device)``.
  69. """
  70. return lambda func: device_decorator(torch.device(device), func)