MatrixRef.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #pragma once
  2. #include <ATen/Utils.h>
  3. #include <c10/util/ArrayRef.h>
  4. #include <vector>
  5. namespace at {
  6. /// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that
  7. /// we can easily view it as a multidimensional array.
  8. ///
  9. /// Like ArrayRef, this class does not own the underlying data, it is expected
  10. /// to be used in situations where the data resides in some other buffer.
  11. ///
  12. /// This is intended to be trivially copyable, so it should be passed by
  13. /// value.
  14. ///
  15. /// For now, 2D only (so the copies are actually cheap, without having
  16. /// to write a SmallVector class) and contiguous only (so we can
  17. /// return non-strided ArrayRef on index).
  18. ///
  19. /// P.S. dimension 0 indexes rows, dimension 1 indexes columns
  20. template <typename T>
  21. class MatrixRef {
  22. public:
  23. typedef size_t size_type;
  24. private:
  25. /// Underlying ArrayRef
  26. ArrayRef<T> arr;
  27. /// Stride of dim 0 (outer dimension)
  28. size_type stride0;
  29. // Stride of dim 1 is assumed to be 1
  30. public:
  31. /// Construct an empty Matrixref.
  32. /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {}
  33. /// Construct an MatrixRef from an ArrayRef and outer stride.
  34. /*implicit*/ MatrixRef(ArrayRef<T> arr, size_type stride0)
  35. : arr(arr), stride0(stride0) {
  36. TORCH_CHECK(
  37. arr.size() % stride0 == 0,
  38. "MatrixRef: ArrayRef size ",
  39. arr.size(),
  40. " not divisible by stride ",
  41. stride0)
  42. }
  43. /// @}
  44. /// @name Simple Operations
  45. /// @{
  46. /// empty - Check if the matrix is empty.
  47. bool empty() const {
  48. return arr.empty();
  49. }
  50. const T* data() const {
  51. return arr.data();
  52. }
  53. /// size - Get size a dimension
  54. size_t size(size_t dim) const {
  55. if (dim == 0) {
  56. return arr.size() / stride0;
  57. } else if (dim == 1) {
  58. return stride0;
  59. } else {
  60. TORCH_CHECK(
  61. 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1");
  62. }
  63. }
  64. size_t numel() const {
  65. return arr.size();
  66. }
  67. /// equals - Check for element-wise equality.
  68. bool equals(MatrixRef RHS) const {
  69. return stride0 == RHS.stride0 && arr.equals(RHS.arr);
  70. }
  71. /// @}
  72. /// @name Operator Overloads
  73. /// @{
  74. ArrayRef<T> operator[](size_t Index) const {
  75. return arr.slice(Index * stride0, stride0);
  76. }
  77. /// Disallow accidental assignment from a temporary.
  78. ///
  79. /// The declaration here is extra complicated so that "arrayRef = {}"
  80. /// continues to select the move assignment operator.
  81. template <typename U>
  82. typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type&
  83. operator=(U&& Temporary) = delete;
  84. /// Disallow accidental assignment from a temporary.
  85. ///
  86. /// The declaration here is extra complicated so that "arrayRef = {}"
  87. /// continues to select the move assignment operator.
  88. template <typename U>
  89. typename std::enable_if<std::is_same<U, T>::value, MatrixRef<T>>::type&
  90. operator=(std::initializer_list<U>) = delete;
  91. };
  92. } // end namespace at