encode_png.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #include "encode_jpeg.h"
  2. #include "common_png.h"
  3. namespace vision {
  4. namespace image {
  5. #if !PNG_FOUND
  6. torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
  7. TORCH_CHECK(
  8. false, "encode_png: torchvision not compiled with libpng support");
  9. }
  10. #else
  11. namespace {
  12. struct torch_mem_encode {
  13. char* buffer;
  14. size_t size;
  15. };
  16. struct torch_png_error_mgr {
  17. const char* pngLastErrorMsg; /* error messages */
  18. jmp_buf setjmp_buffer; /* for return to caller */
  19. };
  20. using torch_png_error_mgr_ptr = torch_png_error_mgr*;
  21. void torch_png_error(png_structp png_ptr, png_const_charp error_msg) {
  22. /* png_ptr->err really points to a torch_png_error_mgr struct, so coerce
  23. * pointer */
  24. auto error_ptr = (torch_png_error_mgr_ptr)png_get_error_ptr(png_ptr);
  25. /* Replace the error message on the error structure */
  26. error_ptr->pngLastErrorMsg = error_msg;
  27. /* Return control to the setjmp point */
  28. longjmp(error_ptr->setjmp_buffer, 1);
  29. }
  30. void torch_png_write_data(
  31. png_structp png_ptr,
  32. png_bytep data,
  33. png_size_t length) {
  34. struct torch_mem_encode* p =
  35. (struct torch_mem_encode*)png_get_io_ptr(png_ptr);
  36. size_t nsize = p->size + length;
  37. /* allocate or grow buffer */
  38. if (p->buffer)
  39. p->buffer = (char*)realloc(p->buffer, nsize);
  40. else
  41. p->buffer = (char*)malloc(nsize);
  42. if (!p->buffer)
  43. png_error(png_ptr, "Write Error");
  44. /* copy new bytes to end of buffer */
  45. memcpy(p->buffer + p->size, data, length);
  46. p->size += length;
  47. }
  48. } // namespace
  49. torch::Tensor encode_png(const torch::Tensor& data, int64_t compression_level) {
  50. C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.encode_png.encode_png");
  51. // Define compression structures and error handling
  52. png_structp png_write;
  53. png_infop info_ptr;
  54. struct torch_png_error_mgr err_ptr;
  55. // Define output buffer
  56. struct torch_mem_encode buf_info;
  57. buf_info.buffer = NULL;
  58. buf_info.size = 0;
  59. /* Establish the setjmp return context for my_error_exit to use. */
  60. if (setjmp(err_ptr.setjmp_buffer)) {
  61. /* If we get here, the PNG code has signaled an error.
  62. * We need to clean up the PNG object and the buffer.
  63. */
  64. if (info_ptr != NULL) {
  65. png_destroy_info_struct(png_write, &info_ptr);
  66. }
  67. if (png_write != NULL) {
  68. png_destroy_write_struct(&png_write, NULL);
  69. }
  70. if (buf_info.buffer != NULL) {
  71. free(buf_info.buffer);
  72. }
  73. TORCH_CHECK(false, err_ptr.pngLastErrorMsg);
  74. }
  75. // Check that the compression level is between 0 and 9
  76. TORCH_CHECK(
  77. compression_level >= 0 && compression_level <= 9,
  78. "Compression level should be between 0 and 9");
  79. // Check that the input tensor is on CPU
  80. TORCH_CHECK(data.device() == torch::kCPU, "Input tensor should be on CPU");
  81. // Check that the input tensor dtype is uint8
  82. TORCH_CHECK(data.dtype() == torch::kU8, "Input tensor dtype should be uint8");
  83. // Check that the input tensor is 3-dimensional
  84. TORCH_CHECK(data.dim() == 3, "Input data should be a 3-dimensional tensor");
  85. // Get image info
  86. int channels = data.size(0);
  87. int height = data.size(1);
  88. int width = data.size(2);
  89. auto input = data.permute({1, 2, 0}).contiguous();
  90. TORCH_CHECK(
  91. channels == 1 || channels == 3,
  92. "The number of channels should be 1 or 3, got: ",
  93. channels);
  94. // Initialize PNG structures
  95. png_write = png_create_write_struct(
  96. PNG_LIBPNG_VER_STRING, &err_ptr, torch_png_error, NULL);
  97. info_ptr = png_create_info_struct(png_write);
  98. // Define custom buffer output
  99. png_set_write_fn(png_write, &buf_info, torch_png_write_data, NULL);
  100. // Set output image information
  101. auto color_type = channels == 1 ? PNG_COLOR_TYPE_GRAY : PNG_COLOR_TYPE_RGB;
  102. png_set_IHDR(
  103. png_write,
  104. info_ptr,
  105. width,
  106. height,
  107. 8,
  108. color_type,
  109. PNG_INTERLACE_NONE,
  110. PNG_COMPRESSION_TYPE_DEFAULT,
  111. PNG_FILTER_TYPE_DEFAULT);
  112. // Set image compression level
  113. png_set_compression_level(png_write, compression_level);
  114. // Write file header
  115. png_write_info(png_write, info_ptr);
  116. auto stride = width * channels;
  117. auto ptr = input.data_ptr<uint8_t>();
  118. // Encode PNG file
  119. for (int y = 0; y < height; ++y) {
  120. png_write_row(png_write, ptr);
  121. ptr += stride;
  122. }
  123. // Write EOF
  124. png_write_end(png_write, info_ptr);
  125. // Destroy structures
  126. png_destroy_write_struct(&png_write, &info_ptr);
  127. torch::TensorOptions options = torch::TensorOptions{torch::kU8};
  128. auto outTensor = torch::empty({(long)buf_info.size}, options);
  129. // Copy memory from png buffer, since torch cannot get ownership of it via
  130. // `from_blob`
  131. auto outPtr = outTensor.data_ptr<uint8_t>();
  132. std::memcpy(outPtr, buf_info.buffer, sizeof(uint8_t) * outTensor.numel());
  133. free(buf_info.buffer);
  134. return outTensor;
  135. }
  136. #endif
  137. } // namespace image
  138. } // namespace vision