suppression_gain.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /*
  2. * Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
  3. *
  4. * Use of this source code is governed by a BSD-style license
  5. * that can be found in the LICENSE file in the root of the source
  6. * tree. An additional intellectual property rights grant can be found
  7. * in the file PATENTS. All contributing project authors may
  8. * be found in the AUTHORS file in the root of the source tree.
  9. */
  10. #ifndef MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_
  11. #define MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_
  12. #include <array>
  13. #include <memory>
  14. #include <vector>
  15. #include "absl/types/optional.h"
  16. #include "api/array_view.h"
  17. #include "api/audio/echo_canceller3_config.h"
  18. #include "modules/audio_processing/aec3/aec3_common.h"
  19. #include "modules/audio_processing/aec3/aec_state.h"
  20. #include "modules/audio_processing/aec3/fft_data.h"
  21. #include "modules/audio_processing/aec3/moving_average.h"
  22. #include "modules/audio_processing/aec3/nearend_detector.h"
  23. #include "modules/audio_processing/aec3/render_signal_analyzer.h"
  24. #include "modules/audio_processing/logging/apm_data_dumper.h"
  25. #include "rtc_base/constructor_magic.h"
  26. namespace webrtc {
  27. class SuppressionGain {
  28. public:
  29. SuppressionGain(const EchoCanceller3Config& config,
  30. Aec3Optimization optimization,
  31. int sample_rate_hz,
  32. size_t num_capture_channels);
  33. ~SuppressionGain();
  34. void GetGain(
  35. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
  36. nearend_spectrum,
  37. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum,
  38. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
  39. residual_echo_spectrum,
  40. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
  41. comfort_noise_spectrum,
  42. const RenderSignalAnalyzer& render_signal_analyzer,
  43. const AecState& aec_state,
  44. const std::vector<std::vector<std::vector<float>>>& render,
  45. float* high_bands_gain,
  46. std::array<float, kFftLengthBy2Plus1>* low_band_gain);
  47. // Toggles the usage of the initial state.
  48. void SetInitialState(bool state);
  49. private:
  50. // Computes the gain to apply for the bands beyond the first band.
  51. float UpperBandsGain(
  52. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> echo_spectrum,
  53. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
  54. comfort_noise_spectrum,
  55. const absl::optional<int>& narrow_peak_band,
  56. bool saturated_echo,
  57. const std::vector<std::vector<std::vector<float>>>& render,
  58. const std::array<float, kFftLengthBy2Plus1>& low_band_gain) const;
  59. void GainToNoAudibleEcho(const std::array<float, kFftLengthBy2Plus1>& nearend,
  60. const std::array<float, kFftLengthBy2Plus1>& echo,
  61. const std::array<float, kFftLengthBy2Plus1>& masker,
  62. std::array<float, kFftLengthBy2Plus1>* gain) const;
  63. void LowerBandGain(
  64. bool stationary_with_low_power,
  65. const AecState& aec_state,
  66. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>>
  67. suppressor_input,
  68. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> residual_echo,
  69. rtc::ArrayView<const std::array<float, kFftLengthBy2Plus1>> comfort_noise,
  70. std::array<float, kFftLengthBy2Plus1>* gain);
  71. void GetMinGain(rtc::ArrayView<const float> weighted_residual_echo,
  72. rtc::ArrayView<const float> last_nearend,
  73. rtc::ArrayView<const float> last_echo,
  74. bool low_noise_render,
  75. bool saturated_echo,
  76. rtc::ArrayView<float> min_gain) const;
  77. void GetMaxGain(rtc::ArrayView<float> max_gain) const;
  78. class LowNoiseRenderDetector {
  79. public:
  80. bool Detect(const std::vector<std::vector<std::vector<float>>>& render);
  81. private:
  82. float average_power_ = 32768.f * 32768.f;
  83. };
  84. struct GainParameters {
  85. explicit GainParameters(
  86. const EchoCanceller3Config::Suppressor::Tuning& tuning);
  87. const float max_inc_factor;
  88. const float max_dec_factor_lf;
  89. std::array<float, kFftLengthBy2Plus1> enr_transparent_;
  90. std::array<float, kFftLengthBy2Plus1> enr_suppress_;
  91. std::array<float, kFftLengthBy2Plus1> emr_transparent_;
  92. };
  93. static int instance_count_;
  94. std::unique_ptr<ApmDataDumper> data_dumper_;
  95. const Aec3Optimization optimization_;
  96. const EchoCanceller3Config config_;
  97. const size_t num_capture_channels_;
  98. const int state_change_duration_blocks_;
  99. std::array<float, kFftLengthBy2Plus1> last_gain_;
  100. std::vector<std::array<float, kFftLengthBy2Plus1>> last_nearend_;
  101. std::vector<std::array<float, kFftLengthBy2Plus1>> last_echo_;
  102. LowNoiseRenderDetector low_render_detector_;
  103. bool initial_state_ = true;
  104. int initial_state_change_counter_ = 0;
  105. std::vector<aec3::MovingAverage> nearend_smoothers_;
  106. const GainParameters nearend_params_;
  107. const GainParameters normal_params_;
  108. std::unique_ptr<NearendDetector> dominant_nearend_detector_;
  109. RTC_DISALLOW_COPY_AND_ASSIGN(SuppressionGain);
  110. };
  111. } // namespace webrtc
  112. #endif // MODULES_AUDIO_PROCESSING_AEC3_SUPPRESSION_GAIN_H_