video.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #include "video.h"
  2. #include <regex>
  3. namespace vision {
  4. namespace video {
  5. namespace {
  6. const size_t decoderTimeoutMs = 600000;
  7. const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
  8. // returns number of written bytes
  9. template <typename T>
  10. size_t fillTensorList(DecoderOutputMessage& msgs, torch::Tensor& frame) {
  11. const auto& msg = msgs;
  12. T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
  13. if (frameData) {
  14. auto sizeInBytes = msg.payload->length();
  15. memcpy(frameData, msg.payload->data(), sizeInBytes);
  16. }
  17. return sizeof(T);
  18. }
  19. size_t fillVideoTensor(DecoderOutputMessage& msgs, torch::Tensor& videoFrame) {
  20. return fillTensorList<uint8_t>(msgs, videoFrame);
  21. }
  22. size_t fillAudioTensor(DecoderOutputMessage& msgs, torch::Tensor& audioFrame) {
  23. return fillTensorList<float>(msgs, audioFrame);
  24. }
  25. std::array<std::pair<std::string, ffmpeg::MediaType>, 4>::const_iterator
  26. _parse_type(const std::string& stream_string) {
  27. static const std::array<std::pair<std::string, MediaType>, 4> types = {{
  28. {"video", TYPE_VIDEO},
  29. {"audio", TYPE_AUDIO},
  30. {"subtitle", TYPE_SUBTITLE},
  31. {"cc", TYPE_CC},
  32. }};
  33. auto device = std::find_if(
  34. types.begin(),
  35. types.end(),
  36. [stream_string](const std::pair<std::string, MediaType>& p) {
  37. return p.first == stream_string;
  38. });
  39. if (device != types.end()) {
  40. return device;
  41. }
  42. TORCH_CHECK(
  43. false, "Expected one of [audio, video, subtitle, cc] ", stream_string);
  44. }
  45. std::string parse_type_to_string(const std::string& stream_string) {
  46. auto device = _parse_type(stream_string);
  47. return device->first;
  48. }
  49. MediaType parse_type_to_mt(const std::string& stream_string) {
  50. auto device = _parse_type(stream_string);
  51. return device->second;
  52. }
  53. std::tuple<std::string, long> _parseStream(const std::string& streamString) {
  54. TORCH_CHECK(!streamString.empty(), "Stream string must not be empty");
  55. static const std::regex regex("([a-zA-Z_]+)(?::([1-9]\\d*|0))?");
  56. std::smatch match;
  57. TORCH_CHECK(
  58. std::regex_match(streamString, match, regex),
  59. "Invalid stream string: '",
  60. streamString,
  61. "'");
  62. std::string type_ = "video";
  63. type_ = parse_type_to_string(match[1].str());
  64. long index_ = -1;
  65. if (match[2].matched) {
  66. try {
  67. index_ = c10::stoi(match[2].str());
  68. } catch (const std::exception&) {
  69. TORCH_CHECK(
  70. false,
  71. "Could not parse device index '",
  72. match[2].str(),
  73. "' in device string '",
  74. streamString,
  75. "'");
  76. }
  77. }
  78. return std::make_tuple(type_, index_);
  79. }
  80. } // namespace
  81. void Video::_getDecoderParams(
  82. double videoStartS,
  83. int64_t getPtsOnly,
  84. std::string stream,
  85. long stream_id = -1,
  86. bool fastSeek = true,
  87. bool all_streams = false,
  88. int64_t num_threads = 1,
  89. double seekFrameMarginUs = 10) {
  90. int64_t videoStartUs = int64_t(videoStartS * 1e6);
  91. params.timeoutMs = decoderTimeoutMs;
  92. params.startOffset = videoStartUs;
  93. params.seekAccuracy = seekFrameMarginUs;
  94. params.fastSeek = fastSeek;
  95. params.headerOnly = false;
  96. params.numThreads = num_threads;
  97. params.preventStaleness = false; // not sure what this is about
  98. if (all_streams == true) {
  99. MediaFormat format;
  100. format.stream = -2;
  101. format.type = TYPE_AUDIO;
  102. params.formats.insert(format);
  103. format.type = TYPE_VIDEO;
  104. format.stream = -2;
  105. format.format.video.width = 0;
  106. format.format.video.height = 0;
  107. format.format.video.cropImage = 0;
  108. format.format.video.format = defaultVideoPixelFormat;
  109. params.formats.insert(format);
  110. format.type = TYPE_SUBTITLE;
  111. format.stream = -2;
  112. params.formats.insert(format);
  113. format.type = TYPE_CC;
  114. format.stream = -2;
  115. params.formats.insert(format);
  116. } else {
  117. // parse stream type
  118. MediaType stream_type = parse_type_to_mt(stream);
  119. // TODO: reset params.formats
  120. std::set<MediaFormat> formats;
  121. params.formats = formats;
  122. // Define new format
  123. MediaFormat format;
  124. format.type = stream_type;
  125. format.stream = stream_id;
  126. if (stream_type == TYPE_VIDEO) {
  127. format.format.video.width = 0;
  128. format.format.video.height = 0;
  129. format.format.video.cropImage = 0;
  130. format.format.video.format = defaultVideoPixelFormat;
  131. }
  132. params.formats.insert(format);
  133. }
  134. } // _get decoder params
  135. void Video::initFromFile(
  136. std::string videoPath,
  137. std::string stream,
  138. int64_t numThreads) {
  139. TORCH_CHECK(!initialized, "Video object can only be initialized once");
  140. initialized = true;
  141. params.uri = videoPath;
  142. _init(stream, numThreads);
  143. }
  144. void Video::initFromMemory(
  145. torch::Tensor videoTensor,
  146. std::string stream,
  147. int64_t numThreads) {
  148. TORCH_CHECK(!initialized, "Video object can only be initialized once");
  149. initialized = true;
  150. callback = MemoryBuffer::getCallback(
  151. videoTensor.data_ptr<uint8_t>(), videoTensor.size(0));
  152. _init(stream, numThreads);
  153. }
  154. void Video::_init(std::string stream, int64_t numThreads) {
  155. // set number of threads global
  156. numThreads_ = numThreads;
  157. // parse stream information
  158. current_stream = _parseStream(stream);
  159. // note that in the initial call we want to get all streams
  160. _getDecoderParams(
  161. 0, // video start
  162. 0, // headerOnly
  163. std::get<0>(current_stream), // stream info - remove that
  164. long(-1), // stream_id parsed from info above change to -2
  165. false, // fastseek: we're using the default param here
  166. true, // read all streams
  167. numThreads_ // global number of Threads for decoding
  168. );
  169. std::string logMessage, logType;
  170. // locals
  171. std::vector<double> audioFPS, videoFPS;
  172. std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
  173. std::vector<double> audioTB, videoTB, ccTB, subsTB;
  174. c10::Dict<std::string, std::vector<double>> audioMetadata;
  175. c10::Dict<std::string, std::vector<double>> videoMetadata;
  176. c10::Dict<std::string, std::vector<double>> ccMetadata;
  177. c10::Dict<std::string, std::vector<double>> subsMetadata;
  178. // callback and metadata defined in struct
  179. DecoderInCallback tmp_callback = callback;
  180. succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
  181. if (succeeded) {
  182. for (const auto& header : metadata) {
  183. double fps = double(header.fps);
  184. double duration = double(header.duration) * 1e-6; // * timeBase;
  185. if (header.format.type == TYPE_VIDEO) {
  186. videoFPS.push_back(fps);
  187. videoDuration.push_back(duration);
  188. } else if (header.format.type == TYPE_AUDIO) {
  189. audioFPS.push_back(fps);
  190. audioDuration.push_back(duration);
  191. } else if (header.format.type == TYPE_CC) {
  192. ccDuration.push_back(duration);
  193. } else if (header.format.type == TYPE_SUBTITLE) {
  194. subsDuration.push_back(duration);
  195. };
  196. }
  197. }
  198. // audio
  199. audioMetadata.insert("duration", audioDuration);
  200. audioMetadata.insert("framerate", audioFPS);
  201. // video
  202. videoMetadata.insert("duration", videoDuration);
  203. videoMetadata.insert("fps", videoFPS);
  204. // subs
  205. subsMetadata.insert("duration", subsDuration);
  206. // cc
  207. ccMetadata.insert("duration", ccDuration);
  208. // put all to a data
  209. streamsMetadata.insert("video", videoMetadata);
  210. streamsMetadata.insert("audio", audioMetadata);
  211. streamsMetadata.insert("subtitles", subsMetadata);
  212. streamsMetadata.insert("cc", ccMetadata);
  213. succeeded = setCurrentStream(stream);
  214. LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
  215. if (std::get<1>(current_stream) != -1) {
  216. LOG(INFO)
  217. << "Stream index set to " << std::get<1>(current_stream)
  218. << ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
  219. }
  220. }
  221. Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
  222. C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
  223. if (!videoPath.empty()) {
  224. initFromFile(videoPath, stream, numThreads);
  225. }
  226. } // video
  227. bool Video::setCurrentStream(std::string stream = "video") {
  228. TORCH_CHECK(initialized, "Video object has to be initialized first");
  229. if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
  230. current_stream = _parseStream(stream);
  231. }
  232. double ts = 0;
  233. if (seekTS > 0) {
  234. ts = seekTS;
  235. }
  236. _getDecoderParams(
  237. ts, // video start
  238. 0, // headerOnly
  239. std::get<0>(current_stream), // stream
  240. long(std::get<1>(
  241. current_stream)), // stream_id parsed from info above change to -2
  242. false, // fastseek param set to 0 false by default (changed in seek)
  243. false, // read all streams
  244. numThreads_ // global number of threads
  245. );
  246. // callback and metadata defined in Video.h
  247. DecoderInCallback tmp_callback = callback;
  248. return (decoder.init(params, std::move(tmp_callback), &metadata));
  249. }
  250. std::tuple<std::string, int64_t> Video::getCurrentStream() const {
  251. TORCH_CHECK(initialized, "Video object has to be initialized first");
  252. return current_stream;
  253. }
  254. c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
  255. getStreamMetadata() const {
  256. TORCH_CHECK(initialized, "Video object has to be initialized first");
  257. return streamsMetadata;
  258. }
  259. void Video::Seek(double ts, bool fastSeek = false) {
  260. TORCH_CHECK(initialized, "Video object has to be initialized first");
  261. // initialize the class variables used for seeking and retrurn
  262. _getDecoderParams(
  263. ts, // video start
  264. 0, // headerOnly
  265. std::get<0>(current_stream), // stream
  266. long(std::get<1>(
  267. current_stream)), // stream_id parsed from info above change to -2
  268. fastSeek, // fastseek
  269. false, // read all streams
  270. numThreads_ // global number of threads
  271. );
  272. // callback and metadata defined in Video.h
  273. DecoderInCallback tmp_callback = callback;
  274. succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
  275. LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
  276. }
  277. std::tuple<torch::Tensor, double> Video::Next() {
  278. TORCH_CHECK(initialized, "Video object has to be initialized first");
  279. // if failing to decode simply return a null tensor (note, should we
  280. // raise an exception?)
  281. double frame_pts_s;
  282. torch::Tensor outFrame = torch::zeros({0}, torch::kByte);
  283. // decode single frame
  284. DecoderOutputMessage out;
  285. int64_t res = decoder.decode(&out, decoderTimeoutMs);
  286. // if successful
  287. if (res == 0) {
  288. frame_pts_s = double(double(out.header.pts) * 1e-6);
  289. auto header = out.header;
  290. const auto& format = header.format;
  291. // initialize the output variables based on type
  292. if (format.type == TYPE_VIDEO) {
  293. // note: this can potentially be optimized
  294. // by having the global tensor that we fill at decode time
  295. // (would avoid allocations)
  296. int outHeight = format.format.video.height;
  297. int outWidth = format.format.video.width;
  298. int numChannels = 3;
  299. outFrame = torch::zeros({outHeight, outWidth, numChannels}, torch::kByte);
  300. fillVideoTensor(out, outFrame);
  301. outFrame = outFrame.permute({2, 0, 1});
  302. } else if (format.type == TYPE_AUDIO) {
  303. int outAudioChannels = format.format.audio.channels;
  304. int bytesPerSample = av_get_bytes_per_sample(
  305. static_cast<AVSampleFormat>(format.format.audio.format));
  306. int frameSizeTotal = out.payload->length();
  307. TORCH_CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
  308. int numAudioSamples =
  309. frameSizeTotal / (outAudioChannels * bytesPerSample);
  310. outFrame =
  311. torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
  312. fillAudioTensor(out, outFrame);
  313. }
  314. // currently not supporting other formats (will do soon)
  315. out.payload.reset();
  316. } else if (res == ENODATA) {
  317. LOG(INFO) << "Decoder ran out of frames (ENODATA)\n";
  318. } else {
  319. LOG(ERROR) << "Decoder failed with ERROR_CODE " << res;
  320. }
  321. return std::make_tuple(outFrame, frame_pts_s);
  322. }
  323. static auto registerVideo =
  324. torch::class_<Video>("torchvision", "Video")
  325. .def(torch::init<std::string, std::string, int64_t>())
  326. .def("init_from_file", &Video::initFromFile)
  327. .def("init_from_memory", &Video::initFromMemory)
  328. .def("get_current_stream", &Video::getCurrentStream)
  329. .def("set_current_stream", &Video::setCurrentStream)
  330. .def("get_metadata", &Video::getStreamMetadata)
  331. .def("seek", &Video::Seek)
  332. .def("next", &Video::Next);
  333. } // namespace video
  334. } // namespace vision