| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 | #pragma once#include <ATen/core/Tensor.h>#include <ATen/TensorOperators.h>#ifndef AT_PER_OPERATOR_HEADERS#include <ATen/Functions.h>#else#include <ATen/ops/empty.h>#include <ATen/ops/empty_like.h>#endifnamespace at {namespace native {template <    typename index_t,    void compute(index_t*, int64_t*, index_t*, int64_t, int64_t)>static inline Tensor repeat_interleave_common(    const Tensor& repeats,    c10::optional<int64_t> output_size) {  TORCH_CHECK(      repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");  TORCH_CHECK(      repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,      "repeats has to be Long or Int tensor");  if (repeats.size(0) == 0) {    return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);  }  Tensor repeats_ = repeats.contiguous();  Tensor cumsum = repeats.cumsum(0);  int64_t total;  if (output_size.has_value()) {    total = output_size.value();  } else {    total = cumsum[-1].item<int64_t>();    TORCH_CHECK(        (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");  }  Tensor result = at::empty({total}, repeats.options());  index_t* repeat_ptr = repeats_.data_ptr<index_t>();  int64_t* cumsum_ptr = cumsum.data_ptr<int64_t>();  index_t* result_ptr = result.data_ptr<index_t>();  compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);  return result;}} // namespace native} // namespace at
 |