IndexUtils.cuh 721 B

1234567891011121314151617181920212223242526272829303132
  1. #pragma once
  2. #include <ATen/core/TensorBase.h>
  3. #include <ATen/cuda/detail/TensorInfo.cuh>
  4. #include <ATen/native/CanUse32BitIndexMath.h>
  5. namespace at {
  6. namespace cuda {
  7. namespace detail {
  8. TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
  9. using at::native::canUse32BitIndexMath;
  10. template <typename scalar, typename IndexType>
  11. TensorInfo<scalar, IndexType>
  12. getTensorInfo(const at::TensorBase &t) {
  13. IndexType sz[MAX_TENSORINFO_DIMS];
  14. IndexType st[MAX_TENSORINFO_DIMS];
  15. int dims = t.dim();
  16. for (int i = 0; i < dims; ++i) {
  17. sz[i] = t.size(i);
  18. st[i] = t.stride(i);
  19. }
  20. return TensorInfo<scalar, IndexType>(
  21. t.data_ptr<scalar>(), dims, sz, st);
  22. }
  23. } // detail
  24. } // cuda
  25. } // at