NonEmptyUtils.h 611 B

123456789101112131415161718192021222324252627
  1. #include <ATen/core/TensorBase.h>
  2. #include <algorithm>
  3. #include <vector>
  4. namespace at { namespace native {
  5. inline int64_t ensure_nonempty_dim(int64_t dim) {
  6. return std::max<int64_t>(dim, 1);
  7. }
  8. inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
  9. return t.dim() == 0 ? 1 : t.size(dim);
  10. }
  11. inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
  12. return t.dim() == 0 ? 1 : t.stride(dim);
  13. }
  14. using IdxVec = std::vector<int64_t>;
  15. inline IdxVec ensure_nonempty_vec(IdxVec vec) {
  16. if (vec.empty()) {
  17. vec.push_back(1);
  18. }
  19. return vec;
  20. }
  21. }} // namespace at::native