DeviceGuard.h 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. #pragma once
  2. #include <ATen/core/IListRef.h>
  3. #include <ATen/core/Tensor.h>
  4. #include <c10/core/DeviceGuard.h>
  5. #include <c10/core/ScalarType.h> // TensorList whyyyyy
  6. namespace at {
  7. // Are you here because you're wondering why DeviceGuard(tensor) no
  8. // longer works? For code organization reasons, we have temporarily(?)
  9. // removed this constructor from DeviceGuard. The new way to
  10. // spell it is:
  11. //
  12. // OptionalDeviceGuard guard(device_of(tensor));
  13. /// Return the Device of a Tensor, if the Tensor is defined.
  14. inline c10::optional<Device> device_of(const Tensor& t) {
  15. if (t.defined()) {
  16. return c10::make_optional(t.device());
  17. } else {
  18. return c10::nullopt;
  19. }
  20. }
  21. inline c10::optional<Device> device_of(const c10::optional<Tensor>& t) {
  22. return t.has_value() ? device_of(t.value()) : nullopt;
  23. }
  24. /// Return the Device of a TensorList, if the list is non-empty and
  25. /// the first Tensor is defined. (This function implicitly assumes
  26. /// that all tensors in the list have the same device.)
  27. inline c10::optional<Device> device_of(ITensorListRef t) {
  28. if (!t.empty()) {
  29. return device_of(t.front());
  30. } else {
  31. return c10::nullopt;
  32. }
  33. }
  34. } // namespace at