123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624 |
- #pragma once
- #include <c10/core/ScalarType.h>
- #include <c10/util/irange.h>
- #include <c10/util/Exception.h>
- #include <c10/util/strides.h>
- #include <ATen/core/Tensor.h>
- #include <ATen/ExpandUtils.h>
- #include <ATen/TensorUtils.h>
- #include <ATen/native/TensorIterator.h>
- #include <ATen/native/TransposeType.h>
- #include <limits>
- #include <type_traits>
- #include <sstream>
- #include <cstring>
- #include <cctype>
- #ifndef AT_PER_OPERATOR_HEADERS
- #include <ATen/Functions.h>
- #else
- #include <ATen/ops/arange.h>
- #include <ATen/ops/empty.h>
- #include <ATen/ops/empty_like.h>
- #include <ATen/ops/empty_strided.h>
- #include <ATen/ops/zeros.h>
- #endif
- namespace at { namespace native {
- static inline c10::MaybeOwned<Tensor> expect_resolved_conj(const Tensor& tensor) {
- if (tensor.is_conj()) {
- return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
- } else {
- return c10::MaybeOwned<Tensor>::borrowed(tensor);
- }
- }
- static inline DimVector batched_matrix_contiguous_strides(
- const IntArrayRef sizes,
- const bool f_contig = false) {
- // f_contig chooses between the strides of a batch of Fortran (F-contiguous)
- // and C-contiguous matrices
- auto strides = c10::contiguous_strides(sizes);
- auto dim = strides.size();
- if (f_contig && dim >= 2) {
- // Fix the strides of the last two dimensions, so that we return
- // C-contiguous batches of F-contiguous matrices.
- strides[dim - 1] = std::max(sizes[dim - 2], static_cast<int64_t>(1));
- strides[dim - 2] = 1;
- }
- return strides;
- }
- /*
- * Clones a Tensor so that the following conditions hold:
- * If we think of a Tensor of having size (B, M, N), where B is any number
- * of batch dimensions, then:
- * - Each (M, N) matrix is in column major form
- * - Let Tensor P have size (B, M, N) and Q have size (B, M', N').
- * Then when laid out in memory, the M by N matrix starting at
- * P.data_ptr()[B * M * N] is of the same corresponding batch as the M' by N'
- * matrix starting at Q.data_ptr()[B * M' * N'].
- */
- static inline Tensor cloneBatchedColumnMajor(const Tensor& src) {
- // If src is already in batched column major format, then
- // this will be efficient (no reordering of the data will occur)
- // because the first transpose will make the tensor contiguous,
- // and cloning a contiguous tensor is fast.
- auto result = src.mT().clone(at::MemoryFormat::Contiguous);
- result.transpose_(-2, -1);
- return result;
- }
- /*
- * contig chooses between C-contig (true) and F-contig (false)
- */
- static inline c10::MaybeOwned<Tensor> borrow_else_clone(const bool cond, const Tensor& borrow, const Tensor& clone, const bool contig) {
- return cond ? c10::MaybeOwned<Tensor>::borrowed(borrow)
- : c10::MaybeOwned<Tensor>::owned(contig ? clone.clone(MemoryFormat::Contiguous)
- : cloneBatchedColumnMajor(clone));
- }
- /*
- * This method is designed to be a faster alternative to
- * `cloneBatchedColumnMajor` with some additional features,
- * namely:
- * 1. It uses `copy` instead of `clone` which could be much faster.
- * 2. `nrows` parameter used to create inputs with the number of rows larger
- * than the original input, which is required for some LAPACK/MAGMA methods.
- * 3. `desired_batch_size` is used to create copies with the batch size
- * which is either the original batch size of the input, or its larger
- * broadcasted shape.
- */
- static inline Tensor copyBatchedColumnMajor(const Tensor& src, int64_t nrows = -1,
- at::OptionalIntArrayRef desired_batch_sizes = c10::nullopt) {
- nrows = (nrows == -1) ? src.size(-2) : nrows;
- auto copy_sizes = desired_batch_sizes.has_value()
- ? desired_batch_sizes.value().vec()
- : IntArrayRef(src.sizes().data(), src.dim() - 2).vec();
- copy_sizes.insert(copy_sizes.end(), {nrows, src.size(-1)});
- const auto copy_strides = batched_matrix_contiguous_strides(copy_sizes, /*f-contig*/true);
- auto copy = at::empty_strided(copy_sizes, copy_strides, src.options());
- copy.narrow(-2, 0, src.size(-2)).copy_(src);
- return copy;
- }
- /*
- * Given batches of matrices with arbitrary batch dim,
- * computes the number of batches.
- */
- static inline int64_t batchCount(const Tensor& batched_matrices) {
- int64_t result = 1;
- for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
- result *= batched_matrices.size(i);
- }
- return result;
- }
- // Computes the number of elements of a matrix in a batched matrix tensor
- static inline int64_t matrixStride(const Tensor& batched_matrices) {
- return batched_matrices.size(-1) * batched_matrices.size(-2);
- }
- // Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
- static inline void checkIsMatrix(const Tensor& A, const char* const f_name, const char* const arg_name = "A") {
- TORCH_CHECK(A.dim() >= 2, f_name, ": The input tensor ", arg_name, " must have at least 2 dimensions.");
- }
- static inline void squareCheckInputs(const Tensor& self, const char* const f_name, const char* const arg_name = "A") {
- checkIsMatrix(self, f_name, arg_name);
- TORCH_CHECK(self.size(-1) == self.size(-2),
- f_name,
- ": ", arg_name, " must be batches of square matrices, "
- "but they are ", self.size(-2), " by ", self.size(-1), " matrices");
- }
- static inline void checkInputsSolver(const Tensor& A,
- const Tensor& B,
- const bool left,
- const char* const f_name) {
- squareCheckInputs(A, f_name, "A");
- checkIsMatrix(B, f_name, "B");
- TORCH_CHECK(left ? A.size(-2) == B.size(-2) : A.size(-1) == B.size(-1),
- f_name, ": Incompatible shapes of A and B for the equation ",
- left ? "AX = B" : "XA = B",
- " (", A.size(-2), "x", A.size(-1), " and ", B.size(-2), "x", B.size(-1), ")");
- }
- static inline bool is_row_or_column_contiguous(const Tensor& t) {
- // This could be made more general, similar to how it's checked in matmul, which would allow to
- // ellide the copy with strides such as (6, 12, 1, 3) or (3, 1, 9), but this is quite tricky.
- // We choose to be conservative for simplicity
- return t.is_contiguous() || t.transpose(-2, -1).is_contiguous();
- }
- static inline TransposeType to_transpose_type(const bool contig, const bool conj) {
- if (conj) {
- if (contig) { TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); }
- else { return TransposeType::ConjTranspose; }
- } else {
- if (contig) { return TransposeType::NoTranspose; }
- else { return TransposeType::Transpose; }
- }
- }
- // This function is designed to be used with linear algebra methods that minimize
- // L(ax - b) = 0, where L is generally the identity map (`solve`, for example)
- // or the L2 norm (`lstsq`).
- // It is expected that `a` and `b` are contiguous tensors of column-major matrices
- // (so that a.view({-1, a.size(-2), a.size(-1)}) succeeds, same for `b`),
- // with the following additional properties:
- //
- // 1. a.dim() == b.dim()
- // 2. a.shape[:-2] broadcasts over b.shape[:-2]
- // 3. a.size(i) <= b.size(i) for i=0,..., a.dim() - 3 (only for batch dimensions)
- //
- // MAGMA/LAPACK modify tensor `a` in-place, and the main goal of this method
- // is to be memory efficient, which means that if there exists an index i such that
- // a.shape[i] < b.shape[i], 0 <= i <= a.dim() - 3,
- // then instead of materializing copies of `a` in the broadcasted shape, we keep
- // a buffer copy of `a` along with flags that check whether specific batch dimension
- // indices for `a` were already accessed. If they were, we copy the data from the buffer
- // into `a`. The number of copies does not exceed
- // prod(max(a.shape[:-2], b.shape[:-2]) - a.shape[:-2] + 1)
- // and this value is attained by tensors with non-empty batch dimensions.
- //
- // func_t `f` is a callable that is being supplied with
- // scalar_t* a_working_ptr, scalar_t* b_working_ptr, int64_t a_linear_batch_idx.
- // a_working_ptr and b_working_ptr can directly be passed to LAPACK/MAGMA routines,
- // and a_linear_batch_idx is an index in the 3d representation which corresponds to
- // the memory a_working_ptr points to, in other words:
- // a_working_ptr == a.view({-1, a.size(-2), a.size(-1)}.select(0, a_linear_batch_idx).data_ptr<scalar_t>();
- // a_linear_batch_idx is useful to store metadata related to `a`, such as, for example,
- // its rank or singular values (see linalg_lstsq).
- template<typename scalar_t, typename func_t>
- void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
- IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
- IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
- auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
- auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
- TensorIterator iter = TensorIteratorConfig()
- .set_check_mem_overlap(false)
- .check_all_same_dtype(false)
- .resize_outputs(false)
- .add_output(b_linear_batch_idx)
- .add_input(a_linear_batch_idx)
- .build();
- auto m = a.size(-2);
- auto n = a.size(-1);
- auto a_3d = a.view({batchCount(a), m, n});
- auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
- auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
- Tensor a_buffer, a_was_accessed, a_buffer_3d;
- std::function<void(int64_t)> check_if_copy_needed_for_a
- = [](int64_t /*a_curr_linear_batch_idx*/){};
- if (a_broadcasts_over_b) {
- a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
- .copy_(a);
- a_was_accessed = at::zeros(batchCount(a), at::kBool);
- a_buffer_3d = a_buffer.view({batchCount(a), m, n});
- check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
- auto* a_was_accessed_flag = a_was_accessed
- .select(0, a_curr_linear_batch_idx)
- .data_ptr<bool>();
- if (!(*a_was_accessed_flag)) {
- *a_was_accessed_flag = true;
- }
- else {
- a_3d.select(0, a_curr_linear_batch_idx)
- .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
- }
- };
- }
- auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
- auto* b_batch_idx_ptr = data[0];
- auto* a_batch_idx_ptr = data[1];
- for (const auto elem C10_UNUSED : c10::irange(nelems)) {
- auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
- auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
- check_if_copy_needed_for_a(a_curr_linear_batch_idx);
- auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
- .data_ptr<scalar_t>();
- auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
- .data_ptr<scalar_t>();
- f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
- b_batch_idx_ptr += strides[0];
- a_batch_idx_ptr += strides[1];
- }
- };
- iter.serial_for_each(loop, {0, batchCount(b)});
- }
- // Returns the epsilon value for floating types except half
- static inline double _get_epsilon(const ScalarType& sc_type) {
- switch (sc_type) {
- case at::ScalarType::Float:
- return static_cast<double>(std::numeric_limits<float>::epsilon());
- case at::ScalarType::Double:
- return std::numeric_limits<double>::epsilon();
- default:
- AT_ERROR("This function doesn't handle types other than float and double");
- }
- }
- // Validates input shapes and devices
- // for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
- static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, const char* name) {
- TORCH_CHECK(self.device() == A.device(),
- "Expected b and A to be on the same device, but found b on ",
- self.device(), " and A on ", A.device(), " instead.");
- TORCH_CHECK(self.scalar_type() == A.scalar_type(),
- "Expected b and A to have the same dtype, but found b of type ",
- self.scalar_type(), " and A of type ", A.scalar_type(), " instead.");
- TORCH_CHECK(A.size(-1) == A.size(-2),
- "A must be batches of square matrices, "
- "but they are ", A.size(-2), " by ", A.size(-1), " matrices");
- TORCH_CHECK(A.size(-1) == self.size(-2),
- "Incompatible matrix sizes for ", name, ": each A "
- "matrix is ", A.size(-1), " by ", A.size(-1),
- " but each b matrix is ", self.size(-2), " by ", self.size(-1));
- }
- static inline void checkFloatingOrComplex(const Tensor& t, const char* const f_name, const bool allow_low_precision_dtypes=true) {
- auto dtype = t.scalar_type();
- TORCH_CHECK((at::isFloatingType(dtype) || at::isComplexType(dtype)),
- f_name, ": Expected a floating point or complex tensor as input. Got ", dtype);
- if (!allow_low_precision_dtypes) {
- TORCH_CHECK(dtype == kFloat || dtype == kDouble || dtype == kComplexFloat || dtype == kComplexDouble,
- f_name, ": Low precision dtypes not supported. Got ", dtype);
- }
- }
- // Checks if all the Tensors in a TensorList are of the same dimensions
- static inline void checkAllSameDim(TensorList tensors, int64_t dim) {
- for (auto &t : tensors) {
- TORCH_CHECK(t.dim() == dim, "Tensor dimension is ", t.dim(), ", expected ", dim, " instead.");
- }
- }
- static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) {
- // broadcast the batch dimensions of arg1 and arg2.
- IntArrayRef arg1_batch_sizes(arg1.sizes().data(), arg1.ndimension() - 2);
- IntArrayRef arg2_batch_sizes(arg2.sizes().data(), arg2.ndimension() - 2);
- std::vector<int64_t> expand_batch_portion = infer_size(arg1_batch_sizes, arg2_batch_sizes);
- std::vector<int64_t> arg1_expand_size({expand_batch_portion});
- arg1_expand_size.insert(arg1_expand_size.end(), { arg1.size(-2), arg1.size(-1) });
- std::vector<int64_t> arg2_expand_size({expand_batch_portion});
- arg2_expand_size.insert(arg2_expand_size.end(), { arg2.size(-2), arg2.size(-1) });
- return std::make_tuple(std::move(arg1_expand_size), std::move(arg2_expand_size));
- }
- static inline std::tuple<Tensor,Tensor> _linalg_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2, const char* name) {
- // If there's no name we assume we don't want to check the errors
- if (name != nullptr) {
- linearSolveCheckInputs(arg1, arg2, name);
- }
- std::vector<int64_t> arg1_expand_size, arg2_expand_size;
- std::tie(arg1_expand_size, arg2_expand_size) = at::native::_linalg_broadcast_batch_dims(arg1, arg2);
- auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size);
- auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size);
- return std::make_tuple(arg1_broadcasted, arg2_broadcasted);
- }
- static inline std::vector<int64_t> broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) {
- IntArrayRef t1_batch_sizes(t1.sizes().data(), n_batch_dims);
- IntArrayRef t2_batch_sizes(t2.sizes().data(), n_batch_dims);
- auto broadcasted_batch_sizes = infer_size(t1_batch_sizes, t2_batch_sizes);
- return broadcasted_batch_sizes;
- }
- // Return a permutation with the given axes moved to the end.
- static inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) {
- const std::vector<int64_t> a = axes.vec();
- const int64_t ndim = self.ndimension();
- std::vector<int64_t> perm;
- for (const auto i : c10::irange(ndim)) {
- auto it = std::find(a.begin(), a.end(), i);
- if (it == a.end()) {
- perm.push_back(i);
- }
- }
- for (auto i : a) {
- perm.push_back(i);
- }
- TORCH_CHECK((int64_t)perm.size() == ndim,
- "duplicate or invalid axis in 'dim' argument for tensor with ndim==", ndim);
- return self.permute(perm);
- }
- // parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
- static inline std::tuple<bool, bool> _parse_qr_mode(c10::string_view mode) {
- bool compute_q;
- bool reduced;
- if (mode == "reduced") {
- compute_q = true;
- reduced = true;
- } else if (mode == "complete") {
- compute_q = true;
- reduced = false;
- } else if (mode == "r") {
- compute_q = false;
- reduced = true; // this is actually irrelevant in this mode
- } else {
- TORCH_CHECK(false, "qr received unrecognized mode '", mode,
- "' but expected one of 'reduced' (default), 'r', or 'complete'");
- }
- return std::make_tuple(compute_q, reduced);
- }
- // Function to compute sizes, strides and the extra columns for the Q matrix in the QR Decomposition
- static inline std::tuple<DimVector, DimVector, int64_t> _compute_geometry_for_Q(
- const Tensor& input,
- bool reduced) {
- int64_t m = input.size(-2), n = input.size(-1);
- int64_t n_columns_q;
- // We need to compute the required size of Q based on the `reduced` option
- DimVector q_sizes(input.sizes());
- if (!reduced && m > n) {
- q_sizes[input.dim() - 1] = m;
- n_columns_q = m;
- } else {
- q_sizes[input.dim() - 1] = n;
- n_columns_q = std::min(m, n);
- }
- auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true);
- return std::make_tuple(q_sizes, q_strides, n_columns_q);
- }
- static inline bool svd_uses_cusolver(const Tensor& A) {
- // if cusolver is available, it is used unconditionally
- return A.is_cuda()
- && at::globalContext().hasCuSOLVER()
- && at::globalContext().linalgPreferredBackend() != at::LinalgBackend::Magma;
- }
- // Function used instead of .to so that the original strides are retained
- // .to doesn't retain strides and make the output tensor contiguous
- static inline Tensor same_stride_to(const Tensor& original_tensor, const at::TensorOptions& options) {
- auto strided_to = at::empty_strided(original_tensor.sizes(),
- original_tensor.strides(),
- options);
- strided_to.copy_(original_tensor);
- return strided_to;
- }
- // Creates a dimension permutation array that can be given to `at::permute()`, which will shift
- // the two specified dimensions to the end of a tensor, without changing the order of
- // the other dimensions. `dim1` will be placed at the very end, and `dim0` will be
- // placed just to the left of it.
- //
- // For instance, given a 4-D tensor, dimensions 1 and 3 can be shifted to the end by
- // calling `create_dim_backshift_permutation(1, 3, 4)`. The resulting vector will
- // be `vec(0, 2, 1, 3)`.
- static inline std::vector<int64_t> create_dim_backshift_permutation(int64_t dim0, int64_t dim1, int64_t ndim) {
- TORCH_CHECK(
- (dim0 != dim1) && (dim0 < ndim) && (dim0 >= 0) && (dim1 < ndim) && (dim1 >= 0),
- "duplicate or invalid dimensions");
- std::vector<int64_t> permutation(ndim);
- int64_t cur_permuted_dim = 0;
- for (const auto dim_ind : c10::irange(ndim)) {
- if ((dim_ind != dim0) && (dim_ind != dim1)) {
- permutation[cur_permuted_dim++] = dim_ind;
- }
- }
- permutation[cur_permuted_dim++] = dim0;
- permutation[cur_permuted_dim] = dim1;
- return permutation;
- }
- // Creates a dimension permutation array that can be given to `at::permute()`, which
- // will reverse a given permutation.
- // The reverse permutation array is created by swapping the indices and their
- // associated values from the given permutation array.
- static inline std::vector<int64_t> create_reverse_permutation(std::vector<int64_t> permutation) {
- int64_t ndim = permutation.size();
- std::vector<int64_t> reverse_permutation(ndim);
- for (const auto dim_ind : c10::irange(ndim)) {
- reverse_permutation[permutation[dim_ind]] = dim_ind;
- }
- return reverse_permutation;
- }
- // Compute R-work array size for MAGMA/LAPACK cgesdd/zgesdd
- // See https://github.com/Reference-LAPACK/lapack/blob/122506cd8b6ce050a200920c3d4c0b153b150fd8/SRC/cgesdd.f#L186
- static inline int64_t computeLRWorkDim(const char jobz, int64_t m, int64_t n) {
- auto mn = std::min(m, n);
- auto mx = std::max(m, n);
- if (jobz == 'N') {
- #ifdef __APPLE__
- // According to `vecLib.framework/Headers/clapack.h` Accelerate.framework is based on LAPACK 3.2.1
- return 7 * mn;
- #else
- // These setting is valid for on LAPACK 3.6+
- return 5 * mn;
- #endif
- }
- if (mx > 10 * mn) {
- return 5 * mn * mn + 5 * mn;
- }
- return std::max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn);
- }
- // This function checks whether the uplo argument input is valid
- // Allowed strings are "u", "U", "l", "L"
- static inline void checkUplo(const c10::string_view uplo) {
- // To use std::toupper safely with plain chars (or signed chars), the argument should first be converted to unsigned char
- char uplo_uppercase = static_cast<char>(std::toupper(static_cast<unsigned char>(uplo[0])));
- TORCH_CHECK(uplo.size() == 1 && (uplo_uppercase == 'U' || uplo_uppercase == 'L'),
- "Expected UPLO argument to be 'L' or 'U', but got ", uplo);
- }
- static inline void checkSameDevice(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
- TORCH_CHECK(
- result.device() == input.device(),
- fn_name,
- ": Expected ", result_name, " and input tensors to be on the same device, but got ",
- result_name, " on ", result.device(), " and input on ", input.device());
- }
- // Check the dtype of result and input tensors (for _out variants).
- // Most linear algebra functions have the same dtype for input and output
- // (either floating or complex type input), so we can check whether input's dtype can be casted to result's dtype.
- // According to https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
- // c10::canCast is used for checking the "safe copy" dtype requirements.
- static inline void checkLinalgCompatibleDtype(const std::string& fn_name, Tensor result, Tensor input, const std::string& result_name = "result") {
- bool can_cast = c10::canCast(input.scalar_type(), result.scalar_type());
- TORCH_CHECK(
- can_cast,
- fn_name,
- ": Expected ", result_name, " to be safely castable from ", input.scalar_type(), " dtype, but got ",
- result_name, " with dtype ", result.scalar_type());
- }
- // Alternatively, we can check whether the specific expected output type (result_type) can be safely casted to out tensor dtype (out_type)
- static inline void checkLinalgCompatibleDtype(const std::string& fn_name, ScalarType out_type, ScalarType result_type, const std::string& out_name = "result") {
- bool can_cast = c10::canCast(result_type, out_type);
- TORCH_CHECK(
- can_cast,
- fn_name,
- ": Expected ", out_name, " to be safely castable from ", result_type, " dtype, but got ",
- out_name, " with dtype ", out_type);
- }
- static inline void checkNotComplexTolerance(const Tensor& tol, const c10::string_view f_name, const c10::string_view tol_name) {
- TORCH_CHECK(!at::isComplexType(tol.scalar_type()),
- f_name, ": ", tol_name, " tensor of complex type is not supported. Got ", tol.scalar_type());
- }
- /*
- Two types of 'other' tensors are supported when solving
- a system of linear equations matmul(input, x) = other:
- * 1-dimensional (1D) tensor or batch of 1D tensors (vector case)
- * 2-dimensional (2D) tensor or batch of 2D tensors (matrix case).
- The original torch.solve supported only the matrix case, while NumPy works for both cases.
- For the batched input we need to be able to distinguish them.
- Let input.shape = (batch_dimensions, m, n), then 'other' is of vector type if other.shape == (batch_dimensions, m).
- This rule is compatible with NumPy, see https://github.com/numpy/numpy/blob/v1.20.0/numpy/linalg/linalg.py#L384-L389
- */
- static inline bool linalg_solve_is_vector_rhs(const Tensor& input, const Tensor& other) {
- auto expected_batched_rhs_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); // input.shape[:-1]
- bool vector_case = other.dim() == 1 || (input.dim() - 1 == other.dim() && other.sizes().equals(expected_batched_rhs_shape));
- return vector_case;
- }
- /*
- Computes linear indices for a tensor with original_shape to access its elements like it was a materialized broadcast tensor.
- */
- static inline Tensor get_linear_indices(int64_t numel, IntArrayRef original_shape, IntArrayRef broadcast_shape) {
- TensorOptions options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
- return at::arange(numel, options).view(original_shape).broadcast_to(broadcast_shape).contiguous();
- }
- class BroadcastLinearIndices {
- private:
- Tensor linear_indices_;
- bool is_broadcasting_;
- public:
- BroadcastLinearIndices(
- int64_t numel,
- IntArrayRef original_shape,
- IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
- // The assumption is that the broadcast_shape is a materialized broadcast
- // shape of the original_shape. We need to compute the linear indices
- // compatible with the original_shape to access the elements in the original
- // tensor corresponding to the broadcast tensor.
- if (is_broadcasting_) {
- linear_indices_ =
- get_linear_indices(numel, original_shape, broadcast_shape);
- }
- }
- int64_t operator()(int64_t broadcast_linear_index) {
- return is_broadcasting_
- ? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
- : broadcast_linear_index;
- }
- };
- static inline bool is_blas_compatible_column_major_order(const Tensor& input) {
- IntArrayRef input_strides = input.strides();
- IntArrayRef input_sizes = input.sizes();
- auto ndim = input.dim();
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
- if (ndim > 3) {
- return input.transpose(-2, -1).is_contiguous();
- }
- auto leading_dimension = input_strides[ndim - 1];
- auto rows = input_sizes[ndim - 2];
- bool batch_stride_compatible = true;
- if (ndim == 3) {
- auto cols = input_sizes[ndim - 1];
- batch_stride_compatible =
- input_strides[ndim - 3] >= leading_dimension * cols;
- }
- return (input_strides[ndim - 2] == 1) &&
- (leading_dimension >= std::max<int64_t>(1, rows)) &&
- batch_stride_compatible;
- }
- static inline bool is_blas_compatible_row_major_order(const Tensor& input) {
- IntArrayRef input_strides = input.strides();
- IntArrayRef input_sizes = input.sizes();
- auto ndim = input.dim();
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(ndim >= 2);
- if (ndim > 3) {
- return input.is_contiguous();
- }
- auto leading_dimension = input_strides[ndim - 2];
- auto cols = input_sizes[ndim - 1];
- bool batch_stride_compatible = true;
- if (ndim == 3) {
- auto rows = input_sizes[ndim - 2];
- batch_stride_compatible =
- input_strides[ndim - 3] >= leading_dimension * rows;
- }
- return (input_strides[ndim - 1] == 1) &&
- (leading_dimension >= std::max<int64_t>(1, cols)) &&
- batch_stride_compatible;
- }
- }} // namespace at::native
|