decode_png.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #include "decode_png.h"
  2. #include "common_png.h"
  3. namespace vision {
  4. namespace image {
  5. #if !PNG_FOUND
  6. torch::Tensor decode_png(
  7. const torch::Tensor& data,
  8. ImageReadMode mode,
  9. bool allow_16_bits) {
  10. TORCH_CHECK(
  11. false, "decode_png: torchvision not compiled with libPNG support");
  12. }
  13. #else
  14. bool is_little_endian() {
  15. uint32_t x = 1;
  16. return *(uint8_t*)&x;
  17. }
  18. torch::Tensor decode_png(
  19. const torch::Tensor& data,
  20. ImageReadMode mode,
  21. bool allow_16_bits) {
  22. C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
  23. // Check that the input tensor dtype is uint8
  24. TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  25. // Check that the input tensor is 1-dimensional
  26. TORCH_CHECK(
  27. data.dim() == 1 && data.numel() > 0,
  28. "Expected a non empty 1-dimensional tensor");
  29. auto png_ptr =
  30. png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
  31. TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
  32. auto info_ptr = png_create_info_struct(png_ptr);
  33. if (!info_ptr) {
  34. png_destroy_read_struct(&png_ptr, nullptr, nullptr);
  35. // Seems redundant with the if statement. done here to avoid leaking memory.
  36. TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
  37. }
  38. auto accessor = data.accessor<unsigned char, 1>();
  39. auto datap = accessor.data();
  40. auto datap_len = accessor.size(0);
  41. if (setjmp(png_jmpbuf(png_ptr)) != 0) {
  42. png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
  43. TORCH_CHECK(false, "Internal error.");
  44. }
  45. TORCH_CHECK(datap_len >= 8, "Content is too small for png!")
  46. auto is_png = !png_sig_cmp(datap, 0, 8);
  47. TORCH_CHECK(is_png, "Content is not png!")
  48. struct Reader {
  49. png_const_bytep ptr;
  50. png_size_t count;
  51. } reader;
  52. reader.ptr = png_const_bytep(datap) + 8;
  53. reader.count = datap_len - 8;
  54. auto read_callback = [](png_structp png_ptr,
  55. png_bytep output,
  56. png_size_t bytes) {
  57. auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
  58. TORCH_CHECK(
  59. reader->count >= bytes,
  60. "Out of bound read in decode_png. Probably, the input image is corrupted");
  61. std::copy(reader->ptr, reader->ptr + bytes, output);
  62. reader->ptr += bytes;
  63. reader->count -= bytes;
  64. };
  65. png_set_sig_bytes(png_ptr, 8);
  66. png_set_read_fn(png_ptr, &reader, read_callback);
  67. png_read_info(png_ptr, info_ptr);
  68. png_uint_32 width, height;
  69. int bit_depth, color_type;
  70. int interlace_type;
  71. auto retval = png_get_IHDR(
  72. png_ptr,
  73. info_ptr,
  74. &width,
  75. &height,
  76. &bit_depth,
  77. &color_type,
  78. &interlace_type,
  79. nullptr,
  80. nullptr);
  81. if (retval != 1) {
  82. png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
  83. TORCH_CHECK(retval == 1, "Could read image metadata from content.")
  84. }
  85. auto max_bit_depth = allow_16_bits ? 16 : 8;
  86. auto err_msg = "At most " + std::to_string(max_bit_depth) +
  87. "-bit PNG images are supported currently.";
  88. if (bit_depth > max_bit_depth) {
  89. png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
  90. TORCH_CHECK(false, err_msg)
  91. }
  92. int channels = png_get_channels(png_ptr, info_ptr);
  93. if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
  94. png_set_expand_gray_1_2_4_to_8(png_ptr);
  95. int number_of_passes;
  96. if (interlace_type == PNG_INTERLACE_ADAM7) {
  97. number_of_passes = png_set_interlace_handling(png_ptr);
  98. } else {
  99. number_of_passes = 1;
  100. }
  101. if (mode != IMAGE_READ_MODE_UNCHANGED) {
  102. // TODO: consider supporting PNG_INFO_tRNS
  103. bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
  104. bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0;
  105. bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0;
  106. switch (mode) {
  107. case IMAGE_READ_MODE_GRAY:
  108. if (color_type != PNG_COLOR_TYPE_GRAY) {
  109. if (is_palette) {
  110. png_set_palette_to_rgb(png_ptr);
  111. has_alpha = true;
  112. }
  113. if (has_alpha) {
  114. png_set_strip_alpha(png_ptr);
  115. }
  116. if (has_color) {
  117. png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
  118. }
  119. channels = 1;
  120. }
  121. break;
  122. case IMAGE_READ_MODE_GRAY_ALPHA:
  123. if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) {
  124. if (is_palette) {
  125. png_set_palette_to_rgb(png_ptr);
  126. has_alpha = true;
  127. }
  128. if (!has_alpha) {
  129. png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
  130. }
  131. if (has_color) {
  132. png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587);
  133. }
  134. channels = 2;
  135. }
  136. break;
  137. case IMAGE_READ_MODE_RGB:
  138. if (color_type != PNG_COLOR_TYPE_RGB) {
  139. if (is_palette) {
  140. png_set_palette_to_rgb(png_ptr);
  141. has_alpha = true;
  142. } else if (!has_color) {
  143. png_set_gray_to_rgb(png_ptr);
  144. }
  145. if (has_alpha) {
  146. png_set_strip_alpha(png_ptr);
  147. }
  148. channels = 3;
  149. }
  150. break;
  151. case IMAGE_READ_MODE_RGB_ALPHA:
  152. if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) {
  153. if (is_palette) {
  154. png_set_palette_to_rgb(png_ptr);
  155. has_alpha = true;
  156. } else if (!has_color) {
  157. png_set_gray_to_rgb(png_ptr);
  158. }
  159. if (!has_alpha) {
  160. png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER);
  161. }
  162. channels = 4;
  163. }
  164. break;
  165. default:
  166. png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
  167. TORCH_CHECK(false, "The provided mode is not supported for PNG files");
  168. }
  169. png_read_update_info(png_ptr, info_ptr);
  170. }
  171. auto num_pixels_per_row = width * channels;
  172. auto tensor = torch::empty(
  173. {int64_t(height), int64_t(width), channels},
  174. bit_depth <= 8 ? torch::kU8 : torch::kI32);
  175. if (bit_depth <= 8) {
  176. auto t_ptr = tensor.accessor<uint8_t, 3>().data();
  177. for (int pass = 0; pass < number_of_passes; pass++) {
  178. for (png_uint_32 i = 0; i < height; ++i) {
  179. png_read_row(png_ptr, t_ptr, nullptr);
  180. t_ptr += num_pixels_per_row;
  181. }
  182. t_ptr = tensor.accessor<uint8_t, 3>().data();
  183. }
  184. } else {
  185. // We're reading a 16bits png, but pytorch doesn't support uint16.
  186. // So we read each row in a 16bits tmp_buffer which we then cast into
  187. // a int32 tensor instead.
  188. if (is_little_endian()) {
  189. png_set_swap(png_ptr);
  190. }
  191. int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();
  192. // We create a tensor instead of malloc-ing for automatic memory management
  193. auto tmp_buffer_tensor = torch::empty(
  194. {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
  195. uint16_t* tmp_buffer =
  196. (uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();
  197. for (int pass = 0; pass < number_of_passes; pass++) {
  198. for (png_uint_32 i = 0; i < height; ++i) {
  199. png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
  200. // Now we copy the uint16 values into the int32 tensor.
  201. for (size_t j = 0; j < num_pixels_per_row; ++j) {
  202. t_ptr[j] = (int32_t)tmp_buffer[j];
  203. }
  204. t_ptr += num_pixels_per_row;
  205. }
  206. t_ptr = tensor.accessor<int32_t, 3>().data();
  207. }
  208. }
  209. png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
  210. return tensor.permute({2, 0, 1});
  211. }
  212. #endif
  213. } // namespace image
  214. } // namespace vision