1234567891011121314151617181920212223242526272829303132 |
- #pragma once
- #include <ATen/core/TensorBase.h>
- #include <ATen/cuda/detail/TensorInfo.cuh>
- #include <ATen/native/CanUse32BitIndexMath.h>
- namespace at {
- namespace cuda {
- namespace detail {
- TORCH_CUDA_CU_API bool maybeOverlappingIndices(const at::TensorBase &t);
- using at::native::canUse32BitIndexMath;
- template <typename scalar, typename IndexType>
- TensorInfo<scalar, IndexType>
- getTensorInfo(const at::TensorBase &t) {
- IndexType sz[MAX_TENSORINFO_DIMS];
- IndexType st[MAX_TENSORINFO_DIMS];
- int dims = t.dim();
- for (int i = 0; i < dims; ++i) {
- sz[i] = t.size(i);
- st[i] = t.stride(i);
- }
- return TensorInfo<scalar, IndexType>(
- t.data_ptr<scalar>(), dims, sz, st);
- }
- } // detail
- } // cuda
- } // at
|