MemoryAccess.cuh 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #pragma once
  2. #include <cstdint>
  3. #include <type_traits>
  4. #include <c10/core/DynamicCast.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/TypeCast.h>
  7. #include <c10/macros/Macros.h>
  8. #include <ATen/core/Array.h>
  9. #include <ATen/detail/FunctionTraits.h>
  10. #include <ATen/cuda/detail/OffsetCalculator.cuh>
  11. #include <ATen/native/cuda/thread_constants.h>
  12. #include <thrust/tuple.h>
  13. // References:
  14. // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
  15. namespace at { namespace native { namespace memory {
  16. namespace detail {
  17. // What does the `static_unroll` do?
  18. //
  19. // We want to do something like:
  20. //
  21. // using args_t = typename traits::ArgsTuple;
  22. // args_t args;
  23. // #pragma unroll
  24. // for (int i = 0; i < traits::arity; i++) {
  25. // std::get<i>(args) = ....
  26. // }
  27. //
  28. // but unfortunately the above code does not work because
  29. // the template argument has to be a compile time constant
  30. // so `static_unroll` is created to simulate `#pragma unroll`
  31. // using template metaprogramming.
  32. template<template<int i> typename func, int end, int current=0>
  33. struct static_unroll {
  34. template<typename... Args>
  35. static inline C10_HOST_DEVICE void with_args(Args&&... args) {
  36. func<current>::apply(std::forward<Args>(args)...);
  37. static_unroll<func, end, current+1>::with_args(args...);
  38. }
  39. };
  40. template<template<int i> typename func, int end>
  41. struct static_unroll<func, end, end> {
  42. template<typename... Args>
  43. static inline C10_HOST_DEVICE void with_args(Args... args) {}
  44. };
  45. // helper structs to be used with static_unroll to load arguments
  46. // one by one
  47. template<int arg_index>
  48. struct vectorized_load_helper {
  49. template <typename args_t, typename policy_t>
  50. static __device__ void apply(policy_t &self, args_t *args, int idx) {
  51. using arg_t = std::tuple_element_t<arg_index, args_t>;
  52. // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
  53. // need a +1 offset to get the input
  54. auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
  55. auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
  56. self.load_single_arg(args_accessor, ptr);
  57. }
  58. };
  59. template<int arg_index>
  60. struct unroll_load_helper {
  61. template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
  62. static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
  63. using arg_t = std::tuple_element_t<arg_index, args_t>;
  64. // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
  65. // need a +1 offset to get the input
  66. std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
  67. }
  68. };
  69. template <int current>
  70. struct multi_outputs_store_helper {
  71. template<int ntensors, int num_outputs, typename ...Args>
  72. C10_HOST_DEVICE static void apply(
  73. at::detail::Array<char*, ntensors> data,
  74. at::detail::Array<uint32_t, num_outputs> offsets,
  75. thrust::tuple<Args...> ret) {
  76. using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
  77. T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
  78. *to = thrust::get<current>(ret);
  79. }
  80. };
  81. } // namespace detail
  82. struct LoadWithoutCast {
  83. template<typename scalar_t>
  84. __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
  85. return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
  86. }
  87. };
  88. template <int N>
  89. struct LoadWithCast {
  90. using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
  91. using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
  92. array_t dtypes;
  93. size_array_t element_sizes;
  94. LoadWithCast(const TensorIteratorBase& iter) {
  95. assert(iter.ninputs() == N);
  96. #pragma unroll
  97. for (auto i = 0; i < N; ++i) {
  98. this->dtypes[i] = iter.dtype(i + iter.noutputs());
  99. element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
  100. }
  101. }
  102. template<typename scalar_t>
  103. __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
  104. void *ptr = base_ptr + element_sizes[arg] * offset;
  105. return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
  106. }
  107. };
  108. struct StoreWithoutCast {
  109. template<typename scalar_t>
  110. __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
  111. *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
  112. }
  113. };
  114. template <int N = 1>
  115. struct StoreWithCast {
  116. using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
  117. using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
  118. array_t dtypes;
  119. size_array_t element_sizes;
  120. StoreWithCast(const TensorIteratorBase& iter) {
  121. assert(iter.noutputs() == N);
  122. #pragma unroll
  123. for (auto i = 0; i < N; ++i) {
  124. this->dtypes[i] = iter.dtype(i);
  125. element_sizes[i] = c10::elementSize(iter.dtype(i));
  126. }
  127. }
  128. template<typename scalar_t>
  129. __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
  130. void *ptr = base_ptr + element_sizes[arg] * offset;
  131. c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
  132. }
  133. };
  134. // aligned vector generates vectorized load/store on CUDA
  135. template<typename scalar_t, int vec_size>
  136. struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
  137. scalar_t val[vec_size];
  138. };
  139. template <int vec_size, typename scalar_t>
  140. __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
  141. using vec_t = aligned_vector<scalar_t, vec_size>;
  142. auto *from = reinterpret_cast<const vec_t *>(base_ptr);
  143. return from[offset];
  144. }
  145. template <int vec_size>
  146. __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
  147. // See NOTE [Loading boolean values]
  148. auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
  149. aligned_vector<bool, vec_size> ret;
  150. for (int i = 0; i < vec_size; ++i) {
  151. ret.val[i] = bool(tmp.val[i]);
  152. }
  153. return ret;
  154. }
  155. namespace policies {
  156. // Assumption:
  157. // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
  158. template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
  159. struct unroll {
  160. data_t data;
  161. int remaining;
  162. inp_calc_t input_offset_calculator;
  163. out_calc_t output_offset_calculator;
  164. loader_t loader;
  165. storer_t storer;
  166. __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
  167. data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
  168. __device__ inline bool check_inbounds(int thread_work_elem) {
  169. return ((threadIdx.x + thread_work_elem*num_threads()) < remaining);
  170. }
  171. template<typename args_t>
  172. __device__ inline void load(args_t *args, int idx) {
  173. constexpr int arity = std::tuple_size<args_t>::value;
  174. int thread_idx = threadIdx.x;
  175. #pragma unroll
  176. for (int i = 0; i < thread_work_size(); i++) {
  177. if (thread_idx >= remaining) {
  178. return;
  179. }
  180. int linear_idx = thread_idx + block_work_size() * idx;
  181. auto offset = input_offset_calculator.get(linear_idx);
  182. detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
  183. thread_idx += num_threads();
  184. }
  185. }
  186. template<typename scalar_t>
  187. __device__ inline void store(scalar_t *from, int idx) {
  188. int thread_idx = threadIdx.x;
  189. scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
  190. #pragma unroll
  191. for (int i = 0; i < thread_work_size(); i++) {
  192. if (thread_idx >= remaining) {
  193. return;
  194. }
  195. int linear_idx = thread_idx + block_work_size() * idx;
  196. int offset = output_offset_calculator.get(linear_idx)[0];
  197. storer.store(from[i], data[0], offset);
  198. thread_idx += num_threads();
  199. }
  200. }
  201. };
  202. // Assumption:
  203. // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
  204. // Note:
  205. // Functions in vectorized policy does not do boundary check. It assumes the whole block
  206. // has its job to do. So the reminders should be handled by the the caller manually.
  207. template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
  208. struct vectorized {
  209. static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
  210. static constexpr int loop_size = thread_work_size() / vec_size;
  211. data_t data;
  212. __device__ vectorized(data_t data) : data(data) {}
  213. __device__ inline constexpr bool check_inbounds(int thread_work_elem) {
  214. return true;
  215. }
  216. template<typename accessor_t, typename scalar_t>
  217. __device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
  218. int thread_idx = threadIdx.x;
  219. #pragma unroll
  220. for (int i = 0; i < loop_size; i++) {
  221. int index = thread_idx + i * num_threads();
  222. auto v = load_vector<vec_size>(from, index);
  223. #pragma unroll
  224. for (int j = 0; j < vec_size; j++) {
  225. to(vec_size * i + j) = v.val[j];
  226. }
  227. }
  228. }
  229. template<typename args_t>
  230. __device__ inline void load(args_t *args, int idx) {
  231. constexpr int arity = std::tuple_size<args_t>::value;
  232. detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
  233. }
  234. template<typename scalar_t>
  235. __device__ inline void store(scalar_t *from, int idx) {
  236. using vec_t = aligned_vector<scalar_t, vec_size>;
  237. scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
  238. vec_t *to_ = reinterpret_cast<vec_t *>(to);
  239. int thread_idx = threadIdx.x;
  240. #pragma unroll
  241. for (int i = 0; i < loop_size; i++) {
  242. int index = thread_idx + i * num_threads();
  243. vec_t v;
  244. for (int j = 0; j < vec_size; j++) {
  245. v.val[j] = from[vec_size * i + j];
  246. }
  247. to_[index] = v;
  248. }
  249. }
  250. };
  251. template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
  252. struct multi_outputs_unroll {
  253. //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
  254. //we don't use inheritance because of compiler bug in cuda 10.2+
  255. data_t data;
  256. int remaining;
  257. inp_calc_t input_offset_calculator;
  258. out_calc_t output_offset_calculator;
  259. LoadWithoutCast loader;
  260. StoreWithoutCast storer;
  261. __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
  262. data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
  263. __device__ inline bool check_inbounds(int thread_work_elem) {
  264. return ((threadIdx.x + thread_work_elem*num_threads()) < remaining);
  265. }
  266. template<typename args_t>
  267. __device__ inline void load(args_t *args, int idx) {
  268. constexpr int arity = std::tuple_size<args_t>::value;
  269. int thread_idx = threadIdx.x;
  270. #pragma unroll
  271. for (int i = 0; i < thread_work_size(); i++) {
  272. if (thread_idx >= remaining) {
  273. return;
  274. }
  275. int linear_idx = thread_idx + block_work_size() * idx;
  276. auto offset = input_offset_calculator.get(linear_idx);
  277. detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
  278. thread_idx += num_threads();
  279. }
  280. }
  281. template <typename return_t>
  282. __device__ inline void store(return_t *from, int idx) {
  283. int thread_idx = threadIdx.x;
  284. #pragma unroll
  285. for (int i = 0; i < thread_work_size(); i++) {
  286. if (thread_idx >= this->remaining) {
  287. return;
  288. }
  289. int linear_idx = thread_idx + block_work_size() * idx;
  290. auto offsets = this->output_offset_calculator.get(linear_idx);
  291. memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
  292. thread_idx += num_threads();
  293. }
  294. }
  295. };
  296. } // namespace policies
  297. // This is only used in host, but we will wrap this into some templates
  298. // which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
  299. // in order to compile
  300. template<typename scalar_t>
  301. inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
  302. uint64_t address = reinterpret_cast<uint64_t>(pointer);
  303. constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
  304. constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
  305. if (address % vec4_alignment == 0) {
  306. return 4;
  307. } else if (address % vec2_alignment == 0) {
  308. return 2;
  309. }
  310. return 1;
  311. }
  312. template<int i>
  313. struct can_vectorize_up_to_helper {
  314. template <typename array_t, typename traits>
  315. static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
  316. using arg_t = typename traits::template arg<i>::type;
  317. // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
  318. // need a +1 offset to get the input
  319. result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
  320. }
  321. };
  322. template<typename func_t, typename array_t>
  323. inline int can_vectorize_up_to(array_t pointers) {
  324. using traits = function_traits<func_t>;
  325. using return_t = typename traits::result_type;
  326. constexpr int arity = traits::arity;
  327. int result = can_vectorize_up_to<return_t>(pointers[0]);
  328. // We need to get the type for each argument of `func_t`, this can only
  329. // be done at compile time.
  330. detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
  331. return result;
  332. }
  333. }}} // namespace at::native::memory