gpu_decoder.cpp 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #include "gpu_decoder.h"
  2. #include <c10/cuda/CUDAGuard.h>
  3. /* Set cuda device, create cuda context and initialise the demuxer and decoder.
  4. */
  5. GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
  6. : demuxer(src_file.c_str()) {
  7. at::cuda::CUDAGuard device_guard(dev);
  8. device = device_guard.current_device().index();
  9. check_for_cuda_errors(
  10. cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
  11. decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
  12. initialised = true;
  13. }
  14. GPUDecoder::~GPUDecoder() {
  15. at::cuda::CUDAGuard device_guard(device);
  16. decoder.release();
  17. if (initialised) {
  18. check_for_cuda_errors(
  19. cuDevicePrimaryCtxRelease(device), __LINE__, __FILE__);
  20. }
  21. }
  22. /* Fetch a decoded frame tensor after demuxing and decoding.
  23. */
  24. torch::Tensor GPUDecoder::decode() {
  25. torch::Tensor frameTensor;
  26. unsigned long videoBytes = 0;
  27. uint8_t* video = nullptr;
  28. at::cuda::CUDAGuard device_guard(device);
  29. torch::Tensor frame;
  30. do {
  31. demuxer.demux(&video, &videoBytes);
  32. decoder.decode(video, videoBytes);
  33. frame = decoder.fetch_frame();
  34. } while (frame.numel() == 0 && videoBytes > 0);
  35. return frame;
  36. }
  37. /* Seek to a passed timestamp. The second argument controls whether to seek to a
  38. * keyframe.
  39. */
  40. void GPUDecoder::seek(double timestamp, bool keyframes_only) {
  41. int flag = keyframes_only ? 0 : AVSEEK_FLAG_ANY;
  42. demuxer.seek(timestamp, flag);
  43. }
  44. c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
  45. get_metadata() const {
  46. c10::Dict<std::string, c10::Dict<std::string, double>> metadata;
  47. c10::Dict<std::string, double> video_metadata;
  48. video_metadata.insert("duration", demuxer.get_duration());
  49. video_metadata.insert("fps", demuxer.get_fps());
  50. metadata.insert("video", video_metadata);
  51. return metadata;
  52. }
  53. TORCH_LIBRARY(torchvision, m) {
  54. m.class_<GPUDecoder>("GPUDecoder")
  55. .def(torch::init<std::string, torch::Device>())
  56. .def("seek", &GPUDecoder::seek)
  57. .def("get_metadata", &GPUDecoder::get_metadata)
  58. .def("next", &GPUDecoder::decode);
  59. }