#pragma once #include #include namespace at { namespace native { //input tensors are non-zero dim and non-empty template void tensor_dim_apply3(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim, Function func) { int ndims = self.dim(); int tensor_dim_apply_has_finished = 0; std::vector counter(ndims, 0); T1* self_data = self.data_ptr(); T1* values_data = values.data_ptr(); T2* indices_data = indices.data_ptr(); int64_t self_stride = self.stride(dim); int64_t values_stride = values.stride(dim); int64_t indices_stride = indices.stride(dim); int self_dim_size = self.size(dim); while(!tensor_dim_apply_has_finished) { func(self_data, values_data, indices_data, self_dim_size, self_stride, values_stride, indices_stride); if(ndims == 1) break; for (const auto dim_i : c10::irange(ndims)) { if(dim_i == dim) { if(dim_i == (ndims - 1)) { tensor_dim_apply_has_finished = 1; break; } continue; } counter[dim_i]++; self_data += self.stride(dim_i); values_data += values.stride(dim_i); indices_data += indices.stride(dim_i); if(counter[dim_i] == self.size(dim_i)) { if(dim_i == ndims-1) { tensor_dim_apply_has_finished = 1; break; } else { self_data -= counter[dim_i]*self.stride(dim_i); values_data -= counter[dim_i]*values.stride(dim_i); indices_data -= counter[dim_i]*indices.stride(dim_i); counter[dim_i] = 0; } } else { break; } } } } }}