1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/TensorUtils.h>
- #include <c10/util/irange.h>
- namespace at { namespace native {
- template<typename scalar_t>
- static inline std::vector<int> generate_intervals(
- scalar_t sample,
- int64_t inputSize,
- int64_t outputSize,
- int64_t poolSize) {
- std::vector<int> sequence(outputSize);
- if (outputSize > 1) {
- scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
- static_cast<scalar_t>(outputSize - 1);
- for (const auto i : c10::irange(outputSize - 1)) {
- sequence[i] =
- static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
- }
- }
- if (outputSize > 0) {
- sequence[outputSize - 1] = inputSize - poolSize;
- }
- return sequence;
- }
- template <int64_t ndim>
- static inline void fractional_max_pool_check_shape(
- const Tensor& input,
- const Tensor& randomSamples) {
- TORCH_CHECK(
- input.scalar_type() == randomSamples.scalar_type(),
- "Expect _random_samples to have the same dtype as input");
- int64_t ndimension = randomSamples.ndimension();
- TORCH_CHECK(
- ndimension == 3,
- "Expect _random_samples to have 3 dimensions, got ", ndimension);
- int64_t N = randomSamples.size(0);
- int64_t C = randomSamples.size(1);
- int64_t D = randomSamples.size(2);
- int64_t input_batch, input_channel;
- if (ndim == 2) {
- // fractional_max_pool2d
- if (input.ndimension() == 3) {
- input_batch = 1;
- input_channel = input.size(0);
- } else {
- input_batch = input.size(0);
- input_channel = input.size(1);
- }
- } else {
- // factional_max_pool3d
- if (input.ndimension() == 4) {
- input_batch = 1;
- input_channel = input.size(0);
- } else {
- input_batch = input.size(0);
- input_channel = input.size(1);
- }
- }
- TORCH_CHECK(
- N >= input_batch,
- "Expect _random_samples.size(0) no less then input batch size.");
- TORCH_CHECK(
- C == input_channel,
- "Expect _random_samples.size(1) equals to input channel size.");
- TORCH_CHECK(
- D == ndim,
- "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
- }
- }} // at::native
|