mock_peer_connection_observers.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. /*
  2. * Copyright 2012 The WebRTC project authors. All Rights Reserved.
  3. *
  4. * Use of this source code is governed by a BSD-style license
  5. * that can be found in the LICENSE file in the root of the source
  6. * tree. An additional intellectual property rights grant can be found
  7. * in the file PATENTS. All contributing project authors may
  8. * be found in the AUTHORS file in the root of the source tree.
  9. */
  10. // This file contains mock implementations of observers used in PeerConnection.
  11. // TODO(steveanton): These aren't really mocks and should be renamed.
  12. #ifndef PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
  13. #define PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_
  14. #include <map>
  15. #include <memory>
  16. #include <string>
  17. #include <utility>
  18. #include <vector>
  19. #include "api/data_channel_interface.h"
  20. #include "api/jsep_ice_candidate.h"
  21. #include "pc/stream_collection.h"
  22. #include "rtc_base/checks.h"
  23. namespace webrtc {
  24. class MockPeerConnectionObserver : public PeerConnectionObserver {
  25. public:
  26. struct AddTrackEvent {
  27. explicit AddTrackEvent(
  28. rtc::scoped_refptr<RtpReceiverInterface> event_receiver,
  29. std::vector<rtc::scoped_refptr<MediaStreamInterface>> event_streams)
  30. : receiver(std::move(event_receiver)),
  31. streams(std::move(event_streams)) {
  32. for (auto stream : streams) {
  33. std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>> tracks;
  34. for (auto audio_track : stream->GetAudioTracks()) {
  35. tracks.push_back(audio_track);
  36. }
  37. for (auto video_track : stream->GetVideoTracks()) {
  38. tracks.push_back(video_track);
  39. }
  40. snapshotted_stream_tracks[stream] = tracks;
  41. }
  42. }
  43. rtc::scoped_refptr<RtpReceiverInterface> receiver;
  44. std::vector<rtc::scoped_refptr<MediaStreamInterface>> streams;
  45. // This map records the tracks present in each stream at the time the
  46. // OnAddTrack callback was issued.
  47. std::map<rtc::scoped_refptr<MediaStreamInterface>,
  48. std::vector<rtc::scoped_refptr<MediaStreamTrackInterface>>>
  49. snapshotted_stream_tracks;
  50. };
  51. MockPeerConnectionObserver() : remote_streams_(StreamCollection::Create()) {}
  52. virtual ~MockPeerConnectionObserver() {}
  53. void SetPeerConnectionInterface(PeerConnectionInterface* pc) {
  54. pc_ = pc;
  55. if (pc) {
  56. state_ = pc_->signaling_state();
  57. }
  58. }
  59. void OnSignalingChange(
  60. PeerConnectionInterface::SignalingState new_state) override {
  61. RTC_DCHECK(pc_);
  62. RTC_DCHECK(pc_->signaling_state() == new_state);
  63. state_ = new_state;
  64. }
  65. MediaStreamInterface* RemoteStream(const std::string& label) {
  66. return remote_streams_->find(label);
  67. }
  68. StreamCollectionInterface* remote_streams() const { return remote_streams_; }
  69. void OnAddStream(rtc::scoped_refptr<MediaStreamInterface> stream) override {
  70. last_added_stream_ = stream;
  71. remote_streams_->AddStream(stream);
  72. }
  73. void OnRemoveStream(
  74. rtc::scoped_refptr<MediaStreamInterface> stream) override {
  75. last_removed_stream_ = stream;
  76. remote_streams_->RemoveStream(stream);
  77. }
  78. void OnRenegotiationNeeded() override { renegotiation_needed_ = true; }
  79. void OnNegotiationNeededEvent(uint32_t event_id) override {
  80. latest_negotiation_needed_event_ = event_id;
  81. }
  82. void OnDataChannel(
  83. rtc::scoped_refptr<DataChannelInterface> data_channel) override {
  84. last_datachannel_ = data_channel;
  85. }
  86. void OnIceConnectionChange(
  87. PeerConnectionInterface::IceConnectionState new_state) override {
  88. RTC_DCHECK(pc_);
  89. RTC_DCHECK(pc_->ice_connection_state() == new_state);
  90. // When ICE is finished, the caller will get to a kIceConnectionCompleted
  91. // state, because it has the ICE controlling role, while the callee
  92. // will get to a kIceConnectionConnected state. This means that both ICE
  93. // and DTLS are connected.
  94. ice_connected_ =
  95. (new_state == PeerConnectionInterface::kIceConnectionConnected) ||
  96. (new_state == PeerConnectionInterface::kIceConnectionCompleted);
  97. callback_triggered_ = true;
  98. }
  99. void OnIceGatheringChange(
  100. PeerConnectionInterface::IceGatheringState new_state) override {
  101. RTC_DCHECK(pc_);
  102. RTC_DCHECK(pc_->ice_gathering_state() == new_state);
  103. ice_gathering_complete_ =
  104. new_state == PeerConnectionInterface::kIceGatheringComplete;
  105. callback_triggered_ = true;
  106. }
  107. void OnIceCandidate(const IceCandidateInterface* candidate) override {
  108. RTC_DCHECK(pc_);
  109. RTC_DCHECK(PeerConnectionInterface::kIceGatheringNew !=
  110. pc_->ice_gathering_state());
  111. candidates_.push_back(std::make_unique<JsepIceCandidate>(
  112. candidate->sdp_mid(), candidate->sdp_mline_index(),
  113. candidate->candidate()));
  114. callback_triggered_ = true;
  115. }
  116. void OnIceCandidatesRemoved(
  117. const std::vector<cricket::Candidate>& candidates) override {
  118. num_candidates_removed_++;
  119. callback_triggered_ = true;
  120. }
  121. void OnIceConnectionReceivingChange(bool receiving) override {
  122. callback_triggered_ = true;
  123. }
  124. void OnAddTrack(rtc::scoped_refptr<RtpReceiverInterface> receiver,
  125. const std::vector<rtc::scoped_refptr<MediaStreamInterface>>&
  126. streams) override {
  127. RTC_DCHECK(receiver);
  128. num_added_tracks_++;
  129. last_added_track_label_ = receiver->id();
  130. add_track_events_.push_back(AddTrackEvent(receiver, streams));
  131. }
  132. void OnTrack(
  133. rtc::scoped_refptr<RtpTransceiverInterface> transceiver) override {
  134. on_track_transceivers_.push_back(transceiver);
  135. }
  136. void OnRemoveTrack(
  137. rtc::scoped_refptr<RtpReceiverInterface> receiver) override {
  138. remove_track_events_.push_back(receiver);
  139. }
  140. std::vector<rtc::scoped_refptr<RtpReceiverInterface>> GetAddTrackReceivers() {
  141. std::vector<rtc::scoped_refptr<RtpReceiverInterface>> receivers;
  142. for (const AddTrackEvent& event : add_track_events_) {
  143. receivers.push_back(event.receiver);
  144. }
  145. return receivers;
  146. }
  147. int CountAddTrackEventsForStream(const std::string& stream_id) {
  148. int found_tracks = 0;
  149. for (const AddTrackEvent& event : add_track_events_) {
  150. bool has_stream_id = false;
  151. for (auto stream : event.streams) {
  152. if (stream->id() == stream_id) {
  153. has_stream_id = true;
  154. break;
  155. }
  156. }
  157. if (has_stream_id) {
  158. ++found_tracks;
  159. }
  160. }
  161. return found_tracks;
  162. }
  163. // Returns the id of the last added stream.
  164. // Empty string if no stream have been added.
  165. std::string GetLastAddedStreamId() {
  166. if (last_added_stream_.get())
  167. return last_added_stream_->id();
  168. return "";
  169. }
  170. std::string GetLastRemovedStreamId() {
  171. if (last_removed_stream_.get())
  172. return last_removed_stream_->id();
  173. return "";
  174. }
  175. IceCandidateInterface* last_candidate() {
  176. if (candidates_.empty()) {
  177. return nullptr;
  178. } else {
  179. return candidates_.back().get();
  180. }
  181. }
  182. std::vector<const IceCandidateInterface*> GetAllCandidates() {
  183. std::vector<const IceCandidateInterface*> candidates;
  184. for (const auto& candidate : candidates_) {
  185. candidates.push_back(candidate.get());
  186. }
  187. return candidates;
  188. }
  189. std::vector<IceCandidateInterface*> GetCandidatesByMline(int mline_index) {
  190. std::vector<IceCandidateInterface*> candidates;
  191. for (const auto& candidate : candidates_) {
  192. if (candidate->sdp_mline_index() == mline_index) {
  193. candidates.push_back(candidate.get());
  194. }
  195. }
  196. return candidates;
  197. }
  198. bool legacy_renegotiation_needed() const { return renegotiation_needed_; }
  199. void clear_legacy_renegotiation_needed() { renegotiation_needed_ = false; }
  200. bool has_negotiation_needed_event() {
  201. return latest_negotiation_needed_event_.has_value();
  202. }
  203. uint32_t latest_negotiation_needed_event() {
  204. return latest_negotiation_needed_event_.value_or(0u);
  205. }
  206. void clear_latest_negotiation_needed_event() {
  207. latest_negotiation_needed_event_ = absl::nullopt;
  208. }
  209. rtc::scoped_refptr<PeerConnectionInterface> pc_;
  210. PeerConnectionInterface::SignalingState state_;
  211. std::vector<std::unique_ptr<IceCandidateInterface>> candidates_;
  212. rtc::scoped_refptr<DataChannelInterface> last_datachannel_;
  213. rtc::scoped_refptr<StreamCollection> remote_streams_;
  214. bool renegotiation_needed_ = false;
  215. absl::optional<uint32_t> latest_negotiation_needed_event_;
  216. bool ice_gathering_complete_ = false;
  217. bool ice_connected_ = false;
  218. bool callback_triggered_ = false;
  219. int num_added_tracks_ = 0;
  220. std::string last_added_track_label_;
  221. std::vector<AddTrackEvent> add_track_events_;
  222. std::vector<rtc::scoped_refptr<RtpReceiverInterface>> remove_track_events_;
  223. std::vector<rtc::scoped_refptr<RtpTransceiverInterface>>
  224. on_track_transceivers_;
  225. int num_candidates_removed_ = 0;
  226. private:
  227. rtc::scoped_refptr<MediaStreamInterface> last_added_stream_;
  228. rtc::scoped_refptr<MediaStreamInterface> last_removed_stream_;
  229. };
  230. class MockCreateSessionDescriptionObserver
  231. : public webrtc::CreateSessionDescriptionObserver {
  232. public:
  233. MockCreateSessionDescriptionObserver()
  234. : called_(false),
  235. error_("MockCreateSessionDescriptionObserver not called") {}
  236. virtual ~MockCreateSessionDescriptionObserver() {}
  237. void OnSuccess(SessionDescriptionInterface* desc) override {
  238. called_ = true;
  239. error_ = "";
  240. desc_.reset(desc);
  241. }
  242. void OnFailure(webrtc::RTCError error) override {
  243. called_ = true;
  244. error_ = error.message();
  245. }
  246. bool called() const { return called_; }
  247. bool result() const { return error_.empty(); }
  248. const std::string& error() const { return error_; }
  249. std::unique_ptr<SessionDescriptionInterface> MoveDescription() {
  250. return std::move(desc_);
  251. }
  252. private:
  253. bool called_;
  254. std::string error_;
  255. std::unique_ptr<SessionDescriptionInterface> desc_;
  256. };
  257. class MockSetSessionDescriptionObserver
  258. : public webrtc::SetSessionDescriptionObserver {
  259. public:
  260. static rtc::scoped_refptr<MockSetSessionDescriptionObserver> Create() {
  261. return new rtc::RefCountedObject<MockSetSessionDescriptionObserver>();
  262. }
  263. MockSetSessionDescriptionObserver()
  264. : called_(false),
  265. error_("MockSetSessionDescriptionObserver not called") {}
  266. ~MockSetSessionDescriptionObserver() override {}
  267. void OnSuccess() override {
  268. called_ = true;
  269. error_ = "";
  270. }
  271. void OnFailure(webrtc::RTCError error) override {
  272. called_ = true;
  273. error_ = error.message();
  274. }
  275. bool called() const { return called_; }
  276. bool result() const { return error_.empty(); }
  277. const std::string& error() const { return error_; }
  278. private:
  279. bool called_;
  280. std::string error_;
  281. };
  282. class FakeSetLocalDescriptionObserver
  283. : public rtc::RefCountedObject<SetLocalDescriptionObserverInterface> {
  284. public:
  285. bool called() const { return error_.has_value(); }
  286. RTCError& error() {
  287. RTC_DCHECK(error_.has_value());
  288. return *error_;
  289. }
  290. // SetLocalDescriptionObserverInterface implementation.
  291. void OnSetLocalDescriptionComplete(RTCError error) override {
  292. error_ = std::move(error);
  293. }
  294. private:
  295. // Set on complete, on success this is set to an RTCError::OK() error.
  296. absl::optional<RTCError> error_;
  297. };
  298. class FakeSetRemoteDescriptionObserver
  299. : public rtc::RefCountedObject<SetRemoteDescriptionObserverInterface> {
  300. public:
  301. bool called() const { return error_.has_value(); }
  302. RTCError& error() {
  303. RTC_DCHECK(error_.has_value());
  304. return *error_;
  305. }
  306. // SetRemoteDescriptionObserverInterface implementation.
  307. void OnSetRemoteDescriptionComplete(RTCError error) override {
  308. error_ = std::move(error);
  309. }
  310. private:
  311. // Set on complete, on success this is set to an RTCError::OK() error.
  312. absl::optional<RTCError> error_;
  313. };
  314. class MockDataChannelObserver : public webrtc::DataChannelObserver {
  315. public:
  316. explicit MockDataChannelObserver(webrtc::DataChannelInterface* channel)
  317. : channel_(channel) {
  318. channel_->RegisterObserver(this);
  319. state_ = channel_->state();
  320. }
  321. virtual ~MockDataChannelObserver() { channel_->UnregisterObserver(); }
  322. void OnBufferedAmountChange(uint64_t previous_amount) override {}
  323. void OnStateChange() override { state_ = channel_->state(); }
  324. void OnMessage(const DataBuffer& buffer) override {
  325. messages_.push_back(
  326. std::string(buffer.data.data<char>(), buffer.data.size()));
  327. }
  328. bool IsOpen() const { return state_ == DataChannelInterface::kOpen; }
  329. std::vector<std::string> messages() const { return messages_; }
  330. std::string last_message() const {
  331. return messages_.empty() ? std::string() : messages_.back();
  332. }
  333. size_t received_message_count() const { return messages_.size(); }
  334. private:
  335. rtc::scoped_refptr<webrtc::DataChannelInterface> channel_;
  336. DataChannelInterface::DataState state_;
  337. std::vector<std::string> messages_;
  338. };
  339. class MockStatsObserver : public webrtc::StatsObserver {
  340. public:
  341. MockStatsObserver() : called_(false), stats_() {}
  342. virtual ~MockStatsObserver() {}
  343. virtual void OnComplete(const StatsReports& reports) {
  344. RTC_CHECK(!called_);
  345. called_ = true;
  346. stats_.Clear();
  347. stats_.number_of_reports = reports.size();
  348. for (const auto* r : reports) {
  349. if (r->type() == StatsReport::kStatsReportTypeSsrc) {
  350. stats_.timestamp = r->timestamp();
  351. GetIntValue(r, StatsReport::kStatsValueNameAudioOutputLevel,
  352. &stats_.audio_output_level);
  353. GetIntValue(r, StatsReport::kStatsValueNameAudioInputLevel,
  354. &stats_.audio_input_level);
  355. GetIntValue(r, StatsReport::kStatsValueNameBytesReceived,
  356. &stats_.bytes_received);
  357. GetIntValue(r, StatsReport::kStatsValueNameBytesSent,
  358. &stats_.bytes_sent);
  359. GetInt64Value(r, StatsReport::kStatsValueNameCaptureStartNtpTimeMs,
  360. &stats_.capture_start_ntp_time);
  361. stats_.track_ids.emplace_back();
  362. GetStringValue(r, StatsReport::kStatsValueNameTrackId,
  363. &stats_.track_ids.back());
  364. } else if (r->type() == StatsReport::kStatsReportTypeBwe) {
  365. stats_.timestamp = r->timestamp();
  366. GetIntValue(r, StatsReport::kStatsValueNameAvailableReceiveBandwidth,
  367. &stats_.available_receive_bandwidth);
  368. } else if (r->type() == StatsReport::kStatsReportTypeComponent) {
  369. stats_.timestamp = r->timestamp();
  370. GetStringValue(r, StatsReport::kStatsValueNameDtlsCipher,
  371. &stats_.dtls_cipher);
  372. GetStringValue(r, StatsReport::kStatsValueNameSrtpCipher,
  373. &stats_.srtp_cipher);
  374. }
  375. }
  376. }
  377. bool called() const { return called_; }
  378. size_t number_of_reports() const { return stats_.number_of_reports; }
  379. double timestamp() const { return stats_.timestamp; }
  380. int AudioOutputLevel() const {
  381. RTC_CHECK(called_);
  382. return stats_.audio_output_level;
  383. }
  384. int AudioInputLevel() const {
  385. RTC_CHECK(called_);
  386. return stats_.audio_input_level;
  387. }
  388. int BytesReceived() const {
  389. RTC_CHECK(called_);
  390. return stats_.bytes_received;
  391. }
  392. int BytesSent() const {
  393. RTC_CHECK(called_);
  394. return stats_.bytes_sent;
  395. }
  396. int64_t CaptureStartNtpTime() const {
  397. RTC_CHECK(called_);
  398. return stats_.capture_start_ntp_time;
  399. }
  400. int AvailableReceiveBandwidth() const {
  401. RTC_CHECK(called_);
  402. return stats_.available_receive_bandwidth;
  403. }
  404. std::string DtlsCipher() const {
  405. RTC_CHECK(called_);
  406. return stats_.dtls_cipher;
  407. }
  408. std::string SrtpCipher() const {
  409. RTC_CHECK(called_);
  410. return stats_.srtp_cipher;
  411. }
  412. std::vector<std::string> TrackIds() const {
  413. RTC_CHECK(called_);
  414. return stats_.track_ids;
  415. }
  416. private:
  417. bool GetIntValue(const StatsReport* report,
  418. StatsReport::StatsValueName name,
  419. int* value) {
  420. const StatsReport::Value* v = report->FindValue(name);
  421. if (v) {
  422. // TODO(tommi): We should really just be using an int here :-/
  423. *value = rtc::FromString<int>(v->ToString());
  424. }
  425. return v != nullptr;
  426. }
  427. bool GetInt64Value(const StatsReport* report,
  428. StatsReport::StatsValueName name,
  429. int64_t* value) {
  430. const StatsReport::Value* v = report->FindValue(name);
  431. if (v) {
  432. // TODO(tommi): We should really just be using an int here :-/
  433. *value = rtc::FromString<int64_t>(v->ToString());
  434. }
  435. return v != nullptr;
  436. }
  437. bool GetStringValue(const StatsReport* report,
  438. StatsReport::StatsValueName name,
  439. std::string* value) {
  440. const StatsReport::Value* v = report->FindValue(name);
  441. if (v)
  442. *value = v->ToString();
  443. return v != nullptr;
  444. }
  445. bool called_;
  446. struct {
  447. void Clear() {
  448. number_of_reports = 0;
  449. timestamp = 0;
  450. audio_output_level = 0;
  451. audio_input_level = 0;
  452. bytes_received = 0;
  453. bytes_sent = 0;
  454. capture_start_ntp_time = 0;
  455. available_receive_bandwidth = 0;
  456. dtls_cipher.clear();
  457. srtp_cipher.clear();
  458. track_ids.clear();
  459. }
  460. size_t number_of_reports;
  461. double timestamp;
  462. int audio_output_level;
  463. int audio_input_level;
  464. int bytes_received;
  465. int bytes_sent;
  466. int64_t capture_start_ntp_time;
  467. int available_receive_bandwidth;
  468. std::string dtls_cipher;
  469. std::string srtp_cipher;
  470. std::vector<std::string> track_ids;
  471. } stats_;
  472. };
  473. // Helper class that just stores the report from the callback.
  474. class MockRTCStatsCollectorCallback : public webrtc::RTCStatsCollectorCallback {
  475. public:
  476. rtc::scoped_refptr<const RTCStatsReport> report() { return report_; }
  477. bool called() const { return called_; }
  478. protected:
  479. void OnStatsDelivered(
  480. const rtc::scoped_refptr<const RTCStatsReport>& report) override {
  481. report_ = report;
  482. called_ = true;
  483. }
  484. private:
  485. bool called_ = false;
  486. rtc::scoped_refptr<const RTCStatsReport> report_;
  487. };
  488. } // namespace webrtc
  489. #endif // PC_TEST_MOCK_PEER_CONNECTION_OBSERVERS_H_