123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- #include <ATen/core/Tensor.h>
- #include <ATen/Config.h>
- #include <cstdint>
- #ifdef USE_FBGEMM
- #include <fbgemm/FbgemmEmbedding.h>
- #endif
- namespace at {
- namespace native {
- void check_arguments(
- const Tensor& weight,
- const Tensor& indices,
- const Tensor& offsets,
- const int64_t mode,
- const c10::optional<Tensor>& per_sample_weights,
- bool include_last_offset);
- void make_bag_size_out(
- Tensor& bag_size_out,
- const Tensor& offsets,
- const Tensor& indices,
- const int64_t mode,
- const bool include_last_offset,
- const bool requires_grad);
- void make_max_indices_out(
- Tensor& max_indices_out,
- const Tensor& weight,
- const Tensor& indices,
- const Tensor& offsets,
- const Tensor& bag_size,
- const int64_t mode,
- bool include_last_offset);
- void make_offset2bag_out(
- Tensor& offset2bag,
- Tensor& output,
- const Tensor& weight,
- const Tensor& indices,
- const Tensor& offsets,
- const int64_t mode,
- const c10::optional<Tensor>& per_sample_weights,
- const int64_t padding_idx = -1);
- #ifdef USE_FBGEMM
- template<bool has_weight, typename TIndex, typename TData>
- struct _CallbackAndBlockSize {
- using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
- int64_t blockSize = -1;
- TCallback callback = nullptr;
- static TCallback generateCallback(int64_t block_size) {
- return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
- block_size,
- has_weight,
- /* normalize_by_lengths */false,
- /* prefetch */16,
- /* is_weight_positional */false,
- /* use_offsets */true);
- }
- _CallbackAndBlockSize() = default;
- explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size)
- : blockSize(maybe_block_size.value_or(-1))
- , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
- {}
- };
- template<typename... StorageMixins>
- struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
- _EmbeddingBagKernelCacheImpl() = default;
- // use each of the mixins to store corresponding kernel and block size
- explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size)
- : StorageMixins(maybe_block_size)...
- {}
- // this method is thread safe (call sites may call from different threads)
- template<bool has_weight, typename TIndex, typename TData>
- typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
- getCallback(int64_t block_size) const {
- // if the cache doesn't store the kernel for the incoming block size
- // (so it is different from the one stored in corresponding mixin)
- // regenerate the kernel (not writing it into the cache so we avoid locks)
- if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
- return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
- }
- // else retrieve the cached kernel from the corresponding mixin
- return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
- }
- };
- // instantiate the cache with the list of storage mixins
- // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
- using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
- _CallbackAndBlockSize<true, int32_t, float>,
- _CallbackAndBlockSize<false, int32_t, float>,
- _CallbackAndBlockSize<true, int64_t, float>,
- _CallbackAndBlockSize<false, int64_t, float>,
- _CallbackAndBlockSize<true, int32_t, unsigned short>,
- _CallbackAndBlockSize<false, int32_t, unsigned short>,
- _CallbackAndBlockSize<true, int64_t, unsigned short>,
- _CallbackAndBlockSize<false, int64_t, unsigned short>>;
- #else
- struct _EmbeddingBagKernelCache {
- explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
- };
- #endif
- void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
- Tensor& bag_size, Tensor* max_indices,
- const Tensor &weight, const Tensor &indices,
- const Tensor &offsets, const int64_t mode = 0,
- const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
- bool include_last_offset = false,
- int64_t padding_idx = -1,
- _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
- void _embedding_bag_cpu_out(
- at::Tensor& output,
- at::Tensor& offset2bag,
- at::Tensor& bag_size,
- at::Tensor* p_max_indices,
- const at::Tensor& weight,
- const at::Tensor& indices,
- const at::Tensor& offsets,
- const bool scale_grad_by_freq,
- const int64_t mode,
- const bool sparse,
- const c10::optional<at::Tensor>& per_sample_weights,
- const bool include_last_offset,
- const c10::optional<int64_t>& padding_idx,
- _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
- } // native
- } // at
|