UpSample.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. #pragma once
  2. #include <math.h>
  3. #include <ATen/OpMathType.h>
  4. #include <ATen/TensorUtils.h>
  5. #include <ATen/core/Tensor.h>
  6. #include <ATen/native/DispatchStub.h>
  7. /**
  8. * Note [compute_scales_value]
  9. * Note [area_pixel_compute_scale]
  10. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  11. * Interpolate with scale_factor can have different behaviors
  12. * depending on the value of recompute_scale_factor:
  13. *
  14. * - With recompute_scale_factor = True (current default behavior):
  15. * the scale_factor, when provided by the user, are used to calculate
  16. * the output size. The input size and the computed output_size
  17. * are then used to infer new values for the scales which are
  18. * used in the interpolation. Because floating-point math is not exact,
  19. * this may be a different value from the user-supplied scales.
  20. *
  21. * - With recompute_scale_factor = False (which will be the default
  22. * behavior starting 1.5.0):
  23. * the behavior follows opencv logic, and the scales provided by
  24. * the user are the ones used in the interpolation calculations.
  25. *
  26. * If the scales are not provided or if they are provided but
  27. * recompute_scale_factor is set to True (default behavior), the scales
  28. * are computed from the input and the output size;
  29. *
  30. *
  31. * When the scales are inferred from the input and output sizes,
  32. * we view each pixel as an area, idx + 0.5 as its center index.
  33. * Here is an example formula in 1D case.
  34. * if align_corners: center of two corner pixel areas are preserved,
  35. * (0.5, 0.5) -> (0.5, 0.5),
  36. * (input_size - 0.5, 0.5) -> (output_size - 0.5)
  37. * scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
  38. * src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
  39. * if not align_corners: the whole range is scaled accordingly
  40. * scale = input_size / output_size
  41. * src_idx + 0.5 = scale * (dst_index + 0.5)
  42. */
  43. namespace at {
  44. namespace native {
  45. namespace upsample {
  46. TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
  47. c10::IntArrayRef input_size, // Full input tensor size.
  48. at::OptionalIntArrayRef output_size,
  49. c10::optional<c10::ArrayRef<double>> scale_factors);
  50. inline c10::optional<double> get_scale_value(c10::optional<c10::ArrayRef<double>> scales, int idx) {
  51. if (!scales) {
  52. return c10::nullopt;
  53. }
  54. return scales->at(idx);
  55. }
  56. } // namespace upsample
  57. using scale_t = c10::optional<double>;
  58. using upsampling_nearest1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
  59. using _upsampling_nearest_exact1d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_w);
  60. using upsampling_nearest2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
  61. using _upsampling_nearest_exact2d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_h, scale_t scales_w);
  62. using upsampling_nearest3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  63. using _upsampling_nearest_exact3d = void(*)(const Tensor& output, const Tensor& input, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  64. using upsampling_linear1d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_w);
  65. using upsampling_bilinear2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  66. using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  67. using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w);
  68. using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  69. using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w);
  70. DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel);
  71. DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel);
  72. DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel);
  73. DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_kernel);
  74. DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_kernel);
  75. DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_kernel);
  76. DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_backward_kernel);
  77. DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_backward_kernel);
  78. DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_backward_kernel);
  79. DECLARE_DISPATCH(_upsampling_nearest_exact2d, _upsample_nearest_exact2d_backward_kernel);
  80. DECLARE_DISPATCH(upsampling_nearest3d, upsample_nearest3d_backward_kernel);
  81. DECLARE_DISPATCH(_upsampling_nearest_exact3d, _upsample_nearest_exact3d_backward_kernel);
  82. DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_kernel);
  83. DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_kernel);
  84. DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_kernel);
  85. DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_kernel);
  86. DECLARE_DISPATCH(upsampling_linear1d, upsample_linear1d_backward_kernel);
  87. DECLARE_DISPATCH(upsampling_bilinear2d, upsample_bilinear2d_backward_kernel);
  88. DECLARE_DISPATCH(_upsampling_bilinear2d_aa, _upsample_bilinear2d_aa_backward_kernel);
  89. DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel);
  90. DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel);
  91. DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel);
  92. DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel);
  93. static C10_UNUSED std::array<int64_t, 3> upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
  94. TORCH_CHECK(
  95. output_size.size() == 1,
  96. "It is expected output_size equals to 1, but got size ",
  97. output_size.size());
  98. TORCH_CHECK(
  99. input_size.size() == 3,
  100. "It is expected input_size equals to 3, but got size ",
  101. input_size.size());
  102. int64_t output_width = output_size[0];
  103. int64_t nbatch = input_size[0];
  104. int64_t channels = input_size[1];
  105. int64_t input_width = input_size[2];
  106. TORCH_CHECK(
  107. input_width > 0 && output_width > 0,
  108. "Input and output sizes should be greater than 0, but got input (W: ",
  109. input_width,
  110. ") and output (W: ",
  111. output_width,
  112. ")");
  113. return {nbatch, channels, output_width};
  114. }
  115. static C10_UNUSED std::array<int64_t, 4> upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
  116. TORCH_CHECK(
  117. output_size.size() == 2,
  118. "It is expected output_size equals to 2, but got size ",
  119. output_size.size());
  120. TORCH_CHECK(
  121. input_size.size() == 4,
  122. "It is expected input_size equals to 4, but got size ",
  123. input_size.size());
  124. int64_t output_height = output_size[0];
  125. int64_t output_width = output_size[1];
  126. int64_t nbatch = input_size[0];
  127. int64_t channels = input_size[1];
  128. int64_t input_height = input_size[2];
  129. int64_t input_width = input_size[3];
  130. TORCH_CHECK(
  131. input_height > 0 && input_width > 0 && output_height > 0 &&
  132. output_width > 0,
  133. "Input and output sizes should be greater than 0,"
  134. " but got input (H: ",
  135. input_height,
  136. ", W: ",
  137. input_width,
  138. ") output (H: ",
  139. output_height,
  140. ", W: ",
  141. output_width,
  142. ")");
  143. return {nbatch, channels, output_height, output_width};
  144. }
  145. static C10_UNUSED
  146. std::array<int64_t, 5> upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) {
  147. TORCH_CHECK(
  148. output_size.size() == 3,
  149. "It is expected output_size equals to 3, but got size ",
  150. output_size.size());
  151. TORCH_CHECK(
  152. input_size.size() == 5,
  153. "It is expected input_size equals to 5, but got size ",
  154. input_size.size());
  155. int64_t output_depth = output_size[0];
  156. int64_t output_height = output_size[1];
  157. int64_t output_width = output_size[2];
  158. int64_t nbatch = input_size[0];
  159. int64_t channels = input_size[1];
  160. int64_t input_depth = input_size[2];
  161. int64_t input_height = input_size[3];
  162. int64_t input_width = input_size[4];
  163. TORCH_CHECK(
  164. input_depth > 0 && input_height > 0 && input_width > 0 &&
  165. output_depth > 0 && output_height > 0 && output_width > 0,
  166. "Input and output sizes should be greater than 0, but got input (D: ",
  167. input_depth,
  168. ", H: ",
  169. input_height,
  170. ", W: ",
  171. input_width,
  172. ") output (D: ",
  173. output_depth,
  174. ", H: ",
  175. output_height,
  176. ", W: ",
  177. output_width,
  178. ")");
  179. return {nbatch, channels, output_depth, output_height, output_width};
  180. }
  181. static inline void upsample_2d_shape_check(
  182. const Tensor& input,
  183. const Tensor& grad_output,
  184. int64_t nbatch,
  185. int64_t nchannels,
  186. int64_t input_height,
  187. int64_t input_width,
  188. int64_t output_height,
  189. int64_t output_width) {
  190. TORCH_CHECK(
  191. input_height > 0 && input_width > 0 && output_height > 0 &&
  192. output_width > 0,
  193. "Input and output sizes should be greater than 0,"
  194. " but got input (H: ",
  195. input_height,
  196. ", W: ",
  197. input_width,
  198. ") output (H: ",
  199. output_height,
  200. ", W: ",
  201. output_width,
  202. ")");
  203. if (input.defined()) {
  204. // Allow for empty batch size but not other dimensions
  205. TORCH_CHECK(
  206. (input.numel() != 0 ||
  207. (input.size(1) != 0 && input.size(2) != 0 && input.size(3) != 0)
  208. ) &&
  209. input.dim() == 4,
  210. "Non-empty 4D data tensor expected but got a tensor with sizes ",
  211. input.sizes());
  212. } else if (grad_output.defined()) {
  213. check_dim_size(grad_output, 4, 0, nbatch);
  214. check_dim_size(grad_output, 4, 1, nchannels);
  215. check_dim_size(grad_output, 4, 2, output_height);
  216. check_dim_size(grad_output, 4, 3, output_width);
  217. }
  218. }
  219. template <typename scalar_t>
  220. static inline scalar_t compute_scales_value(
  221. const c10::optional<double> scale,
  222. int64_t input_size,
  223. int64_t output_size) {
  224. // see Note [compute_scales_value]
  225. // FIXME: remove magic > 0 after we ensure no models were serialized with -1 defaults.
  226. return (scale.has_value() && scale.value() > 0.)
  227. ? static_cast<scalar_t>(1.0 / scale.value())
  228. : (static_cast<scalar_t>(input_size) / output_size);
  229. }
  230. template <typename scalar_t>
  231. static inline scalar_t area_pixel_compute_scale(
  232. int64_t input_size,
  233. int64_t output_size,
  234. bool align_corners,
  235. const c10::optional<double> scale) {
  236. // see Note [area_pixel_compute_scale]
  237. if(align_corners) {
  238. if(output_size > 1) {
  239. return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
  240. } else {
  241. return static_cast<scalar_t>(0);
  242. }
  243. } else {
  244. return compute_scales_value<scalar_t>(scale, input_size, output_size);
  245. }
  246. }
  247. template <typename scalar_t>
  248. static inline scalar_t area_pixel_compute_source_index(
  249. scalar_t scale,
  250. int64_t dst_index,
  251. bool align_corners,
  252. bool cubic) {
  253. if (align_corners) {
  254. return scale * dst_index;
  255. } else {
  256. scalar_t src_idx = scale * (dst_index + static_cast<scalar_t>(0.5)) -
  257. static_cast<scalar_t>(0.5);
  258. // [Note] Follow Opencv resize logic:
  259. // We allow negative src_idx here and later will use
  260. // dx = src_idx - floorf(src_idx)
  261. // to compute the "distance"(which affects weights).
  262. // For linear modes, weight distribution doesn't matter
  263. // for negative indices as they use 2 pixels to interpolate.
  264. // For example, [-1, 0], they both use pixel 0 value so it
  265. // doesn't affect if we bound the src_idx to 0 or not.
  266. // TODO: Our current linear mode impls use unbound indices
  267. // where we should and then remove this cubic flag.
  268. // This matters in cubic mode, as we might need [-1, 0, 1, 2]
  269. // to interpolate and the weights can be affected.
  270. return (!cubic && src_idx < static_cast<scalar_t>(0)) ? scalar_t(0)
  271. : src_idx;
  272. }
  273. }
  274. static inline int64_t nearest_neighbor_compute_source_index(
  275. const float scale,
  276. int64_t dst_index,
  277. int64_t input_size) {
  278. // Index computation matching OpenCV INTER_NEAREST
  279. // which is buggy and kept for BC
  280. const int64_t src_index =
  281. std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
  282. return src_index;
  283. }
  284. static inline int64_t nearest_neighbor_exact_compute_source_index(
  285. const float scale,
  286. int64_t dst_index,
  287. int64_t input_size) {
  288. // index_f32 = (output_index + 0.5) * scale - 0.5
  289. // input_index = round(index_f32)
  290. // Same as Pillow and Scikit-Image/Scipy ndi.zoom
  291. const int64_t src_index =
  292. std::min(static_cast<int64_t>(floorf((dst_index + 0.5) * scale)), input_size - 1);
  293. return src_index;
  294. }
  295. static inline int64_t nearest_idx(
  296. int64_t output_index,
  297. int64_t input_size,
  298. int64_t output_size,
  299. c10::optional<double> scales) {
  300. // This method specificly treats cases: output_size == input_size or
  301. // output_size == 2 * input_size, that we would like to get rid of
  302. // We keep this method for BC and consider as deprecated.
  303. // See nearest_exact_idx as replacement
  304. if (output_size == input_size) {
  305. // scale_factor = 1, simply copy
  306. return output_index;
  307. } else if (output_size == 2 * input_size) {
  308. // scale_factor = 2, shift input index
  309. return output_index >> 1;
  310. } else {
  311. float scale = compute_scales_value<float>(scales, input_size, output_size);
  312. return nearest_neighbor_compute_source_index(scale, output_index, input_size);
  313. }
  314. }
  315. static inline int64_t nearest_exact_idx(
  316. int64_t output_index,
  317. int64_t input_size,
  318. int64_t output_size,
  319. c10::optional<double> scales) {
  320. float scale = compute_scales_value<float>(scales, input_size, output_size);
  321. return nearest_neighbor_exact_compute_source_index(scale, output_index, input_size);
  322. }
  323. // Define a typedef to dispatch to nearest_idx or nearest_exact_idx
  324. typedef int64_t (*nearest_idx_fn_t)(int64_t, int64_t, int64_t, c10::optional<double>);
  325. template <typename scalar_t>
  326. static scalar_t upsample_get_value_bounded(
  327. scalar_t* data,
  328. int64_t width,
  329. int64_t height,
  330. int64_t x,
  331. int64_t y) {
  332. int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
  333. int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
  334. return data[access_y * width + access_x];
  335. }
  336. template <typename scalar_t>
  337. static void upsample_increment_value_bounded(
  338. scalar_t* data,
  339. int64_t width,
  340. int64_t height,
  341. int64_t x,
  342. int64_t y,
  343. scalar_t value) {
  344. int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
  345. int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
  346. data[access_y * width + access_x] += value;
  347. }
  348. // Based on
  349. // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  350. template <typename scalar_t>
  351. static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
  352. return ((A + 2) * x - (A + 3)) * x * x + 1;
  353. }
  354. template <typename scalar_t>
  355. static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
  356. return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
  357. }
  358. template <typename scalar_t>
  359. static inline void get_cubic_upsample_coefficients(
  360. scalar_t coeffs[4],
  361. scalar_t t) {
  362. scalar_t A = -0.75;
  363. scalar_t x1 = t;
  364. coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
  365. coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
  366. // opposite coefficients
  367. scalar_t x2 = 1.0 - t;
  368. coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
  369. coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
  370. }
  371. template <typename scalar_t>
  372. static inline scalar_t cubic_interp1d(
  373. scalar_t x0,
  374. scalar_t x1,
  375. scalar_t x2,
  376. scalar_t x3,
  377. scalar_t t) {
  378. scalar_t coeffs[4];
  379. get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
  380. return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
  381. }
  382. template<typename scalar_t>
  383. static inline void compute_source_index_and_lambda(
  384. int64_t& input_index0,
  385. int64_t& input_index1,
  386. scalar_t& lambda0,
  387. scalar_t& lambda1,
  388. scalar_t ratio,
  389. int64_t output_index,
  390. int64_t input_size,
  391. int64_t output_size,
  392. bool align_corners) {
  393. if (output_size == input_size) {
  394. // scale_factor = 1, simply copy
  395. input_index0 = output_index;
  396. input_index1 = output_index;
  397. lambda0 = static_cast<scalar_t>(1);
  398. lambda1 = static_cast<scalar_t>(0);
  399. } else {
  400. using opmath_t = at::opmath_type<scalar_t>;
  401. const auto real_input_index =
  402. area_pixel_compute_source_index<opmath_t>(
  403. ratio, output_index, align_corners, /*cubic=*/false);
  404. // when `real_input_index` becomes larger than the range the floating point
  405. // type can accurately represent, the type casting to `int64_t` might exceed
  406. // `input_size - 1`, causing overflow. So we guard it with `std::min` below.
  407. input_index0 = std::min(static_cast<int64_t>(real_input_index), input_size - 1);
  408. int64_t offset = (input_index0 < input_size - 1) ? 1 : 0;
  409. input_index1 = input_index0 + offset;
  410. lambda1 = std::min(
  411. std::max(real_input_index - input_index0, static_cast<opmath_t>(0)),
  412. static_cast<opmath_t>(1)
  413. );
  414. lambda0 = static_cast<scalar_t>(1.) - lambda1;
  415. }
  416. }
  417. } // namespace native
  418. } // namespace at