Descriptors.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #pragma once
  2. #include <ATen/miopen/Exceptions.h>
  3. #include <ATen/miopen/miopen-wrapper.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/TensorUtils.h>
  6. namespace at { namespace native {
  7. inline int dataSize(miopenDataType_t dataType)
  8. {
  9. switch (dataType) {
  10. case miopenHalf: return 2;
  11. case miopenFloat: return 4;
  12. case miopenBFloat16: return 2;
  13. default: return 8;
  14. }
  15. }
  16. template <typename T, miopenStatus_t (*dtor)(T*)>
  17. struct DescriptorDeleter {
  18. void operator()(T* x) {
  19. if (x != nullptr) {
  20. MIOPEN_CHECK(dtor(x));
  21. }
  22. }
  23. };
  24. // A generic class for wrapping MIOpen descriptor types. All you need
  25. // is to give the underlying type the Descriptor_t points to (usually,
  26. // if it's miopenTensorDescriptor_t it points to miopenTensorStruct),
  27. // the constructor and the destructor. Subclasses are responsible
  28. // for defining a set() function to actually set the descriptor.
  29. //
  30. // Descriptors default construct to a nullptr, and have a descriptor
  31. // initialized the first time you call set() or any other initializing
  32. // function.
  33. template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
  34. class Descriptor
  35. {
  36. public:
  37. // Use desc() to access the underlying descriptor pointer in
  38. // a read-only fashion. Most client code should use this.
  39. // If the descriptor was never initialized, this will return
  40. // nullptr.
  41. T* desc() const { return desc_.get(); }
  42. T* desc() { return desc_.get(); }
  43. // Use mut_desc() to access the underlying descriptor pointer
  44. // if you intend to modify what it points to (e.g., using
  45. // miopenSetFooDescriptor). This will ensure that the descriptor
  46. // is initialized. Code in this file will use this function.
  47. T* mut_desc() { init(); return desc_.get(); }
  48. protected:
  49. void init() {
  50. if (desc_ == nullptr) {
  51. T* raw_desc;
  52. MIOPEN_CHECK(ctor(&raw_desc));
  53. desc_.reset(raw_desc);
  54. }
  55. }
  56. private:
  57. std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
  58. };
  59. class TensorDescriptor
  60. : public Descriptor<miopenTensorDescriptor,
  61. &miopenCreateTensorDescriptor,
  62. &miopenDestroyTensorDescriptor>
  63. {
  64. public:
  65. TensorDescriptor() {}
  66. explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
  67. set(t, pad);
  68. }
  69. void set(const at::Tensor &t, size_t pad = 0);
  70. void set(miopenDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
  71. void print();
  72. private:
  73. void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
  74. MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
  75. }
  76. };
  77. std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
  78. class FilterDescriptor
  79. : public Descriptor<miopenTensorDescriptor,
  80. &miopenCreateTensorDescriptor,
  81. &miopenDestroyTensorDescriptor>
  82. {
  83. public:
  84. void set(const at::Tensor &t, int64_t pad = 0) {
  85. set(t, at::MemoryFormat::Contiguous, pad);
  86. }
  87. void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
  88. private:
  89. void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
  90. MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
  91. }
  92. };
  93. struct ConvolutionDescriptor
  94. : public Descriptor<miopenConvolutionDescriptor,
  95. &miopenCreateConvolutionDescriptor,
  96. &miopenDestroyConvolutionDescriptor>
  97. {
  98. void set(miopenDataType_t dataType, miopenConvolutionMode_t c_mode, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) {
  99. MIOPEN_CHECK(miopenInitConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, c_mode));
  100. MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
  101. }
  102. };
  103. struct RNNDescriptor
  104. : public Descriptor<miopenRNNDescriptor,
  105. &miopenCreateRNNDescriptor,
  106. &miopenDestroyRNNDescriptor>
  107. {
  108. void set(int64_t hidden_size, int64_t num_layers, miopenRNNInputMode_t input_mode, miopenRNNDirectionMode_t direction, miopenRNNMode_t rnn_mode,
  109. miopenRNNBiasMode_t bias_mode, miopenRNNAlgo_t algorithm, miopenDataType_t datatype) {
  110. MIOPEN_CHECK(miopenSetRNNDescriptor(mut_desc(), hidden_size, num_layers, input_mode, direction, rnn_mode, bias_mode, algorithm, datatype));
  111. }
  112. };
  113. union Constant
  114. {
  115. float f;
  116. double d;
  117. Constant(miopenDataType_t dataType, double value) {
  118. if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
  119. f = static_cast<float>(value);
  120. } else {
  121. d = value;
  122. }
  123. }
  124. };
  125. }} // namespace