decoder.h 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #include <cuda.h>
  2. #include <cuda_runtime_api.h>
  3. #include <cuviddec.h>
  4. #include <nvcuvid.h>
  5. #include <torch/torch.h>
  6. #include <cstdint>
  7. #include <queue>
  8. static auto check_for_cuda_errors =
  9. [](CUresult result, int line_num, std::string file_name) {
  10. if (CUDA_SUCCESS != result) {
  11. const char* error_name = nullptr;
  12. TORCH_CHECK(
  13. CUDA_SUCCESS != cuGetErrorName(result, &error_name),
  14. "CUDA error: ",
  15. error_name,
  16. " in ",
  17. file_name,
  18. " at line ",
  19. line_num)
  20. TORCH_CHECK(
  21. false, "Error: ", result, " in ", file_name, " at line ", line_num);
  22. }
  23. };
  24. struct Rect {
  25. int left, top, right, bottom;
  26. };
  27. class Decoder {
  28. public:
  29. Decoder() {}
  30. ~Decoder();
  31. void init(CUcontext, cudaVideoCodec);
  32. void release();
  33. void decode(const uint8_t*, unsigned long);
  34. torch::Tensor fetch_frame();
  35. int get_height() const {
  36. return luma_height;
  37. }
  38. private:
  39. unsigned int width = 0, luma_height = 0, chroma_height = 0;
  40. unsigned int surface_height = 0, surface_width = 0;
  41. unsigned int max_width = 0, max_height = 0;
  42. unsigned int num_chroma_planes = 0;
  43. int bit_depth_minus8 = 0, bytes_per_pixel = 1;
  44. int decode_pic_count = 0, pic_num_in_decode_order[32];
  45. std::queue<torch::Tensor> decoded_frames;
  46. CUcontext cu_context = NULL;
  47. CUvideoctxlock ctx_lock;
  48. CUvideoparser parser = NULL;
  49. CUvideodecoder decoder = NULL;
  50. CUstream cuvidStream = 0;
  51. cudaVideoCodec video_codec = cudaVideoCodec_NumCodecs;
  52. cudaVideoChromaFormat video_chroma_format = cudaVideoChromaFormat_420;
  53. cudaVideoSurfaceFormat video_output_format = cudaVideoSurfaceFormat_NV12;
  54. CUVIDEOFORMAT cu_video_format = {};
  55. Rect display_rect = {};
  56. static int video_sequence_handler(
  57. void* user_data,
  58. CUVIDEOFORMAT* video_format) {
  59. return ((Decoder*)user_data)->handle_video_sequence(video_format);
  60. }
  61. static int picture_decode_handler(
  62. void* user_data,
  63. CUVIDPICPARAMS* pic_params) {
  64. return ((Decoder*)user_data)->handle_picture_decode(pic_params);
  65. }
  66. static int picture_display_handler(
  67. void* user_data,
  68. CUVIDPARSERDISPINFO* disp_info) {
  69. return ((Decoder*)user_data)->handle_picture_display(disp_info);
  70. }
  71. static int operating_point_handler(
  72. void* user_data,
  73. CUVIDOPERATINGPOINTINFO* operating_info) {
  74. return ((Decoder*)user_data)->get_operating_point(operating_info);
  75. }
  76. void query_hardware(CUVIDEOFORMAT*);
  77. int reconfigure_decoder(CUVIDEOFORMAT*);
  78. int handle_video_sequence(CUVIDEOFORMAT*);
  79. int handle_picture_decode(CUVIDPICPARAMS*);
  80. int handle_picture_display(CUVIDPARSERDISPINFO*);
  81. int get_operating_point(CUVIDOPERATINGPOINTINFO*);
  82. };