123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- #pragma once
- #include <cstdint>
- #include <type_traits>
- #include <c10/core/DynamicCast.h>
- #include <c10/util/Exception.h>
- #include <c10/util/TypeCast.h>
- #include <c10/macros/Macros.h>
- #include <ATen/core/Array.h>
- #include <ATen/detail/FunctionTraits.h>
- #include <ATen/cuda/detail/OffsetCalculator.cuh>
- #include <ATen/native/cuda/thread_constants.h>
- #include <thrust/tuple.h>
- // References:
- // https://devblogs.nvidia.com/cuda-pro-tip-increase-performance-with-vectorized-memory-access/
- namespace at { namespace native { namespace memory {
- namespace detail {
- // What does the `static_unroll` do?
- //
- // We want to do something like:
- //
- // using args_t = typename traits::ArgsTuple;
- // args_t args;
- // #pragma unroll
- // for (int i = 0; i < traits::arity; i++) {
- // std::get<i>(args) = ....
- // }
- //
- // but unfortunately the above code does not work because
- // the template argument has to be a compile time constant
- // so `static_unroll` is created to simulate `#pragma unroll`
- // using template metaprogramming.
- template<template<int i> typename func, int end, int current=0>
- struct static_unroll {
- template<typename... Args>
- static inline C10_HOST_DEVICE void with_args(Args&&... args) {
- func<current>::apply(std::forward<Args>(args)...);
- static_unroll<func, end, current+1>::with_args(args...);
- }
- };
- template<template<int i> typename func, int end>
- struct static_unroll<func, end, end> {
- template<typename... Args>
- static inline C10_HOST_DEVICE void with_args(Args... args) {}
- };
- // helper structs to be used with static_unroll to load arguments
- // one by one
- template<int arg_index>
- struct vectorized_load_helper {
- template <typename args_t, typename policy_t>
- static __device__ void apply(policy_t &self, args_t *args, int idx) {
- using arg_t = std::tuple_element_t<arg_index, args_t>;
- // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
- // need a +1 offset to get the input
- auto ptr = reinterpret_cast<arg_t *>(self.data[arg_index + 1]) + block_work_size() * idx;
- auto args_accessor = [&args] __device__ (int thread_unroll_idx) -> arg_t & { return std::get<arg_index>(args[thread_unroll_idx]); };
- self.load_single_arg(args_accessor, ptr);
- }
- };
- template<int arg_index>
- struct unroll_load_helper {
- template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
- static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
- using arg_t = std::tuple_element_t<arg_index, args_t>;
- // `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
- // need a +1 offset to get the input
- std::get<arg_index>(args[j]) = loader.template load<arg_t>(self.data[arg_index + num_outputs], offset[arg_index], arg_index);
- }
- };
- template <int current>
- struct multi_outputs_store_helper {
- template<int ntensors, int num_outputs, typename ...Args>
- C10_HOST_DEVICE static void apply(
- at::detail::Array<char*, ntensors> data,
- at::detail::Array<uint32_t, num_outputs> offsets,
- thrust::tuple<Args...> ret) {
- using T = typename thrust::tuple_element<current, thrust::tuple<Args...>>::type;
- T *to = reinterpret_cast<T *>(data[current]) + offsets[current];
- *to = thrust::get<current>(ret);
- }
- };
- } // namespace detail
- struct LoadWithoutCast {
- template<typename scalar_t>
- __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
- return c10::load(reinterpret_cast<scalar_t *>(base_ptr) + offset);
- }
- };
- template <int N>
- struct LoadWithCast {
- using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
- using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
- array_t dtypes;
- size_array_t element_sizes;
- LoadWithCast(const TensorIteratorBase& iter) {
- assert(iter.ninputs() == N);
- #pragma unroll
- for (auto i = 0; i < N; ++i) {
- this->dtypes[i] = iter.dtype(i + iter.noutputs());
- element_sizes[i] = c10::elementSize(iter.dtype(i + iter.noutputs()));
- }
- }
- template<typename scalar_t>
- __device__ scalar_t load(char *base_ptr, uint32_t offset, int arg) {
- void *ptr = base_ptr + element_sizes[arg] * offset;
- return c10::fetch_and_cast<scalar_t>(dtypes[arg], ptr);
- }
- };
- struct StoreWithoutCast {
- template<typename scalar_t>
- __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
- *(reinterpret_cast<scalar_t *>(base_ptr) + offset) = value;
- }
- };
- template <int N = 1>
- struct StoreWithCast {
- using array_t = at::detail::Array<at::ScalarType, std::max<int>(N, 1)>;
- using size_array_t = at::detail::Array<uint32_t, std::max<int>(N, 1)>;
- array_t dtypes;
- size_array_t element_sizes;
- StoreWithCast(const TensorIteratorBase& iter) {
- assert(iter.noutputs() == N);
- #pragma unroll
- for (auto i = 0; i < N; ++i) {
- this->dtypes[i] = iter.dtype(i);
- element_sizes[i] = c10::elementSize(iter.dtype(i));
- }
- }
- template<typename scalar_t>
- __device__ void store(scalar_t value, char *base_ptr, uint32_t offset, int arg = 0) {
- void *ptr = base_ptr + element_sizes[arg] * offset;
- c10::cast_and_store<scalar_t>(dtypes[arg], ptr, value);
- }
- };
- // aligned vector generates vectorized load/store on CUDA
- template<typename scalar_t, int vec_size>
- struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
- scalar_t val[vec_size];
- };
- template <int vec_size, typename scalar_t>
- __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
- using vec_t = aligned_vector<scalar_t, vec_size>;
- auto *from = reinterpret_cast<const vec_t *>(base_ptr);
- return from[offset];
- }
- template <int vec_size>
- __device__ aligned_vector<bool, vec_size> load_vector(const bool *base_ptr, uint32_t offset) {
- // See NOTE [Loading boolean values]
- auto tmp = load_vector<vec_size>(reinterpret_cast<const uint8_t*>(base_ptr), offset);
- aligned_vector<bool, vec_size> ret;
- for (int i = 0; i < vec_size; ++i) {
- ret.val[i] = bool(tmp.val[i]);
- }
- return ret;
- }
- namespace policies {
- // Assumption:
- // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
- template<typename data_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, int num_outputs = 1>
- struct unroll {
- data_t data;
- int remaining;
- inp_calc_t input_offset_calculator;
- out_calc_t output_offset_calculator;
- loader_t loader;
- storer_t storer;
- __device__ unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s):
- data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc), loader(l), storer(s) {}
- __device__ inline bool check_inbounds(int thread_work_elem) {
- return ((threadIdx.x + thread_work_elem*num_threads()) < remaining);
- }
- template<typename args_t>
- __device__ inline void load(args_t *args, int idx) {
- constexpr int arity = std::tuple_size<args_t>::value;
- int thread_idx = threadIdx.x;
- #pragma unroll
- for (int i = 0; i < thread_work_size(); i++) {
- if (thread_idx >= remaining) {
- return;
- }
- int linear_idx = thread_idx + block_work_size() * idx;
- auto offset = input_offset_calculator.get(linear_idx);
- detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
- thread_idx += num_threads();
- }
- }
- template<typename scalar_t>
- __device__ inline void store(scalar_t *from, int idx) {
- int thread_idx = threadIdx.x;
- scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
- #pragma unroll
- for (int i = 0; i < thread_work_size(); i++) {
- if (thread_idx >= remaining) {
- return;
- }
- int linear_idx = thread_idx + block_work_size() * idx;
- int offset = output_offset_calculator.get(linear_idx)[0];
- storer.store(from[i], data[0], offset);
- thread_idx += num_threads();
- }
- }
- };
- // Assumption:
- // all tensors are contiguous, that is: stride == sizeof(type) for all tensors
- // Note:
- // Functions in vectorized policy does not do boundary check. It assumes the whole block
- // has its job to do. So the reminders should be handled by the the caller manually.
- template <int vec_size, typename data_t> // vec_size: number of scalars, can be 1, 2, or 4.
- struct vectorized {
- static_assert(thread_work_size() % vec_size == 0, "The workload per thread must be a multiple of vec_size");
- static constexpr int loop_size = thread_work_size() / vec_size;
- data_t data;
- __device__ vectorized(data_t data) : data(data) {}
- __device__ inline constexpr bool check_inbounds(int thread_work_elem) {
- return true;
- }
- template<typename accessor_t, typename scalar_t>
- __device__ inline void load_single_arg(accessor_t to, scalar_t *from) {
- int thread_idx = threadIdx.x;
- #pragma unroll
- for (int i = 0; i < loop_size; i++) {
- int index = thread_idx + i * num_threads();
- auto v = load_vector<vec_size>(from, index);
- #pragma unroll
- for (int j = 0; j < vec_size; j++) {
- to(vec_size * i + j) = v.val[j];
- }
- }
- }
- template<typename args_t>
- __device__ inline void load(args_t *args, int idx) {
- constexpr int arity = std::tuple_size<args_t>::value;
- detail::static_unroll<detail::vectorized_load_helper, arity>::with_args(*this, args, idx);
- }
- template<typename scalar_t>
- __device__ inline void store(scalar_t *from, int idx) {
- using vec_t = aligned_vector<scalar_t, vec_size>;
- scalar_t *to = reinterpret_cast<scalar_t *>(data[0]) + block_work_size() * idx;
- vec_t *to_ = reinterpret_cast<vec_t *>(to);
- int thread_idx = threadIdx.x;
- #pragma unroll
- for (int i = 0; i < loop_size; i++) {
- int index = thread_idx + i * num_threads();
- vec_t v;
- for (int j = 0; j < vec_size; j++) {
- v.val[j] = from[vec_size * i + j];
- }
- to_[index] = v;
- }
- }
- };
- template <typename data_t, typename inp_calc_t, typename out_calc_t, int num_outputs>
- struct multi_outputs_unroll {
- //multi_outputs_unroll struct members and check_inbounds and load methods are copypasted from unroll struct
- //we don't use inheritance because of compiler bug in cuda 10.2+
- data_t data;
- int remaining;
- inp_calc_t input_offset_calculator;
- out_calc_t output_offset_calculator;
- LoadWithoutCast loader;
- StoreWithoutCast storer;
- __device__ multi_outputs_unroll(data_t data, int remaining, inp_calc_t ic, out_calc_t oc):
- data(data), remaining(remaining), input_offset_calculator(ic), output_offset_calculator(oc) {}
- __device__ inline bool check_inbounds(int thread_work_elem) {
- return ((threadIdx.x + thread_work_elem*num_threads()) < remaining);
- }
- template<typename args_t>
- __device__ inline void load(args_t *args, int idx) {
- constexpr int arity = std::tuple_size<args_t>::value;
- int thread_idx = threadIdx.x;
- #pragma unroll
- for (int i = 0; i < thread_work_size(); i++) {
- if (thread_idx >= remaining) {
- return;
- }
- int linear_idx = thread_idx + block_work_size() * idx;
- auto offset = input_offset_calculator.get(linear_idx);
- detail::static_unroll<detail::unroll_load_helper, arity>::with_args(*this, args, offset, loader, i, num_outputs);
- thread_idx += num_threads();
- }
- }
- template <typename return_t>
- __device__ inline void store(return_t *from, int idx) {
- int thread_idx = threadIdx.x;
- #pragma unroll
- for (int i = 0; i < thread_work_size(); i++) {
- if (thread_idx >= this->remaining) {
- return;
- }
- int linear_idx = thread_idx + block_work_size() * idx;
- auto offsets = this->output_offset_calculator.get(linear_idx);
- memory::detail::static_unroll<detail::multi_outputs_store_helper, num_outputs>::with_args(this->data, offsets, from[i]);
- thread_idx += num_threads();
- }
- }
- };
- } // namespace policies
- // This is only used in host, but we will wrap this into some templates
- // which is C10_HOST_DEVICE, so we have to make this C10_HOST_DEVICE
- // in order to compile
- template<typename scalar_t>
- inline C10_HOST_DEVICE int can_vectorize_up_to(char *pointer) {
- uint64_t address = reinterpret_cast<uint64_t>(pointer);
- constexpr int vec2_alignment = std::alignment_of<aligned_vector<scalar_t, 2>>::value;
- constexpr int vec4_alignment = std::alignment_of<aligned_vector<scalar_t, 4>>::value;
- if (address % vec4_alignment == 0) {
- return 4;
- } else if (address % vec2_alignment == 0) {
- return 2;
- }
- return 1;
- }
- template<int i>
- struct can_vectorize_up_to_helper {
- template <typename array_t, typename traits>
- static C10_HOST_DEVICE void apply(int &result, array_t pointers, traits _) {
- using arg_t = typename traits::template arg<i>::type;
- // `pointers` hold the data_ptr for tensors [output, input0, input1, ...], so we
- // need a +1 offset to get the input
- result = std::min<int>(result, can_vectorize_up_to<arg_t>(pointers[i + 1]));
- }
- };
- template<typename func_t, typename array_t>
- inline int can_vectorize_up_to(array_t pointers) {
- using traits = function_traits<func_t>;
- using return_t = typename traits::result_type;
- constexpr int arity = traits::arity;
- int result = can_vectorize_up_to<return_t>(pointers[0]);
- // We need to get the type for each argument of `func_t`, this can only
- // be done at compile time.
- detail::static_unroll<can_vectorize_up_to_helper, arity>::with_args(result, pointers, traits());
- return result;
- }
- }}} // namespace at::native::memory
|