video_reader.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. #include "video_reader.h"
  2. #ifdef USE_PYTHON
  3. #include <Python.h>
  4. #endif
  5. #include "../decoder/memory_buffer.h"
  6. #include "../decoder/sync_decoder.h"
  7. #ifdef USE_PYTHON
  8. // If we are in a Windows environment, we need to define
  9. // initialization functions for the _custom_ops extension
  10. #ifdef _WIN32
  11. PyMODINIT_FUNC PyInit_video_reader(void) {
  12. // No need to do anything.
  13. return NULL;
  14. }
  15. #endif
  16. #endif // USE_PYTHONs
  17. using namespace ffmpeg;
  18. namespace vision {
  19. namespace video_reader {
  20. namespace {
  21. const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
  22. const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
  23. const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE};
  24. const size_t decoderTimeoutMs = 600000;
  25. // A jitter can be added to the end of the range to avoid conversion/rounding
  26. // error, small value 100us won't be enough to select the next frame, but enough
  27. // to compensate rounding error due to the multiple conversions.
  28. const size_t timeBaseJitterUs = 100;
  29. DecoderParameters getDecoderParams(
  30. int64_t videoStartUs,
  31. int64_t videoEndUs,
  32. double seekFrameMarginUs,
  33. int64_t getPtsOnly,
  34. int64_t readVideoStream,
  35. int videoWidth,
  36. int videoHeight,
  37. int videoMinDimension,
  38. int videoMaxDimension,
  39. int64_t readAudioStream,
  40. int audioSamples,
  41. int audioChannels) {
  42. DecoderParameters params;
  43. params.headerOnly = getPtsOnly != 0;
  44. params.seekAccuracy = seekFrameMarginUs;
  45. params.startOffset = videoStartUs;
  46. params.endOffset = videoEndUs;
  47. params.timeoutMs = decoderTimeoutMs;
  48. params.preventStaleness = false;
  49. if (readVideoStream == 1) {
  50. MediaFormat videoFormat(0);
  51. videoFormat.type = TYPE_VIDEO;
  52. videoFormat.format.video.format = defaultVideoPixelFormat;
  53. videoFormat.format.video.width = videoWidth;
  54. videoFormat.format.video.height = videoHeight;
  55. videoFormat.format.video.minDimension = videoMinDimension;
  56. videoFormat.format.video.maxDimension = videoMaxDimension;
  57. params.formats.insert(videoFormat);
  58. }
  59. if (readAudioStream == 1) {
  60. MediaFormat audioFormat;
  61. audioFormat.type = TYPE_AUDIO;
  62. audioFormat.format.audio.format = defaultAudioSampleFormat;
  63. audioFormat.format.audio.samples = audioSamples;
  64. audioFormat.format.audio.channels = audioChannels;
  65. params.formats.insert(audioFormat);
  66. }
  67. return params;
  68. }
  69. // returns number of written bytes
  70. template <typename T>
  71. size_t fillTensor(
  72. std::vector<DecoderOutputMessage>& msgs,
  73. torch::Tensor& frame,
  74. torch::Tensor& framePts,
  75. int64_t num,
  76. int64_t den) {
  77. if (msgs.empty()) {
  78. return 0;
  79. }
  80. T* frameData = frame.numel() > 0 ? frame.data_ptr<T>() : nullptr;
  81. int64_t* framePtsData = framePts.data_ptr<int64_t>();
  82. TORCH_CHECK_EQ(framePts.size(0), (int64_t)msgs.size());
  83. size_t avgElementsInFrame = frame.numel() / msgs.size();
  84. size_t offset = 0;
  85. for (size_t i = 0; i < msgs.size(); ++i) {
  86. const auto& msg = msgs[i];
  87. // convert pts into original time_base
  88. AVRational avr = AVRational{(int)num, (int)den};
  89. framePtsData[i] = av_rescale_q(msg.header.pts, timeBaseQ, avr);
  90. VLOG(2) << "PTS type: " << sizeof(T) << ", us: " << msg.header.pts
  91. << ", original: " << framePtsData[i];
  92. if (frameData) {
  93. auto sizeInBytes = msg.payload->length();
  94. memcpy(frameData + offset, msg.payload->data(), sizeInBytes);
  95. if (sizeof(T) == sizeof(uint8_t)) {
  96. // Video - move by allocated frame size
  97. offset += avgElementsInFrame / sizeof(T);
  98. } else {
  99. // Audio - move by number of samples
  100. offset += sizeInBytes / sizeof(T);
  101. }
  102. }
  103. }
  104. return offset * sizeof(T);
  105. }
  106. size_t fillVideoTensor(
  107. std::vector<DecoderOutputMessage>& msgs,
  108. torch::Tensor& videoFrame,
  109. torch::Tensor& videoFramePts,
  110. int64_t num,
  111. int64_t den) {
  112. return fillTensor<uint8_t>(msgs, videoFrame, videoFramePts, num, den);
  113. }
  114. size_t fillAudioTensor(
  115. std::vector<DecoderOutputMessage>& msgs,
  116. torch::Tensor& audioFrame,
  117. torch::Tensor& audioFramePts,
  118. int64_t num,
  119. int64_t den) {
  120. return fillTensor<float>(msgs, audioFrame, audioFramePts, num, den);
  121. }
  122. void offsetsToUs(
  123. double& seekFrameMargin,
  124. int64_t readVideoStream,
  125. int64_t videoStartPts,
  126. int64_t videoEndPts,
  127. int64_t videoTimeBaseNum,
  128. int64_t videoTimeBaseDen,
  129. int64_t readAudioStream,
  130. int64_t audioStartPts,
  131. int64_t audioEndPts,
  132. int64_t audioTimeBaseNum,
  133. int64_t audioTimeBaseDen,
  134. int64_t& videoStartUs,
  135. int64_t& videoEndUs) {
  136. seekFrameMargin *= AV_TIME_BASE;
  137. videoStartUs = 0;
  138. videoEndUs = -1;
  139. if (readVideoStream) {
  140. AVRational vr = AVRational{(int)videoTimeBaseNum, (int)videoTimeBaseDen};
  141. if (videoStartPts > 0) {
  142. videoStartUs = av_rescale_q(videoStartPts, vr, timeBaseQ);
  143. }
  144. if (videoEndPts > 0) {
  145. // Add jitter to the end of the range to avoid conversion/rounding error.
  146. // Small value 100us won't be enough to select the next frame, but enough
  147. // to compensate rounding error due to the multiple conversions.
  148. videoEndUs = timeBaseJitterUs + av_rescale_q(videoEndPts, vr, timeBaseQ);
  149. }
  150. } else if (readAudioStream) {
  151. AVRational ar = AVRational{(int)audioTimeBaseNum, (int)audioTimeBaseDen};
  152. if (audioStartPts > 0) {
  153. videoStartUs = av_rescale_q(audioStartPts, ar, timeBaseQ);
  154. }
  155. if (audioEndPts > 0) {
  156. // Add jitter to the end of the range to avoid conversion/rounding error.
  157. // Small value 100us won't be enough to select the next frame, but enough
  158. // to compensate rounding error due to the multiple conversions.
  159. videoEndUs = timeBaseJitterUs + av_rescale_q(audioEndPts, ar, timeBaseQ);
  160. }
  161. }
  162. }
  163. torch::List<torch::Tensor> readVideo(
  164. bool isReadFile,
  165. const torch::Tensor& input_video,
  166. std::string videoPath,
  167. double seekFrameMargin,
  168. int64_t getPtsOnly,
  169. int64_t readVideoStream,
  170. int64_t width,
  171. int64_t height,
  172. int64_t minDimension,
  173. int64_t maxDimension,
  174. int64_t videoStartPts,
  175. int64_t videoEndPts,
  176. int64_t videoTimeBaseNum,
  177. int64_t videoTimeBaseDen,
  178. int64_t readAudioStream,
  179. int64_t audioSamples,
  180. int64_t audioChannels,
  181. int64_t audioStartPts,
  182. int64_t audioEndPts,
  183. int64_t audioTimeBaseNum,
  184. int64_t audioTimeBaseDen) {
  185. int64_t videoStartUs, videoEndUs;
  186. offsetsToUs(
  187. seekFrameMargin,
  188. readVideoStream,
  189. videoStartPts,
  190. videoEndPts,
  191. videoTimeBaseNum,
  192. videoTimeBaseDen,
  193. readAudioStream,
  194. audioStartPts,
  195. audioEndPts,
  196. audioTimeBaseNum,
  197. audioTimeBaseDen,
  198. videoStartUs,
  199. videoEndUs);
  200. DecoderParameters params = getDecoderParams(
  201. videoStartUs, // videoStartPts
  202. videoEndUs, // videoEndPts
  203. seekFrameMargin, // seekFrameMargin
  204. getPtsOnly, // getPtsOnly
  205. readVideoStream, // readVideoStream
  206. width, // width
  207. height, // height
  208. minDimension, // minDimension
  209. maxDimension, // maxDimension
  210. readAudioStream, // readAudioStream
  211. audioSamples, // audioSamples
  212. audioChannels // audioChannels
  213. );
  214. SyncDecoder decoder;
  215. std::vector<DecoderOutputMessage> audioMessages, videoMessages;
  216. DecoderInCallback callback = nullptr;
  217. std::string logMessage, logType;
  218. if (isReadFile) {
  219. params.uri = videoPath;
  220. logType = "file";
  221. logMessage = videoPath;
  222. } else {
  223. callback = MemoryBuffer::getCallback(
  224. input_video.data_ptr<uint8_t>(), input_video.size(0));
  225. logType = "memory";
  226. logMessage = std::to_string(input_video.size(0));
  227. }
  228. VLOG(1) << "Video decoding from " << logType << " [" << logMessage
  229. << "] has started";
  230. const auto now = std::chrono::system_clock::now();
  231. bool succeeded;
  232. DecoderMetadata audioMetadata, videoMetadata;
  233. std::vector<DecoderMetadata> metadata;
  234. if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
  235. for (const auto& header : metadata) {
  236. if (header.format.type == TYPE_VIDEO) {
  237. videoMetadata = header;
  238. } else if (header.format.type == TYPE_AUDIO) {
  239. audioMetadata = header;
  240. }
  241. }
  242. int res;
  243. DecoderOutputMessage msg;
  244. while (0 == (res = decoder.decode(&msg, decoderTimeoutMs))) {
  245. if (msg.header.format.type == TYPE_VIDEO) {
  246. videoMessages.push_back(std::move(msg));
  247. }
  248. if (msg.header.format.type == TYPE_AUDIO) {
  249. audioMessages.push_back(std::move(msg));
  250. }
  251. msg.payload.reset();
  252. }
  253. } else {
  254. LOG(ERROR) << "Decoder initialization has failed";
  255. }
  256. const auto then = std::chrono::system_clock::now();
  257. VLOG(1) << "Video decoding from " << logType << " [" << logMessage
  258. << "] has finished, "
  259. << std::chrono::duration_cast<std::chrono::microseconds>(then - now)
  260. .count()
  261. << " us";
  262. decoder.shutdown();
  263. // video section
  264. torch::Tensor videoFrame = torch::zeros({0}, torch::kByte);
  265. torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong);
  266. torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
  267. torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
  268. torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
  269. if (succeeded && readVideoStream == 1) {
  270. if (!videoMessages.empty()) {
  271. const auto& header = videoMetadata;
  272. const auto& format = header.format.format.video;
  273. int numVideoFrames = videoMessages.size();
  274. int outHeight = format.height;
  275. int outWidth = format.width;
  276. int numChannels = 3; // decoder guarantees the default AV_PIX_FMT_RGB24
  277. size_t expectedWrittenBytes = 0;
  278. if (getPtsOnly == 0) {
  279. videoFrame = torch::zeros(
  280. {numVideoFrames, outHeight, outWidth, numChannels}, torch::kByte);
  281. expectedWrittenBytes =
  282. (size_t)numVideoFrames * outHeight * outWidth * numChannels;
  283. }
  284. videoFramePts = torch::zeros({numVideoFrames}, torch::kLong);
  285. VLOG(2) << "video duration: " << header.duration
  286. << ", fps: " << header.fps << ", num: " << header.num
  287. << ", den: " << header.den << ", num frames: " << numVideoFrames;
  288. auto numberWrittenBytes = fillVideoTensor(
  289. videoMessages, videoFrame, videoFramePts, header.num, header.den);
  290. TORCH_CHECK_EQ(numberWrittenBytes, expectedWrittenBytes);
  291. videoTimeBase = torch::zeros({2}, torch::kInt);
  292. int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
  293. videoTimeBaseData[0] = header.num;
  294. videoTimeBaseData[1] = header.den;
  295. videoFps = torch::zeros({1}, torch::kFloat);
  296. float* videoFpsData = videoFps.data_ptr<float>();
  297. videoFpsData[0] = header.fps;
  298. videoDuration = torch::zeros({1}, torch::kLong);
  299. int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
  300. AVRational vr = AVRational{(int)header.num, (int)header.den};
  301. videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, vr);
  302. VLOG(1) << "Video decoding from " << logType << " [" << logMessage
  303. << "] filled video tensors";
  304. } else {
  305. VLOG(1) << "Miss video stream";
  306. }
  307. }
  308. // audio section
  309. torch::Tensor audioFrame = torch::zeros({0}, torch::kFloat);
  310. torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong);
  311. torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
  312. torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
  313. torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
  314. if (succeeded && readAudioStream == 1) {
  315. if (!audioMessages.empty()) {
  316. const auto& header = audioMetadata;
  317. const auto& format = header.format.format.audio;
  318. int64_t outAudioChannels = format.channels;
  319. int bytesPerSample =
  320. av_get_bytes_per_sample(static_cast<AVSampleFormat>(format.format));
  321. int numAudioFrames = audioMessages.size();
  322. int64_t numAudioSamples = 0;
  323. if (getPtsOnly == 0) {
  324. int64_t frameSizeTotal = 0;
  325. for (auto const& audioMessage : audioMessages) {
  326. frameSizeTotal += audioMessage.payload->length();
  327. }
  328. TORCH_CHECK_EQ(frameSizeTotal % (outAudioChannels * bytesPerSample), 0);
  329. numAudioSamples = frameSizeTotal / (outAudioChannels * bytesPerSample);
  330. audioFrame =
  331. torch::zeros({numAudioSamples, outAudioChannels}, torch::kFloat);
  332. }
  333. audioFramePts = torch::zeros({numAudioFrames}, torch::kLong);
  334. VLOG(2) << "audio duration: " << header.duration
  335. << ", channels: " << format.channels
  336. << ", sample rate: " << format.samples << ", num: " << header.num
  337. << ", den: " << header.den;
  338. auto numberWrittenBytes = fillAudioTensor(
  339. audioMessages, audioFrame, audioFramePts, header.num, header.den);
  340. TORCH_CHECK_EQ(
  341. numberWrittenBytes,
  342. numAudioSamples * outAudioChannels * sizeof(float));
  343. audioTimeBase = torch::zeros({2}, torch::kInt);
  344. int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
  345. audioTimeBaseData[0] = header.num;
  346. audioTimeBaseData[1] = header.den;
  347. audioSampleRate = torch::zeros({1}, torch::kInt);
  348. int* audioSampleRateData = audioSampleRate.data_ptr<int>();
  349. audioSampleRateData[0] = format.samples;
  350. audioDuration = torch::zeros({1}, torch::kLong);
  351. int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
  352. AVRational ar = AVRational{(int)header.num, (int)header.den};
  353. audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, ar);
  354. VLOG(1) << "Video decoding from " << logType << " [" << logMessage
  355. << "] filled audio tensors";
  356. } else {
  357. VLOG(1) << "Miss audio stream";
  358. }
  359. }
  360. torch::List<torch::Tensor> result;
  361. result.push_back(std::move(videoFrame));
  362. result.push_back(std::move(videoFramePts));
  363. result.push_back(std::move(videoTimeBase));
  364. result.push_back(std::move(videoFps));
  365. result.push_back(std::move(videoDuration));
  366. result.push_back(std::move(audioFrame));
  367. result.push_back(std::move(audioFramePts));
  368. result.push_back(std::move(audioTimeBase));
  369. result.push_back(std::move(audioSampleRate));
  370. result.push_back(std::move(audioDuration));
  371. VLOG(1) << "Video decoding from " << logType << " [" << logMessage
  372. << "] about to return";
  373. return result;
  374. }
  375. torch::List<torch::Tensor> probeVideo(
  376. bool isReadFile,
  377. const torch::Tensor& input_video,
  378. std::string videoPath) {
  379. DecoderParameters params = getDecoderParams(
  380. 0, // videoStartUs
  381. -1, // videoEndUs
  382. 0, // seekFrameMargin
  383. 1, // getPtsOnly
  384. 1, // readVideoStream
  385. 0, // width
  386. 0, // height
  387. 0, // minDimension
  388. 0, // maxDimension
  389. 1, // readAudioStream
  390. 0, // audioSamples
  391. 0 // audioChannels
  392. );
  393. SyncDecoder decoder;
  394. DecoderInCallback callback = nullptr;
  395. std::string logMessage, logType;
  396. if (isReadFile) {
  397. params.uri = videoPath;
  398. logType = "file";
  399. logMessage = videoPath;
  400. } else {
  401. callback = MemoryBuffer::getCallback(
  402. input_video.data_ptr<uint8_t>(), input_video.size(0));
  403. logType = "memory";
  404. logMessage = std::to_string(input_video.size(0));
  405. }
  406. VLOG(1) << "Video probing from " << logType << " [" << logMessage
  407. << "] has started";
  408. const auto now = std::chrono::system_clock::now();
  409. bool succeeded;
  410. bool gotAudio = false, gotVideo = false;
  411. DecoderMetadata audioMetadata, videoMetadata;
  412. std::vector<DecoderMetadata> metadata;
  413. if ((succeeded = decoder.init(params, std::move(callback), &metadata))) {
  414. for (const auto& header : metadata) {
  415. if (header.format.type == TYPE_VIDEO) {
  416. gotVideo = true;
  417. videoMetadata = header;
  418. } else if (header.format.type == TYPE_AUDIO) {
  419. gotAudio = true;
  420. audioMetadata = header;
  421. }
  422. }
  423. const auto then = std::chrono::system_clock::now();
  424. VLOG(1) << "Video probing from " << logType << " [" << logMessage
  425. << "] has finished, "
  426. << std::chrono::duration_cast<std::chrono::microseconds>(then - now)
  427. .count()
  428. << " us";
  429. } else {
  430. LOG(ERROR) << "Decoder initialization has failed";
  431. }
  432. decoder.shutdown();
  433. // video section
  434. torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt);
  435. torch::Tensor videoFps = torch::zeros({0}, torch::kFloat);
  436. torch::Tensor videoDuration = torch::zeros({0}, torch::kLong);
  437. if (succeeded && gotVideo) {
  438. videoTimeBase = torch::zeros({2}, torch::kInt);
  439. int* videoTimeBaseData = videoTimeBase.data_ptr<int>();
  440. const auto& header = videoMetadata;
  441. videoTimeBaseData[0] = header.num;
  442. videoTimeBaseData[1] = header.den;
  443. videoFps = torch::zeros({1}, torch::kFloat);
  444. float* videoFpsData = videoFps.data_ptr<float>();
  445. videoFpsData[0] = header.fps;
  446. videoDuration = torch::zeros({1}, torch::kLong);
  447. int64_t* videoDurationData = videoDuration.data_ptr<int64_t>();
  448. AVRational avr = AVRational{(int)header.num, (int)header.den};
  449. videoDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
  450. VLOG(2) << "Prob fps: " << header.fps << ", duration: " << header.duration
  451. << ", num: " << header.num << ", den: " << header.den;
  452. VLOG(1) << "Video probing from " << logType << " [" << logMessage
  453. << "] filled video tensors";
  454. } else {
  455. LOG(ERROR) << "Miss video stream";
  456. }
  457. // audio section
  458. torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt);
  459. torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt);
  460. torch::Tensor audioDuration = torch::zeros({0}, torch::kLong);
  461. if (succeeded && gotAudio) {
  462. audioTimeBase = torch::zeros({2}, torch::kInt);
  463. int* audioTimeBaseData = audioTimeBase.data_ptr<int>();
  464. const auto& header = audioMetadata;
  465. const auto& media = header.format;
  466. const auto& format = media.format.audio;
  467. audioTimeBaseData[0] = header.num;
  468. audioTimeBaseData[1] = header.den;
  469. audioSampleRate = torch::zeros({1}, torch::kInt);
  470. int* audioSampleRateData = audioSampleRate.data_ptr<int>();
  471. audioSampleRateData[0] = format.samples;
  472. audioDuration = torch::zeros({1}, torch::kLong);
  473. int64_t* audioDurationData = audioDuration.data_ptr<int64_t>();
  474. AVRational avr = AVRational{(int)header.num, (int)header.den};
  475. audioDurationData[0] = av_rescale_q(header.duration, timeBaseQ, avr);
  476. VLOG(2) << "Prob sample rate: " << format.samples
  477. << ", duration: " << header.duration << ", num: " << header.num
  478. << ", den: " << header.den;
  479. VLOG(1) << "Video probing from " << logType << " [" << logMessage
  480. << "] filled audio tensors";
  481. } else {
  482. VLOG(1) << "Miss audio stream";
  483. }
  484. torch::List<torch::Tensor> result;
  485. result.push_back(std::move(videoTimeBase));
  486. result.push_back(std::move(videoFps));
  487. result.push_back(std::move(videoDuration));
  488. result.push_back(std::move(audioTimeBase));
  489. result.push_back(std::move(audioSampleRate));
  490. result.push_back(std::move(audioDuration));
  491. VLOG(1) << "Video probing from " << logType << " [" << logMessage
  492. << "] is about to return";
  493. return result;
  494. }
  495. } // namespace
  496. torch::List<torch::Tensor> read_video_from_memory(
  497. torch::Tensor input_video,
  498. double seekFrameMargin,
  499. int64_t getPtsOnly,
  500. int64_t readVideoStream,
  501. int64_t width,
  502. int64_t height,
  503. int64_t minDimension,
  504. int64_t maxDimension,
  505. int64_t videoStartPts,
  506. int64_t videoEndPts,
  507. int64_t videoTimeBaseNum,
  508. int64_t videoTimeBaseDen,
  509. int64_t readAudioStream,
  510. int64_t audioSamples,
  511. int64_t audioChannels,
  512. int64_t audioStartPts,
  513. int64_t audioEndPts,
  514. int64_t audioTimeBaseNum,
  515. int64_t audioTimeBaseDen) {
  516. C10_LOG_API_USAGE_ONCE(
  517. "torchvision.csrc.io.video_reader.video_reader.read_video_from_memory");
  518. return readVideo(
  519. false,
  520. input_video,
  521. "", // videoPath
  522. seekFrameMargin,
  523. getPtsOnly,
  524. readVideoStream,
  525. width,
  526. height,
  527. minDimension,
  528. maxDimension,
  529. videoStartPts,
  530. videoEndPts,
  531. videoTimeBaseNum,
  532. videoTimeBaseDen,
  533. readAudioStream,
  534. audioSamples,
  535. audioChannels,
  536. audioStartPts,
  537. audioEndPts,
  538. audioTimeBaseNum,
  539. audioTimeBaseDen);
  540. }
  541. torch::List<torch::Tensor> read_video_from_file(
  542. std::string videoPath,
  543. double seekFrameMargin,
  544. int64_t getPtsOnly,
  545. int64_t readVideoStream,
  546. int64_t width,
  547. int64_t height,
  548. int64_t minDimension,
  549. int64_t maxDimension,
  550. int64_t videoStartPts,
  551. int64_t videoEndPts,
  552. int64_t videoTimeBaseNum,
  553. int64_t videoTimeBaseDen,
  554. int64_t readAudioStream,
  555. int64_t audioSamples,
  556. int64_t audioChannels,
  557. int64_t audioStartPts,
  558. int64_t audioEndPts,
  559. int64_t audioTimeBaseNum,
  560. int64_t audioTimeBaseDen) {
  561. C10_LOG_API_USAGE_ONCE(
  562. "torchvision.csrc.io.video_reader.video_reader.read_video_from_file");
  563. torch::Tensor dummy_input_video = torch::ones({0});
  564. return readVideo(
  565. true,
  566. dummy_input_video,
  567. videoPath,
  568. seekFrameMargin,
  569. getPtsOnly,
  570. readVideoStream,
  571. width,
  572. height,
  573. minDimension,
  574. maxDimension,
  575. videoStartPts,
  576. videoEndPts,
  577. videoTimeBaseNum,
  578. videoTimeBaseDen,
  579. readAudioStream,
  580. audioSamples,
  581. audioChannels,
  582. audioStartPts,
  583. audioEndPts,
  584. audioTimeBaseNum,
  585. audioTimeBaseDen);
  586. }
  587. torch::List<torch::Tensor> probe_video_from_memory(torch::Tensor input_video) {
  588. C10_LOG_API_USAGE_ONCE(
  589. "torchvision.csrc.io.video_reader.video_reader.probe_video_from_memory");
  590. return probeVideo(false, input_video, "");
  591. }
  592. torch::List<torch::Tensor> probe_video_from_file(std::string videoPath) {
  593. C10_LOG_API_USAGE_ONCE(
  594. "torchvision.csrc.io.video_reader.video_reader.probe_video_from_file");
  595. torch::Tensor dummy_input_video = torch::ones({0});
  596. return probeVideo(true, dummy_input_video, videoPath);
  597. }
  598. TORCH_LIBRARY_FRAGMENT(video_reader, m) {
  599. m.def("read_video_from_memory", read_video_from_memory);
  600. m.def("read_video_from_file", read_video_from_file);
  601. m.def("probe_video_from_memory", probe_video_from_memory);
  602. m.def("probe_video_from_file", probe_video_from_file);
  603. }
  604. } // namespace video_reader
  605. } // namespace vision