encode_jpeg.cpp 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #include "encode_jpeg.h"
  2. #include "common_jpeg.h"
  3. namespace vision {
  4. namespace image {
  5. #if !JPEG_FOUND
  6. torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
  7. TORCH_CHECK(
  8. false, "encode_jpeg: torchvision not compiled with libjpeg support");
  9. }
  10. #else
  11. // For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
  12. // defined as unsigned long, whereas in later version, it is defined as size_t.
  13. #if !defined(JPEG_LIB_VERSION_MAJOR) || JPEG_LIB_VERSION_MAJOR < 9 || \
  14. (JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
  15. using JpegSizeType = unsigned long;
  16. #else
  17. using JpegSizeType = size_t;
  18. #endif
  19. using namespace detail;
  20. torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
  21. C10_LOG_API_USAGE_ONCE(
  22. "torchvision.csrc.io.image.cpu.encode_jpeg.encode_jpeg");
  23. // Define compression structures and error handling
  24. struct jpeg_compress_struct cinfo {};
  25. struct torch_jpeg_error_mgr jerr {};
  26. // Define buffer to write JPEG information to and its size
  27. JpegSizeType jpegSize = 0;
  28. uint8_t* jpegBuf = nullptr;
  29. cinfo.err = jpeg_std_error(&jerr.pub);
  30. jerr.pub.error_exit = torch_jpeg_error_exit;
  31. /* Establish the setjmp return context for my_error_exit to use. */
  32. if (setjmp(jerr.setjmp_buffer)) {
  33. /* If we get here, the JPEG code has signaled an error.
  34. * We need to clean up the JPEG object and the buffer.
  35. */
  36. jpeg_destroy_compress(&cinfo);
  37. if (jpegBuf != nullptr) {
  38. free(jpegBuf);
  39. }
  40. TORCH_CHECK(false, (const char*)jerr.jpegLastErrorMsg);
  41. }
  42. // Check that the input tensor is on CPU
  43. TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
  44. // Check that the input tensor dtype is uint8
  45. TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
  46. // Check that the input tensor is 3-dimensional
  47. TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
  48. // Get image info
  49. int channels = data.size(0);
  50. int height = data.size(1);
  51. int width = data.size(2);
  52. auto input = data.permute({1, 2, 0}).contiguous();
  53. TORCH_CHECK(
  54. channels == 1 || channels == 3,
  55. "The number of channels should be 1 or 3, got: ",
  56. channels);
  57. // Initialize JPEG structure
  58. jpeg_create_compress(&cinfo);
  59. // Set output image information
  60. cinfo.image_width = width;
  61. cinfo.image_height = height;
  62. cinfo.input_components = channels;
  63. cinfo.in_color_space = channels == 1 ? JCS_GRAYSCALE : JCS_RGB;
  64. jpeg_set_defaults(&cinfo);
  65. jpeg_set_quality(&cinfo, quality, TRUE);
  66. // Save JPEG output to a buffer
  67. jpeg_mem_dest(&cinfo, &jpegBuf, &jpegSize);
  68. // Start JPEG compression
  69. jpeg_start_compress(&cinfo, TRUE);
  70. auto stride = width * channels;
  71. auto ptr = input.data_ptr<uint8_t>();
  72. // Encode JPEG file
  73. while (cinfo.next_scanline < cinfo.image_height) {
  74. jpeg_write_scanlines(&cinfo, &ptr, 1);
  75. ptr += stride;
  76. }
  77. jpeg_finish_compress(&cinfo);
  78. jpeg_destroy_compress(&cinfo);
  79. torch::TensorOptions options = torch::TensorOptions{torch::kU8};
  80. auto out_tensor =
  81. torch::from_blob(jpegBuf, {(long)jpegSize}, ::free, options);
  82. jpegBuf = nullptr;
  83. return out_tensor;
  84. }
  85. #endif
  86. } // namespace image
  87. } // namespace vision