123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- #pragma once
- #include <ATen/native/TensorIterator.h>
- #include <c10/util/SmallBuffer.h>
- #include <c10/util/irange.h>
- namespace at {
- struct DimCounter {
- DimCounter(IntArrayRef shape, Range range);
- void increment(const std::array<int64_t, 2>& step);
- bool is_done() const;
- std::array<int64_t, 2> max_2d_step() const;
- IntArrayRef shape;
- Range range;
- c10::SmallBuffer<int64_t, 4> values;
- int64_t offset;
- };
- namespace internal {
- inline void get_data_ptrs(
- char** ptrs,
- ArrayRef<char*> base,
- IntArrayRef strides,
- IntArrayRef counter) {
- const int64_t ntensors = base.size();
- const int64_t ndim = counter.size();
- std::copy(base.begin(), base.end(), ptrs);
- for (const auto dim : c10::irange(ndim)) {
- int64_t value = counter[dim];
- for (const auto arg : c10::irange(ntensors)) {
- ptrs[arg] += value * strides[dim * ntensors + arg];
- }
- }
- }
- inline void serial_for_each(
- IntArrayRef shape,
- IntArrayRef strides,
- char** base_ptrs,
- size_t ntensors,
- typename TensorIteratorBase::loop2d_t loop,
- Range range) {
- const auto ndim = shape.size();
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
- strides.size() == ntensors * std::max(size_t{2}, ndim));
- if (ndim <= 1) {
- if (range.begin == 0) {
- loop(base_ptrs, strides.data(), range.size(), 1);
- } else {
- c10::SmallBuffer<char*, 4> ptrs(ntensors);
- get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin});
- loop(ptrs.data(), strides.data(), range.size(), 1);
- }
- } else {
- c10::SmallBuffer<char*, 4> ptrs(ntensors);
- auto counter = DimCounter(shape, range);
- while (!counter.is_done()) {
- get_data_ptrs(
- ptrs.data(), {base_ptrs, ntensors}, strides, counter.values);
- auto step = counter.max_2d_step();
- loop(ptrs.data(), strides.data(), step[0], step[1]);
- counter.increment(step);
- }
- }
- }
- } // namespace internal
- } // namespace at
|