123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- #pragma once
- #include <ATen/core/IListRef.h>
- #include <ATen/core/Tensor.h>
- #include <c10/core/TensorImpl.h>
- #include <c10/core/WrapDimMinimal.h>
- #include <c10/util/irange.h>
- namespace at {
- // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
- // range [-1, 0]. This is a special case for scalar tensors and manifests in
- // e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
- // [-dim_post_expr, dim_post_expr-1].
- using c10::maybe_wrap_dim;
- inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
- return maybe_wrap_dim(dim, tensor->dim());
- }
- inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
- if (tensors.empty()) {
- // can't wrap empty TensorList; rely on underlying implementation to throw
- // error if necessary.
- return dim;
- }
- return maybe_wrap_dim(dim, tensors[0].dim());
- }
- inline int64_t maybe_wrap_dim(
- int64_t dim,
- const std::vector<std::vector<int64_t>>& tensor_sizes) {
- if (tensor_sizes.empty()) {
- // can't wrap empty list; rely on underlying implementation to throw error
- // if necessary
- return dim;
- }
- return maybe_wrap_dim(dim, tensor_sizes[0].size());
- }
- // Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
- // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
- // specified using negative indices.
- //
- // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
- // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
- // dimensions not in the range [-dim_post_expr, dim_post_expr).
- inline void maybe_wrap_dims_n(
- int64_t* dims,
- int64_t ndims,
- int64_t dim_post_expr,
- bool wrap_scalars = true) {
- if (dim_post_expr <= 0) {
- if (wrap_scalars) {
- dim_post_expr = 1; // this will make range [-1, 0]
- } else {
- TORCH_CHECK_INDEX(
- ndims == 0,
- "Dimension specified as ",
- dims[0],
- " but tensor has no dimensions");
- return;
- }
- }
- int64_t min = -dim_post_expr;
- int64_t max = dim_post_expr - 1;
- for (const auto i : c10::irange(ndims)) {
- auto& dim = dims[i];
- if (dim < min || dim > max) {
- TORCH_CHECK_INDEX(
- false,
- "Dimension out of range (expected to be in range of [",
- min,
- ", ",
- max,
- "], but got ",
- dim,
- ")");
- }
- if (dim < 0)
- dim += dim_post_expr;
- }
- }
- // Given a contiguous container of dimensions `dims`, this function "Wraps"
- // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
- // specified using negative indices.
- //
- // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
- // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
- // dimensions not in the range [-dim_post_expr, dim_post_expr).
- template <typename Container>
- inline void maybe_wrap_dims(
- Container& dims,
- int64_t dim_post_expr,
- bool wrap_scalars = true) {
- return maybe_wrap_dims_n(
- dims.data(), dims.size(), dim_post_expr, wrap_scalars);
- }
- // previously, size [0] tensors were the only possible empty tensors; thus, it
- // wasn't possible to cat empty tensors unless all the other tensors were
- // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
- // dimension behavior and dimension size checking). We maintain this behavior
- // for backwards compatibility, but only for this specific size (i.e. other
- // empty sizes are not skipped).
- template <typename T>
- inline int64_t _legacy_cat_wrap_dim(
- int64_t dim,
- const std::vector<std::vector<T>>& tensor_sizes) {
- for (auto& sizes : tensor_sizes) {
- if (sizes.size() == 1 && sizes[0] == 0) {
- continue;
- }
- return maybe_wrap_dim(dim, sizes.size());
- }
- return dim;
- }
- inline int64_t legacy_cat_wrap_dim(
- int64_t dim,
- const std::vector<std::vector<int64_t>>& tensor_sizes) {
- return _legacy_cat_wrap_dim<int64_t>(dim, tensor_sizes);
- }
- inline int64_t legacy_cat_wrap_dim_symint(
- int64_t dim,
- const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
- return _legacy_cat_wrap_dim<c10::SymInt>(dim, tensor_sizes);
- }
- inline int64_t legacy_cat_wrap_dim(
- int64_t dim,
- const MaterializedITensorListRef& tensors) {
- for (const Tensor& tensor : tensors) {
- if (tensor.dim() == 1 && tensor.sizes()[0] == 0) {
- continue;
- }
- return maybe_wrap_dim(dim, tensor.dim());
- }
- return dim;
- }
- // wrap negative dims in a vector
- inline void wrap_all_dims(
- std::vector<int64_t>& dims_to_wrap,
- int64_t tensor_total_dims) {
- for (const auto i : c10::irange(dims_to_wrap.size())) {
- dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
- }
- }
- } // namespace at
|