CuFFTPlanCache.h 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. #include <ATen/Config.h>
  2. #include <ATen/core/DimVector.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/native/cuda/CuFFTUtils.h>
  5. #include <ATen/native/utils/ParamsHash.h>
  6. #include <c10/util/accumulate.h>
  7. #include <c10/util/irange.h>
  8. #include <cufft.h>
  9. #include <cufftXt.h>
  10. #include <limits>
  11. #include <list>
  12. #include <sstream>
  13. #include <stdexcept>
  14. #include <string>
  15. #include <unordered_map>
  16. namespace at { namespace native { namespace detail {
  17. // Enum representing the FFT type
  18. enum class CuFFTTransformType : int8_t {
  19. C2C, // Complex-to-complex
  20. R2C, // Real-to-complex
  21. C2R, // Complex-to-real
  22. };
  23. // This struct is used to let us easily compute hashes of the
  24. // parameters.
  25. // It will be the **key** to the plan cache.
  26. struct CuFFTParams
  27. {
  28. int64_t signal_ndim_; // between 1 and max_rank, i.e., 1 <= signal_ndim <= 3
  29. // These include additional batch dimension as well.
  30. int64_t sizes_[max_rank + 1];
  31. int64_t input_strides_[max_rank + 1];
  32. int64_t output_strides_[max_rank + 1];
  33. CuFFTTransformType fft_type_;
  34. ScalarType value_type_;
  35. CuFFTParams() = default;
  36. CuFFTParams(IntArrayRef in_strides, IntArrayRef out_strides,
  37. IntArrayRef signal_sizes, CuFFTTransformType fft_type, ScalarType value_type) {
  38. // Padding bits must be zeroed for hashing
  39. memset(this, 0, sizeof(*this));
  40. signal_ndim_ = signal_sizes.size() - 1;
  41. fft_type_ = fft_type;
  42. value_type_ = value_type;
  43. TORCH_INTERNAL_ASSERT(in_strides.size() == signal_sizes.size());
  44. TORCH_INTERNAL_ASSERT(out_strides.size() == signal_sizes.size());
  45. TORCH_INTERNAL_ASSERT(1 <= signal_ndim_ && signal_ndim_ <= max_rank);
  46. std::copy(signal_sizes.cbegin(), signal_sizes.cend(), sizes_);
  47. std::copy(in_strides.cbegin(), in_strides.cend(), input_strides_);
  48. std::copy(out_strides.cbegin(), out_strides.cend(), output_strides_);
  49. }
  50. };
  51. static_assert(std::is_trivial<CuFFTParams>::value, "");
  52. // Returns true if the transform type has complex input
  53. inline bool cufft_complex_input(CuFFTTransformType type) {
  54. switch (type) {
  55. case CuFFTTransformType::C2C:
  56. case CuFFTTransformType::C2R:
  57. return true;
  58. case CuFFTTransformType::R2C:
  59. return false;
  60. }
  61. TORCH_INTERNAL_ASSERT(false);
  62. }
  63. // Returns true if the transform type has complex output
  64. inline bool cufft_complex_output(CuFFTTransformType type) {
  65. switch (type) {
  66. case CuFFTTransformType::C2C:
  67. case CuFFTTransformType::R2C:
  68. return true;
  69. case CuFFTTransformType::C2R:
  70. return false;
  71. }
  72. TORCH_INTERNAL_ASSERT(false);
  73. }
  74. // Create transform type enum from bools representing if input and output are complex
  75. inline CuFFTTransformType GetCuFFTTransformType(bool complex_input, bool complex_output) {
  76. if (complex_input && complex_output) {
  77. return CuFFTTransformType::C2C;
  78. } else if (complex_input && !complex_output) {
  79. return CuFFTTransformType::C2R;
  80. } else if (!complex_input && complex_output) {
  81. return CuFFTTransformType::R2C;
  82. }
  83. TORCH_INTERNAL_ASSERT(false, "Real to real FFTs are not supported");
  84. }
  85. class CuFFTHandle {
  86. ::cufftHandle handle_;
  87. public:
  88. CuFFTHandle() {
  89. CUFFT_CHECK(cufftCreate(&handle_));
  90. }
  91. ::cufftHandle & get() { return handle_; }
  92. const ::cufftHandle & get() const { return handle_; }
  93. ~CuFFTHandle() {
  94. // Not using fftDestroy() for rocFFT to work around double freeing of handles
  95. #if !defined(USE_ROCM)
  96. cufftDestroy(handle_);
  97. #endif
  98. }
  99. };
  100. __forceinline__
  101. static bool is_pow_of_two(int64_t x) {
  102. return (x & (x - 1)) == 0;
  103. }
  104. #if defined(USE_ROCM)
  105. using cufft_size_type = int;
  106. #else
  107. using cufft_size_type = long long int;
  108. #endif
  109. using CuFFTDimVector = c10::SmallVector<cufft_size_type, at::kDimVectorStaticSize>;
  110. // Struct representing a tensor in CuFFT's data layout for planning transforms
  111. // See NOTE [ cuFFT Embedded Strides ].
  112. struct CuFFTDataLayout {
  113. CuFFTDimVector embed;
  114. cufft_size_type stride, dist;
  115. bool must_clone, simple;
  116. };
  117. // Returns a cufft embedding for a contiguous signal of the given size.
  118. // e.g. if the input is cloned, this will be the resulting data layout
  119. // See NOTE [ cuFFT Embedded Strides ].
  120. inline CuFFTDataLayout cufft_simple_embed(IntArrayRef sizes, bool onesided) {
  121. CuFFTDataLayout layout;
  122. layout.simple = true;
  123. layout.must_clone = false;
  124. layout.embed.assign(sizes.cbegin() + 1, sizes.cend());
  125. if (onesided) {
  126. layout.embed.back() = sizes.back() / 2 + 1;
  127. }
  128. layout.stride = 1;
  129. layout.dist = 1;
  130. for (const auto& len : layout.embed) {
  131. layout.dist *= len;
  132. }
  133. return layout;
  134. }
  135. // Convert strides to a CuFFT embedded representation.
  136. // If strides cannot be embedded, returns a simple layout and sets must_clone flag
  137. // See NOTE [ cuFFT Embedded Strides ].
  138. inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bool onesided) {
  139. const auto signal_ndim = strides.size() - 1;
  140. CuFFTDataLayout layout;
  141. auto last_stride = strides[signal_ndim];
  142. layout.must_clone = (last_stride <= 0);
  143. const auto last_dim_size = onesided ?
  144. sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim];
  145. const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size;
  146. // Zero stides are not allowed, even if the batch size is one.
  147. // If that happens just set a dummy case
  148. if (sizes[0] == 1) {
  149. layout.dist = signal_numel;
  150. } else if (strides[0] == 0) {
  151. layout.must_clone = true;
  152. } else {
  153. layout.dist = strides[0];
  154. }
  155. // Calculate the embedding shape, or set must_clone if the strides cannot be embedded
  156. layout.embed.resize(signal_ndim);
  157. for (auto i = signal_ndim - 1; !layout.must_clone && i > 0; i--) {
  158. auto stride = strides[i];
  159. if (sizes[i] == 1) {
  160. layout.embed[i] = 1;
  161. } else if (stride > 0 && stride % last_stride == 0) {
  162. layout.embed[i] = stride / last_stride;
  163. last_stride = stride;
  164. } else {
  165. layout.must_clone = true;
  166. }
  167. }
  168. if (layout.must_clone) {
  169. // If the input needs to be cloned, assume it will be contiguous
  170. layout = cufft_simple_embed(sizes, onesided);
  171. layout.must_clone = true;
  172. } else {
  173. layout.embed[0] = sizes[1];
  174. layout.stride = strides[signal_ndim];
  175. // Determine if layout represents a simple embedding (contiguous data)
  176. layout.simple = [&] {
  177. for (const auto i : c10::irange(1, signal_ndim - 1)) {
  178. if (layout.embed[i] != sizes[i + 1]) {
  179. return false;
  180. }
  181. }
  182. return (layout.stride == 1 && layout.dist == signal_numel &&
  183. layout.embed.back() == last_dim_size);
  184. }();
  185. }
  186. return layout;
  187. }
  188. // This class contains all the information needed to execute a cuFFT plan:
  189. // 1. the plan
  190. // 2. whether to clone input before executing the plan
  191. // 3. the workspace size needed
  192. //
  193. // This class will be the **value** in the plan cache.
  194. // It **owns** the raw plan via a unique_ptr.
  195. class CuFFTConfig {
  196. public:
  197. // Only move semantics is enought for this class. Although we already use
  198. // unique_ptr for the plan, still remove copy constructor and assignment op so
  199. // we don't accidentally copy and take perf hit.
  200. CuFFTConfig(const CuFFTConfig&) = delete;
  201. CuFFTConfig& operator=(CuFFTConfig const&) = delete;
  202. explicit CuFFTConfig(const CuFFTParams& params):
  203. CuFFTConfig(
  204. IntArrayRef(params.input_strides_, params.signal_ndim_ + 1),
  205. IntArrayRef(params.output_strides_, params.signal_ndim_ + 1),
  206. IntArrayRef(params.sizes_, params.signal_ndim_ + 1),
  207. params.fft_type_,
  208. params.value_type_) {}
  209. // For complex types, strides are in units of 2 * element_size(dtype)
  210. // sizes are for the full signal, including batch size and always two-sided
  211. CuFFTConfig(IntArrayRef in_strides, IntArrayRef out_strides,
  212. IntArrayRef sizes, CuFFTTransformType fft_type, ScalarType dtype):
  213. fft_type_(fft_type), value_type_(dtype) {
  214. // signal sizes (excluding batch dim)
  215. CuFFTDimVector signal_sizes(sizes.begin() + 1, sizes.end());
  216. // input batch size
  217. const int64_t batch = sizes[0];
  218. const int64_t signal_ndim = sizes.size() - 1;
  219. // Since cuFFT has limited non-unit stride support and various constraints, we
  220. // use a flag to keep track throughout this function to see if we need to
  221. // input = input.clone();
  222. #if defined(USE_ROCM)
  223. // clone input to avoid issues with hipfft clobering the input and failing tests
  224. clone_input = true;
  225. #else
  226. clone_input = false;
  227. #endif
  228. // For half, base strides on the real part of real-to-complex and
  229. // complex-to-real transforms are not supported. Since our output is always
  230. // contiguous, only need to check real-to-complex case.
  231. if (dtype == ScalarType::Half) {
  232. // cuFFT on half requires compute capability of at least SM_53
  233. auto dev_prop = at::cuda::getCurrentDeviceProperties();
  234. TORCH_CHECK(dev_prop->major >= 5 && !(dev_prop->major == 5 && dev_prop->minor < 3),
  235. "cuFFT doesn't support signals of half type with compute "
  236. "capability less than SM_53, but the device containing input half "
  237. "tensor only has SM_", dev_prop->major, dev_prop->minor);
  238. for (const auto i : c10::irange(signal_ndim)) {
  239. TORCH_CHECK(is_pow_of_two(sizes[i + 1]),
  240. "cuFFT only supports dimensions whose sizes are powers of two when"
  241. " computing in half precision, but got a signal size of",
  242. sizes.slice(1));
  243. }
  244. clone_input |= in_strides.back() != 1;
  245. }
  246. CuFFTDataLayout in_layout;
  247. if (clone_input) {
  248. in_layout = cufft_simple_embed(sizes, fft_type == CuFFTTransformType::C2R);
  249. } else {
  250. in_layout = as_cufft_embed(in_strides, sizes, fft_type == CuFFTTransformType::C2R);
  251. }
  252. auto out_layout = as_cufft_embed(out_strides, sizes, fft_type == CuFFTTransformType::R2C);
  253. TORCH_INTERNAL_ASSERT(!out_layout.must_clone, "Out strides cannot be represented as CuFFT embedding");
  254. clone_input |= in_layout.must_clone;
  255. // Check if we can take advantage of simple data layout.
  256. //
  257. // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
  258. const bool simple_layout = in_layout.simple && out_layout.simple;
  259. #if defined(USE_ROCM)
  260. hipfftType exec_type = [&]{
  261. if (dtype == kFloat) {
  262. switch (fft_type) {
  263. case CuFFTTransformType::C2C: return HIPFFT_C2C;
  264. case CuFFTTransformType::R2C: return HIPFFT_R2C;
  265. case CuFFTTransformType::C2R: return HIPFFT_C2R;
  266. }
  267. } else if (dtype == kDouble) {
  268. switch (fft_type) {
  269. case CuFFTTransformType::C2C: return HIPFFT_Z2Z;
  270. case CuFFTTransformType::R2C: return HIPFFT_D2Z;
  271. case CuFFTTransformType::C2R: return HIPFFT_Z2D;
  272. }
  273. }
  274. TORCH_CHECK(false, "hipFFT doesn't support transforms of type: ", dtype);
  275. }();
  276. #else
  277. cudaDataType itype, otype, exec_type;
  278. const auto complex_input = cufft_complex_input(fft_type);
  279. const auto complex_output = cufft_complex_output(fft_type);
  280. if (dtype == ScalarType::Float) {
  281. itype = complex_input ? CUDA_C_32F : CUDA_R_32F;
  282. otype = complex_output ? CUDA_C_32F : CUDA_R_32F;
  283. exec_type = CUDA_C_32F;
  284. } else if (dtype == ScalarType::Double) {
  285. itype = complex_input ? CUDA_C_64F : CUDA_R_64F;
  286. otype = complex_output ? CUDA_C_64F : CUDA_R_64F;
  287. exec_type = CUDA_C_64F;
  288. } else if (dtype == ScalarType::Half) {
  289. itype = complex_input ? CUDA_C_16F : CUDA_R_16F;
  290. otype = complex_output ? CUDA_C_16F : CUDA_R_16F;
  291. exec_type = CUDA_C_16F;
  292. } else {
  293. TORCH_CHECK(false, "cuFFT doesn't support tensor of type: ", dtype);
  294. }
  295. #endif
  296. // disable auto allocation of workspace to use THC allocator
  297. CUFFT_CHECK(cufftSetAutoAllocation(plan(), /* autoAllocate */ 0));
  298. size_t ws_size_t;
  299. // make plan
  300. if (simple_layout) {
  301. // If with unit-stride, we tell cuFFT by setting inembed == onembed == NULL.
  302. // In such case, cuFFT ignores istride, ostride, idist, and odist
  303. // by assuming istride = ostride = 1.
  304. //
  305. // See NOTE [ cuFFT Embedded Strides ] in native/cuda/SpectralOps.cu.
  306. #if defined(USE_ROCM)
  307. CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
  308. /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1,
  309. /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1,
  310. exec_type, batch, &ws_size_t));
  311. #else
  312. CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
  313. /* inembed */ nullptr, /* base_istride */ 1, /* idist */ 1, itype,
  314. /* onembed */ nullptr, /* base_ostride */ 1, /* odist */ 1, otype,
  315. batch, &ws_size_t, exec_type));
  316. #endif
  317. } else {
  318. #if defined(USE_ROCM)
  319. CUFFT_CHECK(hipfftMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
  320. in_layout.embed.data(), in_layout.stride, in_layout.dist,
  321. out_layout.embed.data(), out_layout.stride, out_layout.dist,
  322. exec_type, batch, &ws_size_t));
  323. #else
  324. CUFFT_CHECK(cufftXtMakePlanMany(plan(), signal_ndim, signal_sizes.data(),
  325. in_layout.embed.data(), in_layout.stride, in_layout.dist, itype,
  326. out_layout.embed.data(), out_layout.stride, out_layout.dist, otype,
  327. batch, &ws_size_t, exec_type));
  328. #endif
  329. }
  330. ws_size = static_cast<int64_t>(ws_size_t);
  331. }
  332. const cufftHandle &plan() const { return plan_ptr.get(); }
  333. CuFFTTransformType transform_type() const { return fft_type_; }
  334. ScalarType data_type() const { return value_type_; }
  335. bool should_clone_input() const { return clone_input; }
  336. int64_t workspace_size() const { return ws_size; }
  337. private:
  338. CuFFTHandle plan_ptr;
  339. bool clone_input;
  340. int64_t ws_size;
  341. CuFFTTransformType fft_type_;
  342. ScalarType value_type_;
  343. };
  344. #if defined(USE_ROCM)
  345. // Note that the max plan number for CUDA version < 10 has to be 1023
  346. // due to a bug that fails on the 1024th plan
  347. constexpr int64_t CUFFT_MAX_PLAN_NUM = 1023;
  348. constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = CUFFT_MAX_PLAN_NUM;
  349. #else
  350. constexpr int64_t CUFFT_MAX_PLAN_NUM = std::numeric_limits<int64_t>::max();
  351. // The default max cache size chosen for CUDA version > 10 is arbitrary.
  352. // This number puts a limit on how big of a plan cache should we maintain by
  353. // default. Users can always configure it via cufft_set_plan_cache_max_size.
  354. constexpr int64_t CUFFT_DEFAULT_CACHE_SIZE = 4096;
  355. #endif
  356. static_assert(0 <= CUFFT_MAX_PLAN_NUM && CUFFT_MAX_PLAN_NUM <= std::numeric_limits<int64_t>::max(),
  357. "CUFFT_MAX_PLAN_NUM not in size_t range");
  358. static_assert(CUFFT_DEFAULT_CACHE_SIZE >= 0 && CUFFT_DEFAULT_CACHE_SIZE <= CUFFT_MAX_PLAN_NUM,
  359. "CUFFT_DEFAULT_CACHE_SIZE not in [0, CUFFT_MAX_PLAN_NUM] range");
  360. // This cache assumes that the mapping from key to value never changes.
  361. // This is **NOT** thread-safe. Please use a mutex when using it **AND** the
  362. // value returned from try_emplace_value.
  363. // The contract of using this cache is that try_emplace_value should only be
  364. // used when the max_size is positive.
  365. class CuFFTParamsLRUCache {
  366. public:
  367. using kv_t = typename std::pair<CuFFTParams, CuFFTConfig>;
  368. using map_t = typename std::unordered_map<std::reference_wrapper<CuFFTParams>,
  369. typename std::list<kv_t>::iterator,
  370. ParamsHash<CuFFTParams>,
  371. ParamsEqual<CuFFTParams>>;
  372. using map_kkv_iter_t = typename map_t::iterator;
  373. CuFFTParamsLRUCache() : CuFFTParamsLRUCache(CUFFT_DEFAULT_CACHE_SIZE) {}
  374. CuFFTParamsLRUCache(int64_t max_size) {
  375. _set_max_size(max_size);
  376. }
  377. CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept :
  378. _usage_list(std::move(other._usage_list)),
  379. _cache_map(std::move(other._cache_map)),
  380. _max_size(other._max_size) {}
  381. CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept {
  382. _usage_list = std::move(other._usage_list);
  383. _cache_map = std::move(other._cache_map);
  384. _max_size = other._max_size;
  385. return *this;
  386. }
  387. // If key is in this cache, return the cached config. Otherwise, emplace the
  388. // config in this cache and return it.
  389. // Return const reference because CuFFTConfig shouldn't be tampered with once
  390. // created.
  391. const CuFFTConfig &lookup(CuFFTParams params) {
  392. AT_ASSERT(_max_size > 0);
  393. map_kkv_iter_t map_it = _cache_map.find(params);
  394. // Hit, put to list front
  395. if (map_it != _cache_map.end()) {
  396. _usage_list.splice(_usage_list.begin(), _usage_list, map_it->second);
  397. return map_it->second->second;
  398. }
  399. // Miss
  400. // remove if needed
  401. if (_usage_list.size() >= _max_size) {
  402. auto last = _usage_list.end();
  403. last--;
  404. _cache_map.erase(last->first);
  405. _usage_list.pop_back();
  406. }
  407. // construct new plan at list front, then insert into _cache_map
  408. _usage_list.emplace_front(std::piecewise_construct,
  409. std::forward_as_tuple(params),
  410. std::forward_as_tuple(params));
  411. auto kv_it = _usage_list.begin();
  412. _cache_map.emplace(std::piecewise_construct,
  413. std::forward_as_tuple(kv_it->first),
  414. std::forward_as_tuple(kv_it));
  415. return kv_it->second;
  416. }
  417. void clear() {
  418. _cache_map.clear();
  419. _usage_list.clear();
  420. }
  421. void resize(int64_t new_size) {
  422. _set_max_size(new_size);
  423. auto cur_size = _usage_list.size();
  424. if (cur_size > _max_size) {
  425. auto delete_it = _usage_list.end();
  426. for (size_t i = 0; i < cur_size - _max_size; i++) {
  427. delete_it--;
  428. _cache_map.erase(delete_it->first);
  429. }
  430. _usage_list.erase(delete_it, _usage_list.end());
  431. }
  432. }
  433. size_t size() const { return _cache_map.size(); }
  434. size_t max_size() const noexcept { return _max_size; }
  435. std::mutex mutex;
  436. private:
  437. // Only sets size and does value check. Does not resize the data structures.
  438. void _set_max_size(int64_t new_size) {
  439. // We check that 0 <= new_size <= CUFFT_MAX_PLAN_NUM here. Since
  440. // CUFFT_MAX_PLAN_NUM is of type size_t, we need to do non-negativity check
  441. // first.
  442. TORCH_CHECK(new_size >= 0,
  443. "cuFFT plan cache size must be non-negative, but got ", new_size);
  444. TORCH_CHECK(new_size <= CUFFT_MAX_PLAN_NUM,
  445. "cuFFT plan cache size can not be larger than ", CUFFT_MAX_PLAN_NUM, ", but got ", new_size);
  446. _max_size = static_cast<size_t>(new_size);
  447. }
  448. std::list<kv_t> _usage_list;
  449. map_t _cache_map;
  450. size_t _max_size;
  451. };
  452. // Since ATen is separated into CPU build and CUDA build, we need a way to call
  453. // these functions only when CUDA is loaded. We use CUDA hooks for this purpose
  454. // (at cuda/detail/CUDAHooks.cpp), and call the hooked functions from the actual
  455. // native function counterparts (at native/SpectralOps.cpp), i.e.,
  456. // _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
  457. // _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
  458. int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index);
  459. void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size);
  460. int64_t cufft_get_plan_cache_size_impl(int64_t device_index);
  461. void cufft_clear_plan_cache_impl(int64_t device_index);
  462. }}} // namespace at::native::detail