IndexKernelUtils.h 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #pragma once
  2. #include <ATen/native/TensorIterator.h>
  3. #include <c10/util/irange.h>
  4. namespace at {
  5. namespace native {
  6. namespace {
  7. static bool is_constant_index(int ntensor, const int64_t* strides) {
  8. AT_ASSERT(ntensor >= 3);
  9. for (const auto arg : c10::irange(2, ntensor)) {
  10. if (strides[arg] != 0) {
  11. return false;
  12. }
  13. }
  14. return true;
  15. }
  16. struct Indexer {
  17. Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
  18. IntArrayRef original_sizes, IntArrayRef original_strides)
  19. : num_indexers(num_indexers)
  20. , indexers(indexers)
  21. , indexer_strides(indexer_strides)
  22. , original_strides(original_strides.data())
  23. , original_sizes(original_sizes.data()) {
  24. AT_ASSERT(static_cast<int64_t>(original_strides.size()) == num_indexers);
  25. AT_ASSERT(static_cast<int64_t>(original_sizes.size()) == num_indexers);
  26. }
  27. int64_t num_indexers;
  28. char** indexers;
  29. const int64_t* indexer_strides;
  30. const int64_t* original_strides;
  31. const int64_t* original_sizes;
  32. int64_t get(int64_t idx) {
  33. int64_t offset = 0;
  34. for (const auto j : c10::irange(num_indexers)) {
  35. int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
  36. int64_t size = original_sizes[j];
  37. TORCH_CHECK_INDEX(value >= -size && value < size,
  38. "index ", value, " is out of bounds for dimension ", j, " with size ", size);
  39. if (value < 0) {
  40. value += size;
  41. }
  42. offset += value * original_strides[j];
  43. }
  44. return offset;
  45. }
  46. };
  47. } // anonymous namespace
  48. template <typename scalar_t, typename func_t>
  49. void cpu_index_kernel(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride,
  50. const func_t& f, bool serial_execution=false)
  51. {
  52. int ntensor = iter.ntensors();
  53. // When launch the index parallel version, set a relative samll grain size less than the INTERNAL::GRAIN_SIZE
  54. // to make the whole available thread numbers get more balanced work load and a better cache location.
  55. // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
  56. const int index_parallel_grain_size = 3000;
  57. auto loop = [&](char** data, const int64_t* strides, int64_t n) {
  58. auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
  59. char* dst = data[0];
  60. char* src = data[1];
  61. if (is_constant_index(ntensor, strides)) {
  62. // specialization for when every element uses the same index
  63. int64_t offset = indexer.get(0);
  64. if (strides[0] == sizeof(scalar_t) && strides[1] == sizeof(scalar_t)) {
  65. for (const auto i : c10::irange(n)) {
  66. f(dst + strides[0] * i, src + strides[1] * i, offset);
  67. }
  68. } else {
  69. for (const auto i : c10::irange(n)) {
  70. f(dst + strides[0] * i, src + strides[1] * i, offset);
  71. }
  72. }
  73. } else {
  74. for (const auto i : c10::irange(n)) {
  75. int64_t offset = indexer.get(i);
  76. f(dst + strides[0] * i, src + strides[1] * i, offset);
  77. }
  78. }
  79. };
  80. if (serial_execution) {
  81. iter.serial_for_each(loop, {0, iter.numel()});
  82. } else {
  83. iter.for_each(loop, index_parallel_grain_size);
  84. }
  85. }
  86. } // at
  87. } // native