#pragma once #include #include #include namespace at { struct DimCounter { DimCounter(IntArrayRef shape, Range range); void increment(const std::array& step); bool is_done() const; std::array max_2d_step() const; IntArrayRef shape; Range range; c10::SmallBuffer values; int64_t offset; }; namespace internal { inline void get_data_ptrs( char** ptrs, ArrayRef base, IntArrayRef strides, IntArrayRef counter) { const int64_t ntensors = base.size(); const int64_t ndim = counter.size(); std::copy(base.begin(), base.end(), ptrs); for (const auto dim : c10::irange(ndim)) { int64_t value = counter[dim]; for (const auto arg : c10::irange(ntensors)) { ptrs[arg] += value * strides[dim * ntensors + arg]; } } } inline void serial_for_each( IntArrayRef shape, IntArrayRef strides, char** base_ptrs, size_t ntensors, typename TensorIteratorBase::loop2d_t loop, Range range) { const auto ndim = shape.size(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( strides.size() == ntensors * std::max(size_t{2}, ndim)); if (ndim <= 1) { if (range.begin == 0) { loop(base_ptrs, strides.data(), range.size(), 1); } else { c10::SmallBuffer ptrs(ntensors); get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin}); loop(ptrs.data(), strides.data(), range.size(), 1); } } else { c10::SmallBuffer ptrs(ntensors); auto counter = DimCounter(shape, range); while (!counter.is_done()) { get_data_ptrs( ptrs.data(), {base_ptrs, ntensors}, strides, counter.values); auto step = counter.max_2d_step(); loop(ptrs.data(), strides.data(), step[0], step[1]); counter.increment(step); } } } } // namespace internal } // namespace at