123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <c10/util/irange.h>
- namespace at {
- namespace native {
- //input tensors are non-zero dim and non-empty
- template<typename T1, typename T2, typename Function>
- void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) {
- int ndims = self.dim();
- int tensor_dim_apply_has_finished = 0;
- std::vector<int64_t> counter(ndims, 0);
- T1* self_data = self.data_ptr<T1>();
- T1* values_data = values.data_ptr<T1>();
- T2* indices_data = indices.data_ptr<T2>();
- int64_t self_stride = self.stride(dim);
- int64_t values_stride = values.stride(dim);
- int64_t indices_stride = indices.stride(dim);
- int self_dim_size = self.size(dim);
- while(!tensor_dim_apply_has_finished) {
- func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride);
- if(ndims == 1)
- break;
- for (const auto dim_i : c10::irange(ndims)) {
- if(dim_i == dim) {
- if(dim_i == (ndims - 1)) {
- tensor_dim_apply_has_finished = 1;
- break;
- }
- continue;
- }
- counter[dim_i]++;
- self_data += self.stride(dim_i);
- values_data += values.stride(dim_i);
- indices_data += indices.stride(dim_i);
- if(counter[dim_i] == self.size(dim_i)) {
- if(dim_i == ndims-1) {
- tensor_dim_apply_has_finished = 1;
- break;
- } else {
- self_data -= counter[dim_i]*self.stride(dim_i);
- values_data -= counter[dim_i]*values.stride(dim_i);
- indices_data -= counter[dim_i]*indices.stride(dim_i);
- counter[dim_i] = 0;
- }
- } else {
- break;
- }
- }
- }
- }
- }}
|