OffsetCalculator.cuh 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. #pragma once
  2. #include <array>
  3. #include <cstdint>
  4. #include <type_traits>
  5. #include <c10/macros/Macros.h>
  6. #include <ATen/core/Array.h>
  7. #include <ATen/native/TensorIterator.h>
  8. #include <ATen/cuda/detail/IntegerDivider.cuh>
  9. // If element_sizes is nullptr, then the strides will be in bytes, otherwise
  10. // the strides will be in # of elements.
  11. // Operands that share the same shape, but may have different strides.
  12. // OffsetCalculator iterates the tensor in a column-major order
  13. #if defined(USE_ROCM)
  14. constexpr int MAX_DIMS = 16;
  15. #else
  16. constexpr int MAX_DIMS = 25;
  17. #endif
  18. template <int NARGS, typename index_t = uint32_t, bool signed_strides = false>
  19. struct OffsetCalculator {
  20. // We allow having negative strides to implement some operations like torch.flip
  21. using stride_t = std::conditional_t<signed_strides,
  22. std::make_signed_t<index_t>,
  23. index_t>;
  24. // The offset for each argument. Wrapper around fixed-size array.
  25. // On CUDA, zero sized array is not allowed, so when we are handling nullary
  26. // operators, we need to create a size 1 offset to avoid compiler failure.
  27. // This size 1 offset is just a placeholder, and we will not use it.
  28. using offset_type = at::detail::Array<stride_t, std::max<int>(NARGS, 1)>;
  29. // if element_sizes is nullptr, then the strides will be in bytes, otherwise
  30. // the strides will be in # of elements.
  31. OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides, const int64_t* element_sizes=nullptr) : dims(dims) {
  32. TORCH_CHECK(dims <= MAX_DIMS, "tensor has too many (>", MAX_DIMS, ") dims");
  33. for (int i=0; i < dims; i++){
  34. sizes_[i] = at::cuda::detail::IntDivider<index_t>(sizes[i]);
  35. for (int arg = 0; arg < NARGS; arg++) {
  36. int64_t element_size = (element_sizes == nullptr ? 1LL : element_sizes[arg]);
  37. strides_[i][arg] = strides[arg][i] / element_size;
  38. }
  39. }
  40. }
  41. C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
  42. offset_type offsets;
  43. #pragma unroll
  44. for (int arg = 0; arg < NARGS; arg++) {
  45. offsets[arg] = 0;
  46. }
  47. #pragma unroll
  48. for (int dim = 0; dim < MAX_DIMS; ++dim) {
  49. if (dim == dims) {
  50. break;
  51. }
  52. auto divmod = sizes_[dim].divmod(linear_idx);
  53. linear_idx = divmod.div;
  54. #pragma unroll
  55. for (int arg = 0; arg < NARGS; arg++) {
  56. offsets[arg] += divmod.mod * strides_[dim][arg];
  57. }
  58. }
  59. return offsets;
  60. }
  61. int dims;
  62. at::cuda::detail::IntDivider<index_t> sizes_[MAX_DIMS];
  63. stride_t strides_[MAX_DIMS][std::max<int>(NARGS, 1)];
  64. };
  65. template <int NARGS, typename index_t = uint32_t>
  66. struct TrivialOffsetCalculator {
  67. // The offset for each argument. Wrapper around fixed-size array.
  68. // The offsets are in # of elements, not in bytes.
  69. // On CUDA, zero sized array is not allowed, so when we are handling nullary
  70. // operators, we need to create a size 1 offset to avoid compiler failure.
  71. // This size 1 offset is just a placeholder, and we will not use it.
  72. using offset_type = at::detail::Array<index_t, std::max<int>(NARGS, 1)>;
  73. C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
  74. offset_type offsets;
  75. #pragma unroll
  76. for (int arg = 0; arg < NARGS; arg++) {
  77. offsets[arg] = linear_idx;
  78. }
  79. return offsets;
  80. }
  81. };
  82. // Make an OffsetCalculator with byte offsets
  83. template<int N, bool signed_strides = false>
  84. static OffsetCalculator<N, uint32_t, signed_strides> make_offset_calculator(const at::TensorIteratorBase& iter) {
  85. TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
  86. std::array<const int64_t*, N> strides;
  87. for (int i = 0; i < N; i++) {
  88. strides[i] = iter.strides(i).data();
  89. }
  90. return OffsetCalculator<N, uint32_t, signed_strides>(iter.ndim(), iter.shape().data(), strides.data());
  91. }
  92. // Make an OffsetCalculator with element offsets
  93. template<int N, bool signed_strides = false>
  94. static OffsetCalculator<N, uint32_t, signed_strides> make_element_offset_calculator(
  95. const at::TensorIteratorBase& iter) {
  96. TORCH_INTERNAL_ASSERT(N <= iter.ntensors());
  97. std::array<const int64_t*, N> strides;
  98. std::array<int64_t, N> element_sizes;
  99. for (int i = 0; i < N; i++) {
  100. strides[i] = iter.strides(i).data();
  101. element_sizes[i] = iter.element_size(i);
  102. }
  103. return OffsetCalculator<N, uint32_t, signed_strides>(
  104. iter.ndim(), iter.shape().data(), strides.data(), element_sizes.data());
  105. }