TensorInfo.cuh 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #pragma once
  2. #include <ATen/CollapseDims.h>
  3. namespace at {
  4. namespace cuda {
  5. namespace detail {
  6. #define MAX_TENSORINFO_DIMS 25
  7. // CUDA kernel argument that defines tensor layout
  8. template <typename T, typename IndexType>
  9. struct TensorInfo {
  10. TensorInfo();
  11. TensorInfo(T* p,
  12. int dim,
  13. IndexType sz[MAX_TENSORINFO_DIMS],
  14. IndexType st[MAX_TENSORINFO_DIMS]);
  15. // Set the size of the given dimension to 1, as if it were a
  16. // reduction dim (allows you to calculate offsets of the reduction
  17. // slice)
  18. void reduceDim(int dim);
  19. // See note on [collapse dims].
  20. int collapseDims(const int excludeDim = -1);
  21. // Contiguous tensors of more than one dimension are collapsed down
  22. // to one tensor
  23. __host__ __device__ inline bool isContiguous() const {
  24. return (dims == 1 && strides[0] == 1);
  25. }
  26. T* data;
  27. IndexType sizes[MAX_TENSORINFO_DIMS];
  28. IndexType strides[MAX_TENSORINFO_DIMS];
  29. int dims;
  30. };
  31. template <typename T, typename IndexType>
  32. TensorInfo<T, IndexType>::TensorInfo() {
  33. data = nullptr;
  34. dims = 0;
  35. }
  36. template <typename T, typename IndexType>
  37. TensorInfo<T, IndexType>::TensorInfo(T* p,
  38. int dim,
  39. IndexType sz[MAX_TENSORINFO_DIMS],
  40. IndexType st[MAX_TENSORINFO_DIMS]) {
  41. data = p;
  42. dims = dim;
  43. TORCH_CHECK(dims < MAX_TENSORINFO_DIMS, "CUDA Tensors cannot have more than 25 dimensions");
  44. for (int i = 0; i < dim; ++i) {
  45. sizes[i] = sz[i];
  46. strides[i] = st[i];
  47. }
  48. }
  49. template <typename T, typename IndexType>
  50. void
  51. TensorInfo<T, IndexType>::reduceDim(int dim) {
  52. TORCH_CHECK(dim < dims && dim >= 0, "expected dim between 0 and dims - 1");
  53. sizes[dim] = 1;
  54. }
  55. template <typename T, typename IndexType>
  56. int
  57. TensorInfo<T, IndexType>::collapseDims(const int excludeDim) {
  58. auto result = at::collapse_dims(sizes, strides, dims, excludeDim);
  59. dims = std::get<1>(result);
  60. return std::get<0>(result);
  61. }
  62. // Translate a linear index for the apply to a T* offset;
  63. // specialized on `Dims` to reduce nvcc compilation time
  64. template <typename T, typename IndexType, int Dims>
  65. struct IndexToOffset {
  66. static __host__ __device__ IndexType get(
  67. IndexType linearId,
  68. const TensorInfo<T, IndexType>& info) {
  69. IndexType offset = 0;
  70. // Uses static dims
  71. for (int i = Dims - 1; i > 0; --i) {
  72. IndexType curDimIndex = linearId % info.sizes[i];
  73. IndexType curDimOffset = curDimIndex * info.strides[i];
  74. offset += curDimOffset;
  75. linearId /= info.sizes[i];
  76. }
  77. return offset + linearId * info.strides[0];
  78. }
  79. };
  80. // Uses dynamic (runtime) instead of static (compiletime) dims
  81. template <typename T, typename IndexType>
  82. struct IndexToOffset<T, IndexType, -1> {
  83. static inline __host__ __device__ IndexType get(
  84. IndexType linearId,
  85. const TensorInfo<T, IndexType>& info) {
  86. IndexType offset = 0;
  87. for (int i = info.dims - 1; i > 0; --i) {
  88. IndexType curDimIndex = linearId % info.sizes[i];
  89. IndexType curDimOffset = curDimIndex * info.strides[i];
  90. offset += curDimOffset;
  91. linearId /= info.sizes[i];
  92. }
  93. return offset + linearId * info.strides[0];
  94. }
  95. };
  96. } // detail
  97. } // cuda
  98. } // at