im2col_shape_check.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/TensorUtils.h>
  4. #include <ATen/div_rtn.h>
  5. namespace at {
  6. namespace native {
  7. static inline void col2im_shape_check(
  8. const Tensor& input,
  9. const Tensor& grad_output,
  10. int64_t output_height,
  11. int64_t output_width,
  12. int64_t kernel_height,
  13. int64_t kernel_width,
  14. int64_t dilation_height,
  15. int64_t dilation_width,
  16. int64_t pad_height,
  17. int64_t pad_width,
  18. int64_t stride_height,
  19. int64_t stride_width) {
  20. TORCH_CHECK(
  21. kernel_width > 0 && kernel_height > 0,
  22. "kernel size should be greater than zero, but got kernel_height: ",
  23. kernel_height,
  24. " kernel_width: ",
  25. kernel_width);
  26. TORCH_CHECK(
  27. stride_width > 0 && stride_height > 0,
  28. "stride should be greater than zero, but got stride_height: ",
  29. stride_height,
  30. " stride_width: ",
  31. stride_width);
  32. TORCH_CHECK(
  33. dilation_width > 0 && dilation_height > 0,
  34. "dilation should be greater than zero, but got dilation_height: ",
  35. dilation_height,
  36. " dilation_width: ",
  37. dilation_width);
  38. TORCH_CHECK(
  39. pad_width >= 0 && pad_height >= 0,
  40. "padding should be non-negative, but got pad_height: ",
  41. pad_height,
  42. " pad_width: ",
  43. pad_width);
  44. int64_t ndim = input.ndimension();
  45. // allow dim=0 only the batch dimension.
  46. TORCH_CHECK(
  47. (ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
  48. (ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
  49. "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
  50. input.sizes());
  51. int64_t batch_dim = (ndim == 3) ? 0 : -1;
  52. int64_t n_input_plane = input.size(batch_dim + 1);
  53. if (n_input_plane % (kernel_width * kernel_height) != 0) {
  54. AT_ERROR(
  55. "Expected size of input's dimension 1 to be divisible by the "
  56. "product of kernel_size, but got input.size(1)=",
  57. n_input_plane,
  58. " and kernel_size=(",
  59. kernel_height,
  60. ", ",
  61. kernel_width,
  62. ").");
  63. }
  64. int64_t input_length = input.size(batch_dim + 2);
  65. int64_t n_blocks_height =
  66. div_rtn<int64_t>(
  67. output_height + 2 * pad_height -
  68. dilation_height * (kernel_height - 1) - 1,
  69. stride_height) +
  70. 1;
  71. int64_t n_blocks_width = div_rtn<int64_t>(
  72. output_width + 2 * pad_width -
  73. dilation_width * (kernel_width - 1) - 1,
  74. stride_width) +
  75. 1;
  76. if (input_length != (n_blocks_height * n_blocks_width)) {
  77. AT_ERROR(
  78. "Given output_size=(",
  79. output_height,
  80. ", ",
  81. output_width,
  82. "), kernel_size=(",
  83. kernel_height,
  84. ", ",
  85. kernel_width,
  86. "), dilation=(",
  87. dilation_height,
  88. ", ",
  89. dilation_width,
  90. "), padding=(",
  91. pad_height,
  92. ", ",
  93. pad_width,
  94. "), stride=(",
  95. stride_height,
  96. ", ",
  97. stride_width,
  98. "), expected size of input's dimension 2 to match the calculated number of ",
  99. "sliding blocks ",
  100. n_blocks_height,
  101. " * ",
  102. n_blocks_width,
  103. " = ",
  104. (n_blocks_height * n_blocks_width),
  105. ", but got input.size(2)=",
  106. input_length,
  107. ".");
  108. }
  109. TORCH_CHECK(
  110. n_blocks_height >= 1 && n_blocks_width >= 1,
  111. "Given output_size=(", output_height, ", ", output_width, "), ",
  112. "kernel_size=(", kernel_height, ", ", kernel_width, "), ",
  113. "dilation=(", dilation_height, ", ", dilation_width, "), ",
  114. "padding=(", pad_height, ", ", pad_width, "), ",
  115. "stride=(", stride_height, ", ", stride_width, "), ",
  116. "calculated shape of the array of sliding blocks as ",
  117. "(", n_blocks_height, ", ", n_blocks_width, "), ",
  118. "which is too small (non-positive)");
  119. if (output_width < 1 || output_height < 1) {
  120. AT_ERROR(
  121. "Expected output spatial size to be positive, but got: output_size=(",
  122. output_height,
  123. ", ",
  124. output_width,
  125. ").");
  126. }
  127. }
  128. static inline void im2col_shape_check(
  129. const Tensor& input,
  130. const Tensor& grad_output,
  131. int64_t kernel_height,
  132. int64_t kernel_width,
  133. int64_t dilation_height,
  134. int64_t dilation_width,
  135. int64_t pad_height,
  136. int64_t pad_width,
  137. int64_t stride_height,
  138. int64_t stride_width) {
  139. TORCH_CHECK(
  140. kernel_width > 0 && kernel_height > 0,
  141. "kernel size should be greater than zero, but got kernel_height: ",
  142. kernel_height,
  143. " kernel_width: ",
  144. kernel_width);
  145. TORCH_CHECK(
  146. dilation_width > 0 && dilation_height > 0,
  147. "dilation should be greater than zero, but got dilation_height: ",
  148. dilation_height,
  149. " dilation_width: ",
  150. dilation_width);
  151. TORCH_CHECK(
  152. pad_width >= 0 && pad_height >= 0,
  153. "padding should be non-negative, but got pad_height: ",
  154. pad_height,
  155. " pad_width: ",
  156. pad_width);
  157. TORCH_CHECK(
  158. stride_width > 0 && stride_height > 0,
  159. "stride should be greater than zero, but got stride_height: ",
  160. stride_height,
  161. " stride_width: ",
  162. stride_width);
  163. int64_t ndim = input.ndimension();
  164. // allow dim=0 only the batch dimension.
  165. bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
  166. TORCH_CHECK(
  167. (ndim == 3 && input.size(0) && valid_dims) ||
  168. (ndim == 4 && valid_dims && input.size(3) != 0),
  169. "Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
  170. input.sizes());
  171. int64_t dim_batch = 0;
  172. if (ndim == 3) {
  173. dim_batch = -1;
  174. }
  175. int64_t input_height = input.size(dim_batch + 2);
  176. int64_t input_width = input.size(dim_batch + 3);
  177. int64_t output_height = div_rtn<int64_t>(
  178. input_height + 2 * pad_height -
  179. (dilation_height * (kernel_height - 1) + 1),
  180. stride_height) +
  181. 1;
  182. int64_t output_width = div_rtn<int64_t>(
  183. input_width + 2 * pad_width -
  184. (dilation_width * (kernel_width - 1) + 1),
  185. stride_width) +
  186. 1;
  187. if (output_height < 1 || output_width < 1) {
  188. AT_ERROR(
  189. "Given input with spatial size (",
  190. input_height,
  191. ", ",
  192. input_height,
  193. "), kernel_size=(",
  194. kernel_height,
  195. ", ",
  196. kernel_width,
  197. "), dilation=(",
  198. dilation_height,
  199. ", ",
  200. dilation_width,
  201. "), padding=(",
  202. pad_height,
  203. ", ",
  204. pad_width,
  205. "), calculated shape of the array of sliding blocks as (",
  206. output_height,
  207. ", ",
  208. output_width,
  209. "), but its components must be at least one.");
  210. }
  211. }
  212. } // namespace native
  213. } // namespace at