rnn.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /*
  2. * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
  3. *
  4. * Use of this source code is governed by a BSD-style license
  5. * that can be found in the LICENSE file in the root of the source
  6. * tree. An additional intellectual property rights grant can be found
  7. * in the file PATENTS. All contributing project authors may
  8. * be found in the AUTHORS file in the root of the source tree.
  9. */
  10. #ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
  11. #define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
  12. #include <stddef.h>
  13. #include <sys/types.h>
  14. #include <array>
  15. #include <vector>
  16. #include "api/array_view.h"
  17. #include "api/function_view.h"
  18. #include "modules/audio_processing/agc2/rnn_vad/common.h"
  19. #include "rtc_base/system/arch.h"
  20. namespace webrtc {
  21. namespace rnn_vad {
  22. // Maximum number of units for a fully-connected layer. This value is used to
  23. // over-allocate space for fully-connected layers output vectors (implemented as
  24. // std::array). The value should equal the number of units of the largest
  25. // fully-connected layer.
  26. constexpr size_t kFullyConnectedLayersMaxUnits = 24;
  27. // Maximum number of units for a recurrent layer. This value is used to
  28. // over-allocate space for recurrent layers state vectors (implemented as
  29. // std::array). The value should equal the number of units of the largest
  30. // recurrent layer.
  31. constexpr size_t kRecurrentLayersMaxUnits = 24;
  32. // Fully-connected layer.
  33. class FullyConnectedLayer {
  34. public:
  35. FullyConnectedLayer(size_t input_size,
  36. size_t output_size,
  37. rtc::ArrayView<const int8_t> bias,
  38. rtc::ArrayView<const int8_t> weights,
  39. rtc::FunctionView<float(float)> activation_function,
  40. Optimization optimization);
  41. FullyConnectedLayer(const FullyConnectedLayer&) = delete;
  42. FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
  43. ~FullyConnectedLayer();
  44. size_t input_size() const { return input_size_; }
  45. size_t output_size() const { return output_size_; }
  46. Optimization optimization() const { return optimization_; }
  47. rtc::ArrayView<const float> GetOutput() const;
  48. // Computes the fully-connected layer output.
  49. void ComputeOutput(rtc::ArrayView<const float> input);
  50. private:
  51. const size_t input_size_;
  52. const size_t output_size_;
  53. const std::vector<float> bias_;
  54. const std::vector<float> weights_;
  55. rtc::FunctionView<float(float)> activation_function_;
  56. // The output vector of a recurrent layer has length equal to |output_size_|.
  57. // However, for efficiency, over-allocation is used.
  58. std::array<float, kFullyConnectedLayersMaxUnits> output_;
  59. const Optimization optimization_;
  60. };
  61. // Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
  62. // activation functions for the update/reset and output gates respectively.
  63. class GatedRecurrentLayer {
  64. public:
  65. GatedRecurrentLayer(size_t input_size,
  66. size_t output_size,
  67. rtc::ArrayView<const int8_t> bias,
  68. rtc::ArrayView<const int8_t> weights,
  69. rtc::ArrayView<const int8_t> recurrent_weights,
  70. Optimization optimization);
  71. GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
  72. GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
  73. ~GatedRecurrentLayer();
  74. size_t input_size() const { return input_size_; }
  75. size_t output_size() const { return output_size_; }
  76. Optimization optimization() const { return optimization_; }
  77. rtc::ArrayView<const float> GetOutput() const;
  78. void Reset();
  79. // Computes the recurrent layer output and updates the status.
  80. void ComputeOutput(rtc::ArrayView<const float> input);
  81. private:
  82. const size_t input_size_;
  83. const size_t output_size_;
  84. const std::vector<float> bias_;
  85. const std::vector<float> weights_;
  86. const std::vector<float> recurrent_weights_;
  87. // The state vector of a recurrent layer has length equal to |output_size_|.
  88. // However, to avoid dynamic allocation, over-allocation is used.
  89. std::array<float, kRecurrentLayersMaxUnits> state_;
  90. const Optimization optimization_;
  91. };
  92. // Recurrent network based VAD.
  93. class RnnBasedVad {
  94. public:
  95. RnnBasedVad();
  96. RnnBasedVad(const RnnBasedVad&) = delete;
  97. RnnBasedVad& operator=(const RnnBasedVad&) = delete;
  98. ~RnnBasedVad();
  99. void Reset();
  100. // Compute and returns the probability of voice (range: [0.0, 1.0]).
  101. float ComputeVadProbability(
  102. rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
  103. bool is_silence);
  104. private:
  105. FullyConnectedLayer input_layer_;
  106. GatedRecurrentLayer hidden_layer_;
  107. FullyConnectedLayer output_layer_;
  108. };
  109. } // namespace rnn_vad
  110. } // namespace webrtc
  111. #endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_