decode_image.cpp 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. #include "decode_image.h"
  2. #include "decode_jpeg.h"
  3. #include "decode_png.h"
  4. namespace vision {
  5. namespace image {
  6. torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) {
  7. // Check that tensor is a CPU tensor
  8. TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor");
  9. // Check that the input tensor dtype is uint8
  10. TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
  11. // Check that the input tensor is 1-dimensional
  12. TORCH_CHECK(
  13. data.dim() == 1 && data.numel() > 0,
  14. "Expected a non empty 1-dimensional tensor");
  15. auto datap = data.data_ptr<uint8_t>();
  16. const uint8_t jpeg_signature[3] = {255, 216, 255}; // == "\xFF\xD8\xFF"
  17. const uint8_t png_signature[4] = {137, 80, 78, 71}; // == "\211PNG"
  18. if (memcmp(jpeg_signature, datap, 3) == 0) {
  19. return decode_jpeg(data, mode);
  20. } else if (memcmp(png_signature, datap, 4) == 0) {
  21. return decode_png(data, mode);
  22. } else {
  23. TORCH_CHECK(
  24. false,
  25. "Unsupported image file. Only jpeg and png ",
  26. "are currently supported.");
  27. }
  28. }
  29. } // namespace image
  30. } // namespace vision