decode_jpeg.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #include "decode_jpeg.h"
  2. #include "common_jpeg.h"
  3. namespace vision {
  4. namespace image {
  5. #if !JPEG_FOUND
  6. torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
  7. TORCH_CHECK(
  8. false, "decode_jpeg: torchvision not compiled with libjpeg support");
  9. }
  10. #else
  11. using namespace detail;
  12. namespace {
  13. struct torch_jpeg_mgr {
  14. struct jpeg_source_mgr pub;
  15. const JOCTET* data;
  16. size_t len;
  17. };
  18. static void torch_jpeg_init_source(j_decompress_ptr cinfo) {}
  19. static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
  20. // No more data. Probably an incomplete image; Raise exception.
  21. torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
  22. strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated");
  23. longjmp(myerr->setjmp_buffer, 1);
  24. }
  25. static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) {
  26. torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
  27. if (src->pub.bytes_in_buffer < (size_t)num_bytes) {
  28. // Skipping over all of remaining data; output EOI.
  29. src->pub.next_input_byte = EOI_BUFFER;
  30. src->pub.bytes_in_buffer = 1;
  31. } else {
  32. // Skipping over only some of the remaining data.
  33. src->pub.next_input_byte += num_bytes;
  34. src->pub.bytes_in_buffer -= num_bytes;
  35. }
  36. }
  37. static void torch_jpeg_term_source(j_decompress_ptr cinfo) {}
  38. static void torch_jpeg_set_source_mgr(
  39. j_decompress_ptr cinfo,
  40. const unsigned char* data,
  41. size_t len) {
  42. torch_jpeg_mgr* src;
  43. if (cinfo->src == 0) { // if this is first time; allocate memory
  44. cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)(
  45. (j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr));
  46. }
  47. src = (torch_jpeg_mgr*)cinfo->src;
  48. src->pub.init_source = torch_jpeg_init_source;
  49. src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer;
  50. src->pub.skip_input_data = torch_jpeg_skip_input_data;
  51. src->pub.resync_to_restart = jpeg_resync_to_restart; // default
  52. src->pub.term_source = torch_jpeg_term_source;
  53. // fill the buffers
  54. src->data = (const JOCTET*)data;
  55. src->len = len;
  56. src->pub.bytes_in_buffer = len;
  57. src->pub.next_input_byte = src->data;
  58. }
  59. inline unsigned char clamped_cmyk_rgb_convert(
  60. unsigned char k,
  61. unsigned char cmy) {
  62. // Inspired from Pillow:
  63. // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L568-L569
  64. int v = k * cmy + 128;
  65. v = ((v >> 8) + v) >> 8;
  66. return std::clamp(k - v, 0, 255);
  67. }
  68. void convert_line_cmyk_to_rgb(
  69. j_decompress_ptr cinfo,
  70. const unsigned char* cmyk_line,
  71. unsigned char* rgb_line) {
  72. int width = cinfo->output_width;
  73. for (int i = 0; i < width; ++i) {
  74. int c = cmyk_line[i * 4 + 0];
  75. int m = cmyk_line[i * 4 + 1];
  76. int y = cmyk_line[i * 4 + 2];
  77. int k = cmyk_line[i * 4 + 3];
  78. rgb_line[i * 3 + 0] = clamped_cmyk_rgb_convert(k, 255 - c);
  79. rgb_line[i * 3 + 1] = clamped_cmyk_rgb_convert(k, 255 - m);
  80. rgb_line[i * 3 + 2] = clamped_cmyk_rgb_convert(k, 255 - y);
  81. }
  82. }
  83. inline unsigned char rgb_to_gray(int r, int g, int b) {
  84. // Inspired from Pillow:
  85. // https://github.com/python-pillow/Pillow/blob/07623d1a7cc65206a5355fba2ae256550bfcaba6/src/libImaging/Convert.c#L226
  86. return (r * 19595 + g * 38470 + b * 7471 + 0x8000) >> 16;
  87. }
  88. void convert_line_cmyk_to_gray(
  89. j_decompress_ptr cinfo,
  90. const unsigned char* cmyk_line,
  91. unsigned char* gray_line) {
  92. int width = cinfo->output_width;
  93. for (int i = 0; i < width; ++i) {
  94. int c = cmyk_line[i * 4 + 0];
  95. int m = cmyk_line[i * 4 + 1];
  96. int y = cmyk_line[i * 4 + 2];
  97. int k = cmyk_line[i * 4 + 3];
  98. int r = clamped_cmyk_rgb_convert(k, 255 - c);
  99. int g = clamped_cmyk_rgb_convert(k, 255 - m);
  100. int b = clamped_cmyk_rgb_convert(k, 255 - y);
  101. gray_line[i] = rgb_to_gray(r, g, b);
  102. }
  103. }
  104. } // namespace
  105. torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
  106. C10_LOG_API_USAGE_ONCE(
  107. "torchvision.csrc.io.image.cpu.decode_jpeg.decode_jpeg");
  108. // Check that the input tensor dtype is uint8
  109. TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  110. // Check that the input tensor is 1-dimensional
  111. TORCH_CHECK(
  112. data.dim() == 1 && data.numel() > 0,
  113. "Expected a non empty 1-dimensional tensor");
  114. struct jpeg_decompress_struct cinfo;
  115. struct torch_jpeg_error_mgr jerr;
  116. auto datap = data.data_ptr<uint8_t>();
  117. // Setup decompression structure
  118. cinfo.err = jpeg_std_error(&jerr.pub);
  119. jerr.pub.error_exit = torch_jpeg_error_exit;
  120. /* Establish the setjmp return context for my_error_exit to use. */
  121. if (setjmp(jerr.setjmp_buffer)) {
  122. /* If we get here, the JPEG code has signaled an error.
  123. * We need to clean up the JPEG object.
  124. */
  125. jpeg_destroy_decompress(&cinfo);
  126. TORCH_CHECK(false, jerr.jpegLastErrorMsg);
  127. }
  128. jpeg_create_decompress(&cinfo);
  129. torch_jpeg_set_source_mgr(&cinfo, datap, data.numel());
  130. // read info from header.
  131. jpeg_read_header(&cinfo, TRUE);
  132. int channels = cinfo.num_components;
  133. bool cmyk_to_rgb_or_gray = false;
  134. if (mode != IMAGE_READ_MODE_UNCHANGED) {
  135. switch (mode) {
  136. case IMAGE_READ_MODE_GRAY:
  137. if (cinfo.jpeg_color_space == JCS_CMYK ||
  138. cinfo.jpeg_color_space == JCS_YCCK) {
  139. cinfo.out_color_space = JCS_CMYK;
  140. cmyk_to_rgb_or_gray = true;
  141. } else {
  142. cinfo.out_color_space = JCS_GRAYSCALE;
  143. }
  144. channels = 1;
  145. break;
  146. case IMAGE_READ_MODE_RGB:
  147. if (cinfo.jpeg_color_space == JCS_CMYK ||
  148. cinfo.jpeg_color_space == JCS_YCCK) {
  149. cinfo.out_color_space = JCS_CMYK;
  150. cmyk_to_rgb_or_gray = true;
  151. } else {
  152. cinfo.out_color_space = JCS_RGB;
  153. }
  154. channels = 3;
  155. break;
  156. /*
  157. * Libjpeg does not support converting from CMYK to grayscale etc. There
  158. * is a way to do this but it involves converting it manually to RGB:
  159. * https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313
  160. */
  161. default:
  162. jpeg_destroy_decompress(&cinfo);
  163. TORCH_CHECK(false, "The provided mode is not supported for JPEG files");
  164. }
  165. jpeg_calc_output_dimensions(&cinfo);
  166. }
  167. jpeg_start_decompress(&cinfo);
  168. int height = cinfo.output_height;
  169. int width = cinfo.output_width;
  170. int stride = width * channels;
  171. auto tensor =
  172. torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
  173. auto ptr = tensor.data_ptr<uint8_t>();
  174. torch::Tensor cmyk_line_tensor;
  175. if (cmyk_to_rgb_or_gray) {
  176. cmyk_line_tensor = torch::empty({int64_t(width), 4}, torch::kU8);
  177. }
  178. while (cinfo.output_scanline < cinfo.output_height) {
  179. /* jpeg_read_scanlines expects an array of pointers to scanlines.
  180. * Here the array is only one element long, but you could ask for
  181. * more than one scanline at a time if that's more convenient.
  182. */
  183. if (cmyk_to_rgb_or_gray) {
  184. auto cmyk_line_ptr = cmyk_line_tensor.data_ptr<uint8_t>();
  185. jpeg_read_scanlines(&cinfo, &cmyk_line_ptr, 1);
  186. if (channels == 3) {
  187. convert_line_cmyk_to_rgb(&cinfo, cmyk_line_ptr, ptr);
  188. } else if (channels == 1) {
  189. convert_line_cmyk_to_gray(&cinfo, cmyk_line_ptr, ptr);
  190. }
  191. } else {
  192. jpeg_read_scanlines(&cinfo, &ptr, 1);
  193. }
  194. ptr += stride;
  195. }
  196. jpeg_finish_decompress(&cinfo);
  197. jpeg_destroy_decompress(&cinfo);
  198. return tensor.permute({2, 0, 1});
  199. }
  200. #endif // #if !JPEG_FOUND
  201. int64_t _jpeg_version() {
  202. #if JPEG_FOUND
  203. return JPEG_LIB_VERSION;
  204. #else
  205. return -1;
  206. #endif
  207. }
  208. bool _is_compiled_against_turbo() {
  209. #ifdef LIBJPEG_TURBO_VERSION
  210. return true;
  211. #else
  212. return false;
  213. #endif
  214. }
  215. } // namespace image
  216. } // namespace vision