FractionalMaxPooling.h 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorUtils.h>
  4. #include <c10/util/irange.h>
  5. namespace at { namespace native {
  6. template<typename scalar_t>
  7. static inline std::vector<int> generate_intervals(
  8. scalar_t sample,
  9. int64_t inputSize,
  10. int64_t outputSize,
  11. int64_t poolSize) {
  12. std::vector<int> sequence(outputSize);
  13. if (outputSize > 1) {
  14. scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
  15. static_cast<scalar_t>(outputSize - 1);
  16. for (const auto i : c10::irange(outputSize - 1)) {
  17. sequence[i] =
  18. static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha);
  19. }
  20. }
  21. if (outputSize > 0) {
  22. sequence[outputSize - 1] = inputSize - poolSize;
  23. }
  24. return sequence;
  25. }
  26. template <int64_t ndim>
  27. static inline void fractional_max_pool_check_shape(
  28. const Tensor& input,
  29. const Tensor& randomSamples) {
  30. TORCH_CHECK(
  31. input.scalar_type() == randomSamples.scalar_type(),
  32. "Expect _random_samples to have the same dtype as input");
  33. int64_t ndimension = randomSamples.ndimension();
  34. TORCH_CHECK(
  35. ndimension == 3,
  36. "Expect _random_samples to have 3 dimensions, got ", ndimension);
  37. int64_t N = randomSamples.size(0);
  38. int64_t C = randomSamples.size(1);
  39. int64_t D = randomSamples.size(2);
  40. int64_t input_batch, input_channel;
  41. if (ndim == 2) {
  42. // fractional_max_pool2d
  43. if (input.ndimension() == 3) {
  44. input_batch = 1;
  45. input_channel = input.size(0);
  46. } else {
  47. input_batch = input.size(0);
  48. input_channel = input.size(1);
  49. }
  50. } else {
  51. // factional_max_pool3d
  52. if (input.ndimension() == 4) {
  53. input_batch = 1;
  54. input_channel = input.size(0);
  55. } else {
  56. input_batch = input.size(0);
  57. input_channel = input.size(1);
  58. }
  59. }
  60. TORCH_CHECK(
  61. N >= input_batch,
  62. "Expect _random_samples.size(0) no less then input batch size.");
  63. TORCH_CHECK(
  64. C == input_channel,
  65. "Expect _random_samples.size(1) equals to input channel size.");
  66. TORCH_CHECK(
  67. D == ndim,
  68. "Expect _random_samples.size(2) equals to ", ndim, "; got ", D, ".");
  69. }
  70. }} // at::native