linear_least_squares.h 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. /*
  2. * Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
  3. *
  4. * Use of this source code is governed by a BSD-style license
  5. * that can be found in the LICENSE file in the root of the source
  6. * tree. An additional intellectual property rights grant can be found
  7. * in the file PATENTS. All contributing project authors may
  8. * be found in the AUTHORS file in the root of the source tree.
  9. */
  10. #ifndef RTC_TOOLS_FRAME_ANALYZER_LINEAR_LEAST_SQUARES_H_
  11. #define RTC_TOOLS_FRAME_ANALYZER_LINEAR_LEAST_SQUARES_H_
  12. #include <stdint.h>
  13. #include <valarray>
  14. #include <vector>
  15. #include "absl/types/optional.h"
  16. namespace webrtc {
  17. namespace test {
  18. // This class is used for finding a matrix b that roughly solves the equation:
  19. // y = x * b. This is generally impossible to do exactly, so the problem is
  20. // rephrased as finding the matrix b that minimizes the difference:
  21. // |y - x * b|^2. Calling multiple AddObservations() is equivalent to
  22. // concatenating the observation vectors and calling AddObservations() once. The
  23. // reason for doing it incrementally is that we can't store the raw YUV values
  24. // for a whole video file in memory at once. This class has a constant memory
  25. // footprint, regardless how may times AddObservations() is called.
  26. class IncrementalLinearLeastSquares {
  27. public:
  28. IncrementalLinearLeastSquares();
  29. ~IncrementalLinearLeastSquares();
  30. // Add a number of observations. The subvectors of x and y must have the same
  31. // length.
  32. void AddObservations(const std::vector<std::vector<uint8_t>>& x,
  33. const std::vector<std::vector<uint8_t>>& y);
  34. // Calculate and return the best linear solution, given the observations so
  35. // far.
  36. std::vector<std::vector<double>> GetBestSolution() const;
  37. private:
  38. // Running sum of x^T * x.
  39. absl::optional<std::valarray<std::valarray<uint64_t>>> sum_xx;
  40. // Running sum of x^T * y.
  41. absl::optional<std::valarray<std::valarray<uint64_t>>> sum_xy;
  42. };
  43. } // namespace test
  44. } // namespace webrtc
  45. #endif // RTC_TOOLS_FRAME_ANALYZER_LINEAR_LEAST_SQUARES_H_