123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336 |
- #include <ATen/core/Tensor.h>
- #include <ATen/div_rtn.h>
- #include <ATen/TensorUtils.h>
- #include <ATen/native/DispatchStub.h>
- #include <c10/util/irange.h>
- #include <utility>
- #pragma once
- namespace at {
- namespace native {
- using max_pool2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input,
- int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH);
- using max_pool2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
- DECLARE_DISPATCH(max_pool2d_fn, max_pool2d_kernel);
- DECLARE_DISPATCH(max_pool2d_backward_fn, max_pool2d_backward_kernel);
- // averge pooling has same signature for forward and backward
- using avg_pool2d_fn = void(*)(const Tensor& output, const Tensor& input, int64_t kW, int64_t kH,
- int64_t dW, int64_t dH, int64_t padW, int64_t padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
- using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input, int kW, int kH,
- int dW, int dH, int padW, int padH, bool count_include_pad, c10::optional<int64_t> divisor_override);
- DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
- DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
- namespace {
- template <typename dest_t, typename src_t>
- static inline dest_t
- safe_downcast(src_t v)
- {
- TORCH_CHECK(std::numeric_limits<dest_t>::min() <= v && v <= std::numeric_limits<dest_t>::max(),
- "integer out of range");
- return static_cast<dest_t>(v);
- }
- template<typename T>
- static inline T pooling_output_shape_pad_lr(
- T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
- bool ceil_mode) {
- T outputSize = div_rtn<T>(
- inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
- (ceil_mode ? stride - 1 : 0), stride) + 1;
- if (ceil_mode) {
- // ensure that the last pooling starts inside the image
- // needed to avoid problems in ceil mode
- if ((outputSize - 1) * stride >= inputSize + pad_l) {
- --outputSize;
- }
- }
- return outputSize;
- }
- template<typename T>
- static inline T pooling_output_shape(
- T inputSize, T kernelSize, T pad, T stride, T dilation, bool ceil_mode) {
- TORCH_CHECK(stride != 0, "stride should not be zero");
- TORCH_CHECK(pad >= 0,
- "pad must be non-negative, but got pad: ", pad);
- TORCH_CHECK(pad <= kernelSize / 2,
- "pad should be at most half of kernel size, but got pad=",
- pad, " and kernel_size=", kernelSize)
- return pooling_output_shape_pad_lr(
- inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode);
- }
- template <typename T>
- std::pair<T, T> _pooling_same_mode_padding_lr(
- T inputSize, T kernelSize, int64_t stride, int64_t dilation) {
- // NOTE: with strides, the output shape is ceil(inputSize/stride)
- auto total_padding = T(dilation) * (kernelSize - 1);
- // Prefer symmetric padding if possible
- if (stride > 2 && (total_padding % 2 == 1)) {
- // The floor in the output size calculation gives us a little wiggle room
- auto wiggle_room = inputSize % stride - 1;
- if (wiggle_room > 0) {
- total_padding = total_padding - 1;
- }
- }
- auto left = total_padding / 2;
- return {left, total_padding - left};
- }
- inline std::pair<int64_t, int64_t> pooling_same_mode_padding_lr(
- int64_t inputSize, int64_t kernelSize, int64_t stride, int64_t dilation) {
- return _pooling_same_mode_padding_lr(inputSize, kernelSize, stride, dilation);
- }
- inline std::pair<c10::SymInt, c10::SymInt> pooling_same_mode_padding_lr(
- c10::SymInt inputSize, c10::SymInt kernelSize, int64_t stride, int64_t dilation) {
- return _pooling_same_mode_padding_lr(std::move(inputSize), std::move(kernelSize), stride, dilation);
- }
- // AveragePool2d/DilatedMaxPool2d (forward)
- static inline void
- pool2d_shape_check(
- const Tensor& input,
- int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
- int64_t nInputPlane,
- int64_t inputHeight, int64_t inputWidth,
- int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
- {
- const int64_t ndim = input.ndimension();
- const int64_t nOutputPlane = nInputPlane;
- TORCH_CHECK(kW > 0 && kH > 0,
- "kernel size should be greater than zero, but got ",
- "kH: ", kH, " kW: ", kW);
- TORCH_CHECK(dW > 0 && dH > 0,
- "stride should be greater than zero, but got "
- "dH: ", dH, " dW: ", dW);
- TORCH_CHECK(dilationH > 0 && dilationW > 0,
- "dilation should be greater than zero, but got ",
- "dilationH: ", dilationH, " dilationW: ", dilationW);
- bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
- if (memory_format == at::MemoryFormat::ChannelsLast){
- // Expect tensor in NHWC format and allow 0-dim only for N.
- TORCH_CHECK((ndim == 4 && valid_dims && input.size(3) != 0),
- "Expected 4D (batch mode) tensor expected for input with channels_last layout"
- " with optional 0 dim batch size for input, but got: ", input.sizes());
- } else {
- TORCH_CHECK((ndim == 3 && input.size(0) != 0 && valid_dims) ||
- (ndim == 4 && valid_dims && input.size(3) != 0),
- "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
- input.sizes());
- }
- TORCH_CHECK(kW/2 >= padW && kH/2 >= padH,
- "pad should be smaller than or equal to half of kernel size, but got ",
- "padW = ", padW, ", padH = ", padH, ", kW = ", kW, ", kH = ", kH);
- TORCH_CHECK(outputWidth >= 1 && outputHeight >= 1,
- "Given input size: (",
- nInputPlane, "x", inputHeight, "x", inputWidth, "). ",
- "Calculated output size: (",
- nOutputPlane, "x", outputHeight, "x", outputWidth, "). ",
- "Output size is too small");
- }
- // DilatedMaxPool2d (backward)
- static inline void
- max_pool2d_backward_shape_check(
- const Tensor& input,
- const Tensor& gradOutput,
- const Tensor& indices,
- int kH, int kW, int dH, int dW, int padH, int padW, int dilationH, int dilationW,
- int64_t nInputPlane,
- int64_t inputHeight, int64_t inputWidth,
- int64_t outputHeight, int64_t outputWidth, MemoryFormat memory_format)
- {
- pool2d_shape_check(
- input,
- kH, kW, dH, dW, padH, padW, dilationH, dilationW,
- nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth, memory_format);
- const int64_t ndim = input.ndimension();
- const int64_t nOutputPlane = nInputPlane;
- check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
- check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
- check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
- check_dim_size(indices, ndim, ndim-3, nOutputPlane);
- check_dim_size(indices, ndim, ndim-2, outputHeight);
- check_dim_size(indices, ndim, ndim-1, outputWidth);
- }
- // AveragePool2d (backward)
- static inline void
- avg_pool2d_backward_shape_check(
- const Tensor& input,
- const Tensor& gradOutput,
- int64_t /*nbatch*/,
- int kH, int kW, int dH, int dW, int padH, int padW,
- int64_t nInputPlane,
- int64_t inputHeight, int64_t inputWidth,
- int64_t outputHeight, int64_t outputWidth,
- MemoryFormat memory_format)
- {
- pool2d_shape_check(
- input,
- kH, kW, dH, dW, padH, padW, 1, 1,
- nInputPlane, inputHeight, inputWidth, outputHeight, outputWidth,
- memory_format);
- const int64_t ndim = input.ndimension();
- const int64_t nOutputPlane = nInputPlane;
- check_dim_size(gradOutput, ndim, ndim-3, nOutputPlane);
- check_dim_size(gradOutput, ndim, ndim-2, outputHeight);
- check_dim_size(gradOutput, ndim, ndim-1, outputWidth);
- }
- // AveragePool3d/DilatedMaxPool3d (forward)
- static inline void
- pool3d_shape_check(
- const Tensor& input,
- int64_t nslices,
- int kT, int kH, int kW,
- int dT, int dH, int dW,
- int pT, int pH, int pW,
- int dilationT, int dilationH, int dilationW,
- int64_t itime, int64_t iheight, int64_t iwidth,
- int64_t otime, int64_t oheight, int64_t owidth,
- const char *fn_name,
- bool check_input_size=false)
- {
- const int64_t ndim = input.ndimension();
- TORCH_CHECK(kT > 0 && kW > 0 && kH > 0,
- "kernel size should be greater than zero, but got ",
- "kT: ", kT, " kH: ", kH, " kW: ", kW);
- TORCH_CHECK(dT > 0 && dW > 0 && dH > 0,
- "stride should be greater than zero, but got ",
- "dT: ", dT, " dH: ", dH, " dW: ", dW);
- TORCH_CHECK(dilationT > 0 && dilationW > 0 && dilationH > 0,
- "dilation should be greater than zero, but got ",
- "dilationT: ", dilationT, " dilationH: ", dilationH, " dilationW: ", dilationW);
- TORCH_CHECK(ndim == 4 || ndim == 5,
- fn_name, ": Expected 4D or 5D tensor for input, but got: ", input.sizes());
- for (const auto i : c10::irange(ndim)) {
- if (ndim == 5 && i == 0) {
- // size of batch-dim can be 0.
- continue;
- }
- TORCH_CHECK(
- input.size(i) > 0,
- fn_name,
- ": Expected input's non-batch dimensions to have positive length,"
- " but input has a shape of ",
- input.sizes(),
- " and non-batch dimension ",
- input.size(i),
- " has length zero!")
- }
- if (check_input_size) { // AveragePool3d
- TORCH_CHECK(itime >= kT && iheight >= kH && iwidth >= kW,
- "input image ", "(T: ", itime, " H: ", iheight, " W: ", iwidth, ") smaller than ",
- "kernel size ", "(kT: ", kT, " kH: ", kH, " kW: ", kW, ")");
- }
- TORCH_CHECK(kT/2 >= pT && kW/2 >= pW && kH/2 >= pH,
- "pad should be smaller than or equal to half of kernel size, but got "
- "kT: ", kT, " kW: ", kW, " kH: ", kH, " padT: ", pT, " padW: ", pW, " padH: ", pH);
- TORCH_CHECK(otime >= 1 && owidth >= 1 && oheight >= 1,
- "Given input size: (",
- nslices,"x", itime, "x", iheight, "x", iwidth, "). ",
- "Calculated output size: (",
- nslices, "x", otime, "x", oheight, "x", owidth, "). ",
- "Output size is too small");
- }
- static inline void
- max_pool3d_backward_shape_check(
- const Tensor& input,
- const Tensor& gradOutput,
- const Tensor& indices,
- int64_t nslices,
- int kT, int kH, int kW,
- int dT, int dH, int dW,
- int pT, int pH, int pW,
- int dilationT, int dilationH, int dilationW,
- int64_t itime, int64_t iheight, int64_t iwidth,
- int64_t otime, int64_t oheight, int64_t owidth,
- const char* fn_name)
- {
- const int64_t ndim = input.ndimension();
- pool3d_shape_check(
- input,
- nslices,
- kT, kH, kW,
- dT, dH, dW,
- pT, pH, pW,
- dilationT, dilationH, dilationW,
- itime, iheight, iwidth,
- otime, oheight, owidth, fn_name);
- check_dim_size(gradOutput, ndim, ndim-4, nslices);
- check_dim_size(gradOutput, ndim, ndim-3, otime);
- check_dim_size(gradOutput, ndim, ndim-2, oheight);
- check_dim_size(gradOutput, ndim, ndim-1, owidth);
- check_dim_size(indices, ndim, ndim-4, nslices);
- check_dim_size(indices, ndim, ndim-3, otime);
- check_dim_size(indices, ndim, ndim-2, oheight);
- check_dim_size(indices, ndim, ndim-1, owidth);
- }
- static inline void
- avg_pool3d_backward_shape_check(
- const Tensor& input,
- const Tensor& gradOutput,
- int64_t nslices,
- int kT, int kH, int kW,
- int dT, int dH, int dW,
- int pT, int pH, int pW,
- int64_t itime, int64_t iheight, int64_t iwidth,
- int64_t otime, int64_t oheight, int64_t owidth,
- const char *fn_name)
- {
- const int64_t ndim = input.ndimension();
- pool3d_shape_check(
- input,
- nslices,
- kT, kH, kW,
- dT, dH, dW,
- pT, pH, pW,
- 1, 1, 1,
- itime, iheight, iwidth,
- otime, oheight, owidth,
- fn_name, true);
- check_dim_size(gradOutput, ndim, ndim-4, nslices);
- check_dim_size(gradOutput, ndim, ndim-3, otime);
- check_dim_size(gradOutput, ndim, ndim-2, oheight);
- check_dim_size(gradOutput, ndim, ndim-1, owidth);
- }
- } // namespace
- } // at::native
- } // at
|