123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- #pragma once
- #include <ATen/Utils.h>
- #include <c10/util/ArrayRef.h>
- #include <vector>
- namespace at {
- /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
- /// we can easily view it as a multidimensional array.
- ///
- /// Like ArrayRef, this class does not own the underlying data, it is expected
- /// to be used in situations where the data resides in some other buffer.
- ///
- /// This is intended to be trivially copyable, so it should be passed by
- /// value.
- ///
- /// For now, 2D only (so the copies are actually cheap, without having
- /// to write a SmallVector class) and contiguous only (so we can
- /// return non-strided ArrayRef on index).
- ///
- /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
- template <typename T>
- class MatrixRef {
- public:
- typedef size_t size_type;
- private:
- /// Underlying ArrayRef
- ArrayRef<T> arr;
- /// Stride of dim 0 (outer dimension)
- size_type stride0;
- // Stride of dim 1 is assumed to be 1
- public:
- /// Construct an empty Matrixref.
- /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
- /// Construct an MatrixRef from an ArrayRef and outer stride.
- /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
- : arr(arr), stride0(stride0) {
- TORCH_CHECK(
- arr.size() % stride0 == 0,
- "MatrixRef: ArrayRef size ",
- arr.size(),
- " not divisible by stride ",
- stride0)
- }
- /// @}
- /// @name Simple Operations
- /// @{
- /// empty - Check if the matrix is empty.
- bool empty() const {
- return arr.empty();
- }
- const T* data() const {
- return arr.data();
- }
- /// size - Get size a dimension
- size_t size(size_t dim) const {
- if (dim == 0) {
- return arr.size() / stride0;
- } else if (dim == 1) {
- return stride0;
- } else {
- TORCH_CHECK(
- 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
- }
- }
- size_t numel() const {
- return arr.size();
- }
- /// equals - Check for element-wise equality.
- bool equals(MatrixRef RHS) const {
- return stride0 == RHS.stride0 && arr.equals(RHS.arr);
- }
- /// @}
- /// @name Operator Overloads
- /// @{
- ArrayRef<T> operator[](size_t Index) const {
- return arr.slice(Index * stride0, stride0);
- }
- /// Disallow accidental assignment from a temporary.
- ///
- /// The declaration here is extra complicated so that "arrayRef = {}"
- /// continues to select the move assignment operator.
- template <typename U>
- typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type&
- operator=(U&& Temporary) = delete;
- /// Disallow accidental assignment from a temporary.
- ///
- /// The declaration here is extra complicated so that "arrayRef = {}"
- /// continues to select the move assignment operator.
- template <typename U>
- typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type&
- operator=(std::initializer_list<U>) = delete;
- };
- } // end namespace at
|