CUDAUtils.h 428 B

1234567891011121314151617181920
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. namespace at { namespace cuda {
  4. // Check if every tensor in a list of tensors matches the current
  5. // device.
  6. inline bool check_device(ArrayRef<Tensor> ts) {
  7. if (ts.empty()) {
  8. return true;
  9. }
  10. Device curDevice = Device(kCUDA, current_device());
  11. for (const Tensor& t : ts) {
  12. if (t.device() != curDevice) return false;
  13. }
  14. return true;
  15. }
  16. }} // namespace at::cuda