#pragma once #include #include #include #include #include #include #include #include #include #include #include // References: // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/ namespace at { namespace native { namespace memory { namespace detail { // What does the `static_unroll` do? // // We want to do something like: // // using args_t = typename traits::ArgsTuple; // args_t args; // #pragma unroll // for (int i = 0; i < traits::arity; i++) { // std::get(args) = .... // } // // but unfortunately the above code does not work because // the template argument has to be a compile time constant // so `static_unroll` is created to simulate `#pragma unroll` // using template metaprogramming. template typename func, int end, int current=0> struct static_unroll { template static inline C10_HOST_DEVICE void with_args(Args&&... args) { func::apply(std::forward(args)...); static_unroll::with_args(args...); } }; template typename func, int end> struct static_unroll { template static inline C10_HOST_DEVICE void with_args(Args... args) {} }; // helper structs to be used with static_unroll to load arguments // one by one template struct vectorized_load_helper { template static __device__ void apply(policy_t &self, args_t *args, int idx) { using arg_t = std::tuple_element_t; // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we // need a +1 offset to get the input auto ptr = reinterpret_cast(self.data[arg_index + 1]) + block_work_size() * idx; auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get(args[thread_unroll_idx]); }; self.load_single_arg(args_accessor, ptr); } }; template struct unroll_load_helper { template static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) { using arg_t = std::tuple_element_t; // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we // need a +1 offset to get the input std::get(args[j]) = loader.template load(self.data[arg_index + num_outputs], offset[arg_index], arg_index); } }; template struct multi_outputs_store_helper { template C10_HOST_DEVICE static void apply( at::detail::Array data, at::detail::Array offsets, thrust::tuple ret) { using T = typename thrust::tuple_element>::type; T *to = reinterpret_cast(data[current]) + offsets[current]; *to = thrust::get(ret); } }; } // namespace detail struct LoadWithoutCast { template __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { return c10::load(reinterpret_cast(base_ptr) + offset); } }; template struct LoadWithCast { using array_t = at::detail::Array(N, 1)>; using size_array_t = at::detail::Array(N, 1)>; array_t dtypes; size_array_t element_sizes; LoadWithCast(const TensorIteratorBase& iter) { assert(iter.ninputs() == N); #pragma unroll for (auto i = 0; i < N; ++i) { this->dtypes[i] = iter.dtype(i + iter.noutputs()); element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs())); } } template __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) { void *ptr = base_ptr + element_sizes[arg] * offset; return c10::fetch_and_cast(dtypes[arg], ptr); } }; struct StoreWithoutCast { template __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { *(reinterpret_cast(base_ptr) + offset) = value; } }; template struct StoreWithCast { using array_t = at::detail::Array(N, 1)>; using size_array_t = at::detail::Array(N, 1)>; array_t dtypes; size_array_t element_sizes; StoreWithCast(const TensorIteratorBase& iter) { assert(iter.noutputs() == N); #pragma unroll for (auto i = 0; i < N; ++i) { this->dtypes[i] = iter.dtype(i); element_sizes[i] = c10::elementSize(iter.dtype(i)); } } template __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) { void *ptr = base_ptr + element_sizes[arg] * offset; c10::cast_and_store(dtypes[arg], ptr, value); } }; // aligned vector generates vectorized load/store on CUDA template struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { scalar_t val[vec_size]; }; template __device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { using vec_t = aligned_vector; auto *from = reinterpret_cast(base_ptr); return from[offset]; } template __device__ aligned_vector load_vector(const bool *base_ptr, uint32_t offset) { // See NOTE [Loading boolean values] auto tmp = load_vector(reinterpret_cast(base_ptr), offset); aligned_vector ret; for (int i = 0; i < vec_size; ++i) { ret.val[i] = bool(tmp.val[i]); } return ret; } namespace policies { // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors template struct unroll { data_t data; int remaining; inp_calc_t input_offset_calculator; out_calc_t output_offset_calculator; loader_t loader; storer_t storer; __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s): data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {} __device__ inline bool check_inbounds(int thread_work_elem) { return ((threadIdx.x + thread_work_elem*num_threads()) < remaining); } template __device__ inline void load(args_t *args, int idx) { constexpr int arity = std::tuple_size::value; int thread_idx = threadIdx.x; #pragma unroll for (int i = 0; i < thread_work_size(); i++) { if (thread_idx >= remaining) { return; } int linear_idx = thread_idx + block_work_size() * idx; auto offset = input_offset_calculator.get(linear_idx); detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); thread_idx += num_threads(); } } template __device__ inline void store(scalar_t *from, int idx) { int thread_idx = threadIdx.x; scalar_t *to = reinterpret_cast(data[0]) + block_work_size() * idx; #pragma unroll for (int i = 0; i < thread_work_size(); i++) { if (thread_idx >= remaining) { return; } int linear_idx = thread_idx + block_work_size() * idx; int offset = output_offset_calculator.get(linear_idx)[0]; storer.store(from[i], data[0], offset); thread_idx += num_threads(); } } }; // Assumption: // all tensors are contiguous, that is: stride == sizeof(type) for all tensors // Note: // Functions in vectorized policy does not do boundary check. It assumes the whole block // has its job to do. So the reminders should be handled by the the caller manually. template // vec_size: number of scalars, can be 1, 2, or 4. struct vectorized { static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size"); static constexpr int loop_size = thread_work_size() / vec_size; data_t data; __device__ vectorized(data_t data) : data(data) {} __device__ inline constexpr bool check_inbounds(int thread_work_elem) { return true; } template __device__ inline void load_single_arg(accessor_t to, scalar_t *from) { int thread_idx = threadIdx.x; #pragma unroll for (int i = 0; i < loop_size; i++) { int index = thread_idx + i * num_threads(); auto v = load_vector(from, index); #pragma unroll for (int j = 0; j < vec_size; j++) { to(vec_size * i + j) = v.val[j]; } } } template __device__ inline void load(args_t *args, int idx) { constexpr int arity = std::tuple_size::value; detail::static_unroll::with_args(*this, args, idx); } template __device__ inline void store(scalar_t *from, int idx) { using vec_t = aligned_vector; scalar_t *to = reinterpret_cast(data[0]) + block_work_size() * idx; vec_t *to_ = reinterpret_cast(to); int thread_idx = threadIdx.x; #pragma unroll for (int i = 0; i < loop_size; i++) { int index = thread_idx + i * num_threads(); vec_t v; for (int j = 0; j < vec_size; j++) { v.val[j] = from[vec_size * i + j]; } to_[index] = v; } } }; template struct multi_outputs_unroll { //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct //we don't use inheritance because of compiler bug in cuda 10.2+ data_t data; int remaining; inp_calc_t input_offset_calculator; out_calc_t output_offset_calculator; LoadWithoutCast loader; StoreWithoutCast storer; __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc): data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {} __device__ inline bool check_inbounds(int thread_work_elem) { return ((threadIdx.x + thread_work_elem*num_threads()) < remaining); } template __device__ inline void load(args_t *args, int idx) { constexpr int arity = std::tuple_size::value; int thread_idx = threadIdx.x; #pragma unroll for (int i = 0; i < thread_work_size(); i++) { if (thread_idx >= remaining) { return; } int linear_idx = thread_idx + block_work_size() * idx; auto offset = input_offset_calculator.get(linear_idx); detail::static_unroll::with_args(*this, args, offset, loader, i, num_outputs); thread_idx += num_threads(); } } template __device__ inline void store(return_t *from, int idx) { int thread_idx = threadIdx.x; #pragma unroll for (int i = 0; i < thread_work_size(); i++) { if (thread_idx >= this->remaining) { return; } int linear_idx = thread_idx + block_work_size() * idx; auto offsets = this->output_offset_calculator.get(linear_idx); memory::detail::static_unroll::with_args(this->data, offsets, from[i]); thread_idx += num_threads(); } } }; } // namespace policies // This is only used in host, but we will wrap this into some templates // which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE // in order to compile template inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) { uint64_t address = reinterpret_cast(pointer); constexpr int vec2_alignment = std::alignment_of>::value; constexpr int vec4_alignment = std::alignment_of>::value; if (address % vec4_alignment == 0) { return 4; } else if (address % vec2_alignment == 0) { return 2; } return 1; } template struct can_vectorize_up_to_helper { template static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) { using arg_t = typename traits::template arg::type; // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we // need a +1 offset to get the input result = std::min(result, can_vectorize_up_to(pointers[i + 1])); } }; template inline int can_vectorize_up_to(array_t pointers) { using traits = function_traits; using return_t = typename traits::result_type; constexpr int arity = traits::arity; int result = can_vectorize_up_to(pointers[0]); // We need to get the type for each argument of `func_t`, this can only // be done at compile time. detail::static_unroll::with_args(result, pointers, traits()); return result; } }}} // namespace at::native::memory