TensorIteratorInternal.h 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #pragma once
  2. #include <ATen/native/TensorIterator.h>
  3. #include <c10/util/SmallBuffer.h>
  4. #include <c10/util/irange.h>
  5. namespace at {
  6. struct DimCounter {
  7. DimCounter(IntArrayRef shape, Range range);
  8. void increment(const std::array<int64_t, 2>& step);
  9. bool is_done() const;
  10. std::array<int64_t, 2> max_2d_step() const;
  11. IntArrayRef shape;
  12. Range range;
  13. c10::SmallBuffer<int64_t, 4> values;
  14. int64_t offset;
  15. };
  16. namespace internal {
  17. inline void get_data_ptrs(
  18. char** ptrs,
  19. ArrayRef<char*> base,
  20. IntArrayRef strides,
  21. IntArrayRef counter) {
  22. const int64_t ntensors = base.size();
  23. const int64_t ndim = counter.size();
  24. std::copy(base.begin(), base.end(), ptrs);
  25. for (const auto dim : c10::irange(ndim)) {
  26. int64_t value = counter[dim];
  27. for (const auto arg : c10::irange(ntensors)) {
  28. ptrs[arg] += value * strides[dim * ntensors + arg];
  29. }
  30. }
  31. }
  32. inline void serial_for_each(
  33. IntArrayRef shape,
  34. IntArrayRef strides,
  35. char** base_ptrs,
  36. size_t ntensors,
  37. typename TensorIteratorBase::loop2d_t loop,
  38. Range range) {
  39. const auto ndim = shape.size();
  40. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  41. strides.size() == ntensors * std::max(size_t{2}, ndim));
  42. if (ndim <= 1) {
  43. if (range.begin == 0) {
  44. loop(base_ptrs, strides.data(), range.size(), 1);
  45. } else {
  46. c10::SmallBuffer<char*, 4> ptrs(ntensors);
  47. get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
  48. loop(ptrs.data(), strides.data(), range.size(), 1);
  49. }
  50. } else {
  51. c10::SmallBuffer<char*, 4> ptrs(ntensors);
  52. auto counter = DimCounter(shape, range);
  53. while (!counter.is_done()) {
  54. get_data_ptrs(
  55. ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
  56. auto step = counter.max_2d_step();
  57. loop(ptrs.data(), strides.data(), step[0], step[1]);
  58. counter.increment(step);
  59. }
  60. }
  61. }
  62. } // namespace internal
  63. } // namespace at