TensorShape.h 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/util/irange.h>
  4. #include <ATen/core/IListRef.h>
  5. namespace at {
  6. namespace native {
  7. TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
  8. inline bool cat_should_skip_tensor(const Tensor& t) {
  9. return t.numel() == 0 && t.dim() == 1;
  10. }
  11. // Check to see if the shape of tensors is compatible
  12. // for being concatenated along a given dimension.
  13. inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
  14. int64_t first_dims = first.dim();
  15. int64_t second_dims = second.dim();
  16. TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
  17. first_dims, " and ", second_dims);
  18. for (const auto dim : c10::irange(first_dims)) {
  19. if (dim == dimension) {
  20. continue;
  21. }
  22. int64_t first_dim_size = first.sizes()[dim];
  23. int64_t second_dim_size = second.sizes()[dim];
  24. TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
  25. dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
  26. }
  27. }
  28. inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
  29. int64_t i = 0;
  30. for(const Tensor& t : tensors) {
  31. TORCH_CHECK(t.dim() > 0,
  32. "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
  33. i++;
  34. }
  35. }
  36. inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
  37. TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
  38. TORCH_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
  39. int64_t dim_size = self.size(dim);
  40. TORCH_CHECK(split_size > 0 || dim_size == 0,
  41. "split_size can only be 0 if dimension size is 0, "
  42. "but got dimension size of ", dim_size);
  43. // if split_size is 0 and dimension size is 0, there is 1 split.
  44. int64_t num_splits = 1;
  45. if (split_size != 0) {
  46. // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
  47. // (returns a single split). We might want to error here, but keep it for BC.
  48. num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
  49. }
  50. return num_splits;
  51. }
  52. }} // namespace at::native