EmbeddingBag.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/Config.h>
  3. #include <cstdint>
  4. #ifdef USE_FBGEMM
  5. #include <fbgemm/FbgemmEmbedding.h>
  6. #endif
  7. namespace at {
  8. namespace native {
  9. void check_arguments(
  10. const Tensor& weight,
  11. const Tensor& indices,
  12. const Tensor& offsets,
  13. const int64_t mode,
  14. const c10::optional<Tensor>& per_sample_weights,
  15. bool include_last_offset);
  16. void make_bag_size_out(
  17. Tensor& bag_size_out,
  18. const Tensor& offsets,
  19. const Tensor& indices,
  20. const int64_t mode,
  21. const bool include_last_offset,
  22. const bool requires_grad);
  23. void make_max_indices_out(
  24. Tensor& max_indices_out,
  25. const Tensor& weight,
  26. const Tensor& indices,
  27. const Tensor& offsets,
  28. const Tensor& bag_size,
  29. const int64_t mode,
  30. bool include_last_offset);
  31. void make_offset2bag_out(
  32. Tensor& offset2bag,
  33. Tensor& output,
  34. const Tensor& weight,
  35. const Tensor& indices,
  36. const Tensor& offsets,
  37. const int64_t mode,
  38. const c10::optional<Tensor>& per_sample_weights,
  39. const int64_t padding_idx = -1);
  40. #ifdef USE_FBGEMM
  41. template<bool has_weight, typename TIndex, typename TData>
  42. struct _CallbackAndBlockSize {
  43. using TCallback = typename fbgemm::EmbeddingSpMDMKernelSignature<TData, TIndex, TIndex, TData>::Type;
  44. int64_t blockSize = -1;
  45. TCallback callback = nullptr;
  46. static TCallback generateCallback(int64_t block_size) {
  47. return fbgemm::GenerateEmbeddingSpMDM<TData, TIndex, TIndex, TData>(
  48. block_size,
  49. has_weight,
  50. /* normalize_by_lengths */false,
  51. /* prefetch */16,
  52. /* is_weight_positional */false,
  53. /* use_offsets */true);
  54. }
  55. _CallbackAndBlockSize() = default;
  56. explicit _CallbackAndBlockSize(c10::optional<int64_t> maybe_block_size)
  57. : blockSize(maybe_block_size.value_or(-1))
  58. , callback(maybe_block_size.has_value() ? generateCallback(maybe_block_size.value()) : nullptr)
  59. {}
  60. };
  61. template<typename... StorageMixins>
  62. struct _EmbeddingBagKernelCacheImpl : private StorageMixins... {
  63. _EmbeddingBagKernelCacheImpl() = default;
  64. // use each of the mixins to store corresponding kernel and block size
  65. explicit _EmbeddingBagKernelCacheImpl(c10::optional<int64_t> maybe_block_size)
  66. : StorageMixins(maybe_block_size)...
  67. {}
  68. // this method is thread safe (call sites may call from different threads)
  69. template<bool has_weight, typename TIndex, typename TData>
  70. typename _CallbackAndBlockSize<has_weight, TIndex, TData>::TCallback
  71. getCallback(int64_t block_size) const {
  72. // if the cache doesn't store the kernel for the incoming block size
  73. // (so it is different from the one stored in corresponding mixin)
  74. // regenerate the kernel (not writing it into the cache so we avoid locks)
  75. if (block_size != _CallbackAndBlockSize<has_weight, TIndex, TData>::blockSize) {
  76. return _CallbackAndBlockSize<has_weight, TIndex, TData>::generateCallback(block_size);
  77. }
  78. // else retrieve the cached kernel from the corresponding mixin
  79. return _CallbackAndBlockSize<has_weight, TIndex, TData>::callback;
  80. }
  81. };
  82. // instantiate the cache with the list of storage mixins
  83. // for each of the 8 _EmbeddingBagKernelCache* usages in the EmbeddingBag.cpp impl file
  84. using _EmbeddingBagKernelCache = _EmbeddingBagKernelCacheImpl<
  85. _CallbackAndBlockSize<true, int32_t, float>,
  86. _CallbackAndBlockSize<false, int32_t, float>,
  87. _CallbackAndBlockSize<true, int64_t, float>,
  88. _CallbackAndBlockSize<false, int64_t, float>,
  89. _CallbackAndBlockSize<true, int32_t, unsigned short>,
  90. _CallbackAndBlockSize<false, int32_t, unsigned short>,
  91. _CallbackAndBlockSize<true, int64_t, unsigned short>,
  92. _CallbackAndBlockSize<false, int64_t, unsigned short>>;
  93. #else
  94. struct _EmbeddingBagKernelCache {
  95. explicit _EmbeddingBagKernelCache(c10::optional<int64_t> /* maybe_block_size */) {}
  96. };
  97. #endif
  98. void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag,
  99. Tensor& bag_size, Tensor* max_indices,
  100. const Tensor &weight, const Tensor &indices,
  101. const Tensor &offsets, const int64_t mode = 0,
  102. const c10::optional<Tensor>& per_sample_weights = c10::nullopt,
  103. bool include_last_offset = false,
  104. int64_t padding_idx = -1,
  105. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  106. void _embedding_bag_cpu_out(
  107. at::Tensor& output,
  108. at::Tensor& offset2bag,
  109. at::Tensor& bag_size,
  110. at::Tensor* p_max_indices,
  111. const at::Tensor& weight,
  112. const at::Tensor& indices,
  113. const at::Tensor& offsets,
  114. const bool scale_grad_by_freq,
  115. const int64_t mode,
  116. const bool sparse,
  117. const c10::optional<at::Tensor>& per_sample_weights,
  118. const bool include_last_offset,
  119. const c10::optional<int64_t>& padding_idx,
  120. _EmbeddingBagKernelCache* fbgemm_kernel_cache = nullptr);
  121. } // native
  122. } // at