accumulate.h 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. // Copyright 2004-present Facebook. All Rights Reserved.
  2. #pragma once
  3. #include <c10/util/ArrayRef.h>
  4. #include <iterator>
  5. #include <numeric>
  6. #include <type_traits>
  7. namespace c10 {
  8. /// Sum of a list of integers; accumulates into the int64_t datatype
  9. template <
  10. typename C,
  11. typename std::enable_if<
  12. std::is_integral<typename C::value_type>::value,
  13. int>::type = 0>
  14. inline int64_t sum_integers(const C& container) {
  15. // std::accumulate infers return type from `init` type, so if the `init` type
  16. // is not large enough to hold the result, computation can overflow. We use
  17. // `int64_t` here to avoid this.
  18. return std::accumulate(
  19. container.begin(), container.end(), static_cast<int64_t>(0));
  20. }
  21. /// Sum of integer elements referred to by iterators; accumulates into the
  22. /// int64_t datatype
  23. template <
  24. typename Iter,
  25. typename std::enable_if<
  26. std::is_integral<
  27. typename std::iterator_traits<Iter>::value_type>::value,
  28. int>::type = 0>
  29. inline int64_t sum_integers(Iter begin, Iter end) {
  30. // std::accumulate infers return type from `init` type, so if the `init` type
  31. // is not large enough to hold the result, computation can overflow. We use
  32. // `int64_t` here to avoid this.
  33. return std::accumulate(begin, end, static_cast<int64_t>(0));
  34. }
  35. /// Product of a list of integers; accumulates into the int64_t datatype
  36. template <
  37. typename C,
  38. typename std::enable_if<
  39. std::is_integral<typename C::value_type>::value,
  40. int>::type = 0>
  41. inline int64_t multiply_integers(const C& container) {
  42. // std::accumulate infers return type from `init` type, so if the `init` type
  43. // is not large enough to hold the result, computation can overflow. We use
  44. // `int64_t` here to avoid this.
  45. return std::accumulate(
  46. container.begin(),
  47. container.end(),
  48. static_cast<int64_t>(1),
  49. std::multiplies<>());
  50. }
  51. /// Product of integer elements referred to by iterators; accumulates into the
  52. /// int64_t datatype
  53. template <
  54. typename Iter,
  55. typename std::enable_if<
  56. std::is_integral<
  57. typename std::iterator_traits<Iter>::value_type>::value,
  58. int>::type = 0>
  59. inline int64_t multiply_integers(Iter begin, Iter end) {
  60. // std::accumulate infers return type from `init` type, so if the `init` type
  61. // is not large enough to hold the result, computation can overflow. We use
  62. // `int64_t` here to avoid this.
  63. return std::accumulate(
  64. begin, end, static_cast<int64_t>(1), std::multiplies<>());
  65. }
  66. /// Return product of all dimensions starting from k
  67. /// Returns 1 if k>=dims.size()
  68. template <
  69. typename C,
  70. typename std::enable_if<
  71. std::is_integral<typename C::value_type>::value,
  72. int>::type = 0>
  73. inline int64_t numelements_from_dim(const int k, const C& dims) {
  74. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(k >= 0);
  75. if (k > static_cast<int>(dims.size())) {
  76. return 1;
  77. } else {
  78. auto cbegin = dims.cbegin();
  79. std::advance(cbegin, k);
  80. return multiply_integers(cbegin, dims.cend());
  81. }
  82. }
  83. /// Product of all dims up to k (not including dims[k])
  84. /// Throws an error if k>dims.size()
  85. template <
  86. typename C,
  87. typename std::enable_if<
  88. std::is_integral<typename C::value_type>::value,
  89. int>::type = 0>
  90. inline int64_t numelements_to_dim(const int k, const C& dims) {
  91. TORCH_INTERNAL_ASSERT(0 <= k);
  92. TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size());
  93. auto cend = dims.cbegin();
  94. std::advance(cend, k);
  95. return multiply_integers(dims.cbegin(), cend);
  96. }
  97. /// Product of all dims between k and l (including dims[k] and excluding
  98. /// dims[l]) k and l may be supplied in either order
  99. template <
  100. typename C,
  101. typename std::enable_if<
  102. std::is_integral<typename C::value_type>::value,
  103. int>::type = 0>
  104. inline int64_t numelements_between_dim(int k, int l, const C& dims) {
  105. TORCH_INTERNAL_ASSERT(0 <= k);
  106. TORCH_INTERNAL_ASSERT(0 <= l);
  107. if (k > l) {
  108. std::swap(k, l);
  109. }
  110. TORCH_INTERNAL_ASSERT((unsigned)l < dims.size());
  111. auto cbegin = dims.cbegin();
  112. auto cend = dims.cbegin();
  113. std::advance(cbegin, k);
  114. std::advance(cend, l);
  115. return multiply_integers(cbegin, cend);
  116. }
  117. } // namespace c10