strides.h 615 B

1234567891011121314151617181920212223
  1. #pragma once
  2. #include <c10/util/ArrayRef.h>
  3. #include <c10/util/DimVector.h>
  4. namespace c10 {
  5. // Computes the contiguous strides of a tensor, given its sizes.
  6. static inline DimVector contiguous_strides(const IntArrayRef sizes) {
  7. using Int = IntArrayRef::value_type;
  8. const Int dims = static_cast<Int>(sizes.size());
  9. // With this intialisation we get the case dim == 0 or 1 right
  10. DimVector strides(dims, 1);
  11. for (auto i = dims - 2; i >= 0; --i) {
  12. // Strides can't be 0 even if sizes are 0.
  13. strides[i] = strides[i + 1] * std::max(sizes[i + 1], Int{1});
  14. }
  15. return strides;
  16. }
  17. } // namespace c10