WrapDimUtils.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #pragma once
  2. #include <ATen/core/IListRef.h>
  3. #include <ATen/core/Tensor.h>
  4. #include <c10/core/TensorImpl.h>
  5. #include <c10/core/WrapDimMinimal.h>
  6. #include <c10/util/irange.h>
  7. namespace at {
  8. // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
  9. // range [-1, 0]. This is a special case for scalar tensors and manifests in
  10. // e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
  11. // [-dim_post_expr, dim_post_expr-1].
  12. using c10::maybe_wrap_dim;
  13. inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
  14. return maybe_wrap_dim(dim, tensor->dim());
  15. }
  16. inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
  17. if (tensors.empty()) {
  18. // can't wrap empty TensorList; rely on underlying implementation to throw
  19. // error if necessary.
  20. return dim;
  21. }
  22. return maybe_wrap_dim(dim, tensors[0].dim());
  23. }
  24. inline int64_t maybe_wrap_dim(
  25. int64_t dim,
  26. const std::vector<std::vector<int64_t>>& tensor_sizes) {
  27. if (tensor_sizes.empty()) {
  28. // can't wrap empty list; rely on underlying implementation to throw error
  29. // if necessary
  30. return dim;
  31. }
  32. return maybe_wrap_dim(dim, tensor_sizes[0].size());
  33. }
  34. // Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
  35. // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
  36. // specified using negative indices.
  37. //
  38. // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
  39. // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
  40. // dimensions not in the range [-dim_post_expr, dim_post_expr).
  41. inline void maybe_wrap_dims_n(
  42. int64_t* dims,
  43. int64_t ndims,
  44. int64_t dim_post_expr,
  45. bool wrap_scalars = true) {
  46. if (dim_post_expr <= 0) {
  47. if (wrap_scalars) {
  48. dim_post_expr = 1; // this will make range [-1, 0]
  49. } else {
  50. TORCH_CHECK_INDEX(
  51. ndims == 0,
  52. "Dimension specified as ",
  53. dims[0],
  54. " but tensor has no dimensions");
  55. return;
  56. }
  57. }
  58. int64_t min = -dim_post_expr;
  59. int64_t max = dim_post_expr - 1;
  60. for (const auto i : c10::irange(ndims)) {
  61. auto& dim = dims[i];
  62. if (dim < min || dim > max) {
  63. TORCH_CHECK_INDEX(
  64. false,
  65. "Dimension out of range (expected to be in range of [",
  66. min,
  67. ", ",
  68. max,
  69. "], but got ",
  70. dim,
  71. ")");
  72. }
  73. if (dim < 0)
  74. dim += dim_post_expr;
  75. }
  76. }
  77. // Given a contiguous container of dimensions `dims`, this function "Wraps"
  78. // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
  79. // specified using negative indices.
  80. //
  81. // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
  82. // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
  83. // dimensions not in the range [-dim_post_expr, dim_post_expr).
  84. template <typename Container>
  85. inline void maybe_wrap_dims(
  86. Container& dims,
  87. int64_t dim_post_expr,
  88. bool wrap_scalars = true) {
  89. return maybe_wrap_dims_n(
  90. dims.data(), dims.size(), dim_post_expr, wrap_scalars);
  91. }
  92. // previously, size [0] tensors were the only possible empty tensors; thus, it
  93. // wasn't possible to cat empty tensors unless all the other tensors were
  94. // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
  95. // dimension behavior and dimension size checking). We maintain this behavior
  96. // for backwards compatibility, but only for this specific size (i.e. other
  97. // empty sizes are not skipped).
  98. template <typename T>
  99. inline int64_t _legacy_cat_wrap_dim(
  100. int64_t dim,
  101. const std::vector<std::vector<T>>& tensor_sizes) {
  102. for (auto& sizes : tensor_sizes) {
  103. if (sizes.size() == 1 && sizes[0] == 0) {
  104. continue;
  105. }
  106. return maybe_wrap_dim(dim, sizes.size());
  107. }
  108. return dim;
  109. }
  110. inline int64_t legacy_cat_wrap_dim(
  111. int64_t dim,
  112. const std::vector<std::vector<int64_t>>& tensor_sizes) {
  113. return _legacy_cat_wrap_dim<int64_t>(dim, tensor_sizes);
  114. }
  115. inline int64_t legacy_cat_wrap_dim_symint(
  116. int64_t dim,
  117. const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
  118. return _legacy_cat_wrap_dim<c10::SymInt>(dim, tensor_sizes);
  119. }
  120. inline int64_t legacy_cat_wrap_dim(
  121. int64_t dim,
  122. const MaterializedITensorListRef& tensors) {
  123. for (const Tensor& tensor : tensors) {
  124. if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
  125. continue;
  126. }
  127. return maybe_wrap_dim(dim, tensor.dim());
  128. }
  129. return dim;
  130. }
  131. // wrap negative dims in a vector
  132. inline void wrap_all_dims(
  133. std::vector<int64_t>& dims_to_wrap,
  134. int64_t tensor_total_dims) {
  135. for (const auto i : c10::irange(dims_to_wrap.size())) {
  136. dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
  137. }
  138. }
  139. } // namespace at