gpu_decoder.h 468 B

1234567891011121314151617181920
  1. #include <torch/custom_class.h>
  2. #include <torch/torch.h>
  3. #include "decoder.h"
  4. #include "demuxer.h"
  5. class GPUDecoder : public torch::CustomClassHolder {
  6. public:
  7. GPUDecoder(std::string, torch::Device);
  8. ~GPUDecoder();
  9. torch::Tensor decode();
  10. void seek(double, bool);
  11. c10::Dict<std::string, c10::Dict<std::string, double>> get_metadata() const;
  12. private:
  13. Demuxer demuxer;
  14. CUcontext ctx;
  15. Decoder decoder;
  16. int64_t device;
  17. bool initialised = false;
  18. };