DilatedConvolutionUtils.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. #pragma once
  2. #include <algorithm>
  3. #include <vector>
  4. #include <ATen/div_rtn.h>
  5. #include <ATen/core/Tensor.h>
  6. #include <c10/util/irange.h>
  7. #define TORCH_CHECK_DIM_SIZE(T, DIM, DIM_SIZE, SIZE) \
  8. TORCH_CHECK( \
  9. T.dim() == DIM && T.size(DIM_SIZE) == SIZE, \
  10. "Need " #T " of dimension ", \
  11. DIM, \
  12. " and " #T ".size[", \
  13. DIM_SIZE, \
  14. "] == ", \
  15. SIZE, \
  16. " but got input to be of shape ", \
  17. T.sizes())
  18. namespace at {
  19. namespace native {
  20. namespace internal {
  21. namespace {
  22. inline bool all_positive(IntArrayRef& arr) {
  23. return std::all_of(
  24. arr.begin(), arr.end(), [](int64_t item) { return item > 0; });
  25. }
  26. inline bool all_nonnegative(std::vector<int64_t>& arr) {
  27. return std::all_of(
  28. arr.begin(), arr.end(), [](int64_t item) { return item >= 0; });
  29. }
  30. } // namespace
  31. // calculate the rear part of output tensor sizes
  32. template <int64_t dim>
  33. std::vector<int64_t> get_output_size(
  34. const Tensor& input,
  35. IntArrayRef kernel_size,
  36. IntArrayRef stride_size,
  37. IntArrayRef pad_size,
  38. IntArrayRef dilation_size) {
  39. std::vector<int64_t> sizes;
  40. for (const auto index : c10::irange(dim)) {
  41. sizes.push_back(
  42. div_rtn<int64_t>(
  43. input.size(index + input.dim() - dim) + 2 * pad_size[index] -
  44. (dilation_size[index] * (kernel_size[index] - 1) + 1),
  45. stride_size[index]) +
  46. 1);
  47. }
  48. return sizes;
  49. }
  50. // calculate the sizes of output tensor
  51. template <int64_t dim>
  52. std::vector<int64_t> get_output_size(
  53. const Tensor& input,
  54. const Tensor& weight,
  55. IntArrayRef kernel_size,
  56. IntArrayRef stride_size,
  57. IntArrayRef pad_size,
  58. IntArrayRef dilation_size) {
  59. auto output_size = get_output_size<dim>(
  60. input, kernel_size, stride_size, pad_size, dilation_size);
  61. output_size.insert(output_size.begin(), weight.size(0));
  62. if (input.dim() == dim + 2) {
  63. output_size.insert(output_size.begin(), input.size(0));
  64. }
  65. return output_size;
  66. }
  67. /*
  68. slow_conv_dilated_shape_check - check user-input to dilated convolution
  69. forward and backward functions.
  70. */
  71. template <int64_t dim>
  72. void slow_conv_dilated_shape_check(
  73. const Tensor& input,
  74. const Tensor& weight,
  75. const Tensor& bias,
  76. const Tensor& grad_output,
  77. IntArrayRef kernel_size,
  78. IntArrayRef stride_size,
  79. IntArrayRef pad_size,
  80. IntArrayRef dilation_size) {
  81. /*
  82. When the following tensors are defined:
  83. bias, grad_weight, grad_output
  84. then these are assumed to be contiguous without checking
  85. because of these tensors are made contiguous by calling
  86. .contiguous() method or by resizing of zero-sized tensors in
  87. forward/backward functions.
  88. When grad_weight is defined then it is assumed without
  89. checking to have the same shape as weight, see backward
  90. functions.
  91. */
  92. // Check size arguments
  93. TORCH_CHECK(
  94. kernel_size.size() == dim,
  95. "kernel sizes length should be ",
  96. dim,
  97. ", but got ",
  98. kernel_size.size());
  99. TORCH_CHECK(
  100. stride_size.size() == dim,
  101. "strides length should be ",
  102. dim,
  103. ", but got ",
  104. stride_size.size());
  105. TORCH_CHECK(
  106. dilation_size.size() == dim,
  107. "dilations length should be ",
  108. dim,
  109. ", but got ",
  110. dilation_size.size());
  111. TORCH_CHECK(
  112. pad_size.size() == dim,
  113. "pads length should be ",
  114. dim,
  115. ", but got ",
  116. pad_size.size());
  117. TORCH_CHECK(
  118. all_positive(kernel_size),
  119. "kernel size should be greater than zero, but got ",
  120. kernel_size);
  121. TORCH_CHECK(
  122. all_positive(stride_size),
  123. "stride should be greater than zero, but got ",
  124. stride_size);
  125. TORCH_CHECK(
  126. all_positive(dilation_size),
  127. "dilation should be greater than zero, but got ",
  128. dilation_size);
  129. // check input
  130. TORCH_CHECK(input.defined(), "input must be defined");
  131. bool is_batch = input.dim() == dim + 2;
  132. int64_t n = (is_batch ? 2 : 1);
  133. int64_t ndim = n + dim;
  134. if (!is_batch) {
  135. // input dim has to be dim + 1 if not batched
  136. TORCH_CHECK(
  137. input.dim() == dim + 1,
  138. "input must be 4D or 5D tensor but got ",
  139. input.dim(),
  140. "D tensor");
  141. }
  142. // check output sizes
  143. auto output_size = get_output_size<dim>(
  144. input, kernel_size, stride_size, pad_size, dilation_size);
  145. TORCH_CHECK(
  146. all_nonnegative(output_size),
  147. "calculated output size ",
  148. output_size,
  149. " is too small (all sizes must be non-negative)");
  150. // check weight
  151. TORCH_CHECK(weight.defined(), "weight must be defined");
  152. TORCH_CHECK(
  153. weight.dim() == dim + 2,
  154. "weight must be ",
  155. dim + 2,
  156. "D tensor but got ",
  157. weight.dim(),
  158. "D tensor dim=",
  159. dim);
  160. TORCH_CHECK(
  161. weight.sizes().slice(2) == kernel_size,
  162. "weight[2:] shape ",
  163. weight.sizes().slice(2),
  164. " must be equal to kernel_size ",
  165. kernel_size);
  166. TORCH_CHECK_DIM_SIZE(input, input.dim(), (is_batch ? 1 : 0), weight.size(1));
  167. // check bias when present
  168. if (bias.defined()) {
  169. TORCH_CHECK(
  170. bias.dim() == 1,
  171. "bias must be 1D tensor but got ",
  172. bias.dim(),
  173. "D tensor");
  174. TORCH_CHECK_DIM_SIZE(bias, 1, 0, weight.size(0));
  175. }
  176. // check grad_output when present
  177. if (grad_output.defined()) {
  178. TORCH_CHECK(
  179. grad_output.dim() == ndim,
  180. "grad_output must be ",
  181. ndim,
  182. "D tensor but got ",
  183. grad_output.dim(),
  184. "D tensor");
  185. if (is_batch) {
  186. TORCH_CHECK(
  187. grad_output.size(0) == input.size(0),
  188. "grad_output.size(0)=",
  189. grad_output.size(0),
  190. " must be input.size(0)=",
  191. input.size(0));
  192. }
  193. TORCH_CHECK(
  194. grad_output.size(n - 1) == weight.size(0),
  195. "grad_output.size(",
  196. n - 1,
  197. ")=",
  198. grad_output.size(n - 1),
  199. " must be weight.size(0)=",
  200. weight.size(0));
  201. TORCH_CHECK(
  202. grad_output.sizes().slice(n) == output_size,
  203. "grad_output[",
  204. n,
  205. ":] shape",
  206. grad_output.sizes().slice(n),
  207. " must be equal to output size ",
  208. output_size);
  209. }
  210. }
  211. } // namespace internal
  212. } // namespace native
  213. } // namespace at