MaxPooling.h 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/Parallel.h>
  4. #include <ATen/native/DispatchStub.h>
  5. namespace at {
  6. namespace native {
  7. // TODO(Heitor) Template by dimension
  8. struct PoolingParams1D {
  9. int64_t NB; // Number of batches
  10. int64_t NC; // Number of channels
  11. int64_t IW; // Input width
  12. int64_t OW; // Output width
  13. int64_t KW; // Kernel width
  14. int64_t SJ; // Column stride
  15. int64_t PJ; // Column padding
  16. int64_t DJ; // Column dilation
  17. // Return index of input element for the given kernel and output index
  18. inline int64_t index(int64_t kj, int64_t oj) const {
  19. return oj * SJ + kj * DJ - PJ;
  20. }
  21. // Return index of first output within bounds for this kernel index
  22. inline int64_t valid_output_start(int64_t kj) const {
  23. int64_t ij = index(kj, 0);;
  24. return ij < 0 ? at::divup(-ij, SJ) : 0;
  25. }
  26. // Return index one past last output within bounds for this kernel index
  27. inline int64_t valid_output_end(int64_t kj) const {
  28. int64_t ij = index(kj, OW - 1);
  29. return ij >= IW ? OW - at::divup(ij - (IW - 1), SJ) : OW;
  30. }
  31. };
  32. using pooling_fn = void (*)(Tensor&, const Tensor&, const PoolingParams1D&);
  33. DECLARE_DISPATCH(pooling_fn, max_pool1d_stub);
  34. } // namespace native
  35. } // namespace at