Utils.h 401 B

123456789101112131415161718
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/miopen/miopen-wrapper.h>
  4. #include <ATen/miopen/Handle.h>
  5. namespace at { namespace native {
  6. // This function makes tensors which have zero stride contiguous, by
  7. // setting the strides to 1.
  8. inline Tensor contiguousIfZeroInStrides(const Tensor& t) {
  9. for (auto s : t.strides()) {
  10. if (s == 0) return t.contiguous();
  11. }
  12. return t;
  13. }
  14. }}