Descriptors.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #pragma once
  2. #include <string>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. #include <ATen/cudnn/cudnn-wrapper.h>
  6. #include <ATen/cudnn/Utils.h>
  7. #include <ATen/core/Tensor.h>
  8. #include <ATen/TensorUtils.h>
  9. #include <ATen/cuda/ATenCUDAGeneral.h>
  10. #include <cuda.h>
  11. #ifndef AT_PER_OPERATOR_HEADERS
  12. #include <ATen/Functions.h>
  13. #else
  14. #include <ATen/ops/empty.h>
  15. #endif
  16. namespace at { namespace native {
  17. std::string cudnnTypeToString(cudnnDataType_t dtype);
  18. // TODO: Add constructors for all of the descriptors
  19. inline int dataSize(cudnnDataType_t dataType)
  20. {
  21. switch (dataType) {
  22. #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8200
  23. case CUDNN_DATA_BFLOAT16:
  24. #endif
  25. case CUDNN_DATA_HALF: return 2;
  26. case CUDNN_DATA_FLOAT: return 4;
  27. default: return 8;
  28. }
  29. }
  30. // The stride for a size-1 dimensions is not uniquely determined; in
  31. // fact, it can be anything you want, because the fact that the
  32. // tensor is size 1 at this dimension means that you will never actually
  33. // try advancing your pointer by this stride.
  34. //
  35. // However, CuDNN has a much more stringent requirement on strides:
  36. // if you are passing a contiguous input, it better be the case
  37. // that the stride for dim i is the product of the sizes of dims
  38. // i+1 to the end. This stride is indeed uniquely determined. This
  39. // function modifies 'stride' in place so this invariant holds.
  40. template <typename T>
  41. static inline void fixSizeOneDimStride(int dim, const T *size, T *stride, bool nhwc) {
  42. int64_t z = 1;
  43. int index = 0;
  44. std::vector<int> permutation(dim);
  45. if (nhwc) {
  46. permutation[index++] = 1;
  47. }
  48. for (int d = dim-1; d > 1; d--) {
  49. permutation[index++] = d;
  50. }
  51. if (!nhwc) {
  52. permutation[index++] = 1;
  53. }
  54. permutation[index++] = 0;
  55. for (int d : permutation) {
  56. if (size[d] == 1) {
  57. stride[d] = z;
  58. } else {
  59. z *= size[d];
  60. }
  61. }
  62. }
  63. template <typename T, cudnnStatus_t (*dtor)(T*)>
  64. struct DescriptorDeleter {
  65. void operator()(T* x) {
  66. if (x != nullptr) {
  67. AT_CUDNN_CHECK(dtor(x));
  68. }
  69. }
  70. };
  71. // A generic class for wrapping cuDNN descriptor types. All you need
  72. // is to give the underlying type the Descriptor_t points to (usually,
  73. // if it's cudnnTensorDescriptor_t it points to cudnnTensorStruct),
  74. // the constructor and the destructor. Subclasses are responsible
  75. // for defining a set() function to actually set the descriptor.
  76. //
  77. // Descriptors default construct to a nullptr, and have a descriptor
  78. // initialized the first time you call set() or any other initializing
  79. // function.
  80. template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
  81. class TORCH_CUDA_CPP_API Descriptor {
  82. public:
  83. // TODO: Figure out why const-correctness doesn't work here
  84. // Use desc() to access the underlying descriptor pointer in
  85. // a read-only fashion. Most client code should use this.
  86. // If the descriptor was never initialized, this will return
  87. // nullptr.
  88. T* desc() const { return desc_.get(); }
  89. T* desc() { return desc_.get(); }
  90. // Use mut_desc() to access the underlying descriptor pointer
  91. // if you intend to modify what it points to (e.g., using
  92. // cudnnSetFooDescriptor). This will ensure that the descriptor
  93. // is initialized. Code in this file will use this function.
  94. T* mut_desc() { init(); return desc_.get(); }
  95. protected:
  96. void init() {
  97. if (desc_ == nullptr) {
  98. T* raw_desc;
  99. AT_CUDNN_CHECK(ctor(&raw_desc));
  100. desc_.reset(raw_desc);
  101. }
  102. }
  103. private:
  104. std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
  105. };
  106. class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
  107. cudnnTensorStruct,
  108. &cudnnCreateTensorDescriptor,
  109. &cudnnDestroyTensorDescriptor> {
  110. public:
  111. TensorDescriptor() = default;
  112. explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
  113. set(t, pad);
  114. }
  115. // Note [CuDNN broadcast padding]
  116. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  117. // pad specifies the minimum dimensionality of the tensor descriptor
  118. // we produce (it doesn't have anything to do with, e.g., convolution
  119. // padding). If 't' is lower-dimensional than 'pad', the remaining
  120. // dimensions (on the right) are padded with ones. This doesn't
  121. // affect the underlying data layout. This is particularly useful for
  122. // dealing with a pecularity of the CuDNN API, which is that broadcasting in CuDNN is
  123. // done in two steps: first, the client code is expected to pad out
  124. // (the dimensions) input tensors to be the same dimension as the
  125. // target broadcast, and then second, CuDNN takes of actually
  126. // broadcasting size 1 dimensions.
  127. void set(const at::Tensor &t, size_t pad = 0);
  128. void set(const at::Tensor &t, at::MemoryFormat memory_format, size_t pad = 0);
  129. void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad = 0);
  130. void print();
  131. private:
  132. void set(cudnnDataType_t dataType, IntArrayRef sizes, IntArrayRef strides, size_t pad, bool nhwc);
  133. void set(cudnnDataType_t dataType, int dim, int* size, int* stride, bool nhwc) {
  134. fixSizeOneDimStride<int>(dim, size, stride, nhwc);
  135. AT_CUDNN_CHECK(cudnnSetTensorNdDescriptor(mut_desc(), dataType, dim, size, stride));
  136. }
  137. };
  138. std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
  139. class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
  140. cudnnFilterStruct,
  141. &cudnnCreateFilterDescriptor,
  142. &cudnnDestroyFilterDescriptor> {
  143. public:
  144. void set(const at::Tensor &t, int64_t pad = 0) {
  145. set(t, at::MemoryFormat::Contiguous, pad);
  146. }
  147. void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
  148. void print();
  149. private:
  150. void set(cudnnDataType_t dataType, int dim, int* size, cudnnTensorFormat_t filter_format) {
  151. AT_CUDNN_CHECK(cudnnSetFilterNdDescriptor(mut_desc(), dataType, filter_format, dim, size));
  152. }
  153. };
  154. std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
  155. struct TORCH_CUDA_CPP_API ConvolutionDescriptor
  156. : public Descriptor<
  157. cudnnConvolutionStruct,
  158. &cudnnCreateConvolutionDescriptor,
  159. &cudnnDestroyConvolutionDescriptor> {
  160. void set(cudnnDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups, bool allow_tf32) {
  161. cudnnDataType_t mathType = dataType;
  162. if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
  163. AT_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale,
  164. CUDNN_CROSS_CORRELATION, mathType));
  165. AT_CUDNN_CHECK(cudnnSetConvolutionGroupCount(mut_desc(), groups));
  166. // See Note [behavior of cudnnFind and cudnnGet]
  167. AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_DEFAULT_MATH));
  168. if(dataType == CUDNN_DATA_HALF) {
  169. AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_TENSOR_OP_MATH));
  170. } else if (dataType == CUDNN_DATA_FLOAT && !allow_tf32) {
  171. #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
  172. AT_CUDNN_CHECK(cudnnSetConvolutionMathType(mut_desc(), CUDNN_FMA_MATH));
  173. #endif
  174. }
  175. }
  176. };
  177. struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
  178. : public Descriptor<
  179. cudnnSpatialTransformerStruct,
  180. &cudnnCreateSpatialTransformerDescriptor,
  181. &cudnnDestroySpatialTransformerDescriptor> {
  182. void set(cudnnDataType_t dataType, int dim, int* size) {
  183. AT_CUDNN_CHECK(cudnnSetSpatialTransformerNdDescriptor(mut_desc(), CUDNN_SAMPLER_BILINEAR, dataType, dim, size));
  184. }
  185. };
  186. struct TORCH_CUDA_CPP_API DropoutDescriptor
  187. : public Descriptor<
  188. cudnnDropoutStruct,
  189. &cudnnCreateDropoutDescriptor,
  190. &cudnnDestroyDropoutDescriptor> {
  191. at::Tensor state;
  192. // Initialize a dropout descriptor's RNG state.
  193. // WARNING: This function is very expensive, avoid calling this function!
  194. void initialize_rng(cudnnHandle_t handle, float dropout, long long int seed, const TensorOptions& options) {
  195. TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
  196. size_t state_size;
  197. AT_CUDNN_CHECK(cudnnDropoutGetStatesSize(handle, &state_size));
  198. AT_ASSERT(options.device().type() == kCUDA);
  199. AT_ASSERT(options.dtype() == kByte);
  200. state = at::empty({static_cast<int64_t>(state_size)}, options);
  201. AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
  202. }
  203. // Restore a dropout descriptor given a dropout probability and existing RNG state.
  204. void set(cudnnHandle_t handle, float dropout, at::Tensor state_) {
  205. TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
  206. state = state_;
  207. void *state_ptr = state.data_ptr();
  208. size_t state_size = state.size(0);
  209. // NB: The seed doesn't actually matter, so we give a dummy value
  210. AT_CUDNN_CHECK(cudnnRestoreDropoutDescriptor(mut_desc(), handle, dropout, state_ptr, state_size, 0 /* seed */));
  211. }
  212. // Restore a dropout descriptor corresponding to no dropout
  213. void set_no_dropout(cudnnHandle_t handle) {
  214. // NB: seed doesn't matter when dropout = 0, because no random number
  215. // initialization actually takes place when there is no dropout.
  216. // NB: Empirically, cudnnSetDropoutDescriptor is cheap when
  217. // dropoot == 0
  218. AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, 0 /* dropout */, nullptr, 0 /* state_size */, 0 /* seed */));
  219. }
  220. };
  221. struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
  222. cudnnRNNStruct,
  223. &cudnnCreateRNNDescriptor,
  224. &cudnnDestroyRNNDescriptor> {
  225. DropoutDescriptor dropout_desc_;
  226. void set(cudnnHandle_t handle, int hidden_size, int proj_size, int num_layers, DropoutDescriptor&& dropout_desc,
  227. cudnnRNNInputMode_t input_mode, cudnnDirectionMode_t bidirectional,
  228. cudnnRNNMode_t mode, cudnnDataType_t datatype, cudnnDataType_t input_type, cudnnRNNAlgo_t algo, bool allow_tf32) {
  229. dropout_desc_ = std::move(dropout_desc);
  230. AT_CUDNN_CHECK(cudnnSetRNNDescriptor_v6(
  231. handle,
  232. mut_desc(),
  233. hidden_size,
  234. num_layers,
  235. dropout_desc_.desc(),
  236. input_mode,
  237. bidirectional,
  238. mode,
  239. algo,
  240. datatype));
  241. if (proj_size != 0) {
  242. AT_CUDNN_CHECK(cudnnSetRNNProjectionLayers(
  243. handle,
  244. /*rnnDesc=*/mut_desc(),
  245. /*recProjSize=*/proj_size,
  246. /*outProjSize=*/0));
  247. }
  248. cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  249. if (prop->major >= 7) {
  250. if (input_type == CUDNN_DATA_HALF) {
  251. cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_TENSOR_OP_MATH);
  252. }
  253. #if defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
  254. else if (input_type == CUDNN_DATA_FLOAT && !allow_tf32) {
  255. cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_FMA_MATH);
  256. }
  257. #endif
  258. else {
  259. // Technically, as the default it's not necessary to explicitly
  260. // set this.
  261. cudnnSetRNNMatrixMathType(mut_desc(), CUDNN_DEFAULT_MATH);
  262. }
  263. }
  264. }
  265. };
  266. struct TORCH_CUDA_CPP_API CTCLossDescriptor
  267. : public Descriptor<
  268. cudnnCTCLossStruct,
  269. &cudnnCreateCTCLossDescriptor,
  270. &cudnnDestroyCTCLossDescriptor> {
  271. void set(cudnnDataType_t datatype) {
  272. AT_CUDNN_CHECK(cudnnSetCTCLossDescriptor(mut_desc(), datatype));
  273. }
  274. #if CUDNN_VERSION >= 7600
  275. void setEx(
  276. cudnnDataType_t datatype,
  277. cudnnLossNormalizationMode_t normMode,
  278. cudnnNanPropagation_t gradMode) {
  279. AT_CUDNN_CHECK(
  280. cudnnSetCTCLossDescriptorEx(mut_desc(), datatype, normMode, gradMode));
  281. }
  282. #endif
  283. };
  284. struct TORCH_CUDA_CPP_API ActivationDescriptor
  285. : public Descriptor<
  286. cudnnActivationStruct,
  287. &cudnnCreateActivationDescriptor,
  288. &cudnnDestroyActivationDescriptor> {
  289. void set(cudnnActivationMode_t mode) {
  290. AT_ASSERT(
  291. mode == CUDNN_ACTIVATION_RELU,
  292. "TODO: support more cuDNN activation modes");
  293. AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
  294. mut_desc(),
  295. mode,
  296. cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN,
  297. std::numeric_limits<double>::max()));
  298. }
  299. };
  300. union Constant
  301. {
  302. float f;
  303. double d;
  304. Constant(cudnnDataType_t dataType, double value) {
  305. if (dataType == CUDNN_DATA_HALF || dataType == CUDNN_DATA_FLOAT) {
  306. f = static_cast<float>(value);
  307. } else {
  308. d = value;
  309. }
  310. }
  311. };
  312. }} // namespace