qembeddingbag.h 1020 B

12345678910111213141516171819202122232425262728293031323334
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <cstdint>
  4. namespace at {
  5. namespace native {
  6. Tensor& embedding_bag_byte_rowwise_offsets_out(
  7. Tensor& output,
  8. const Tensor& weight,
  9. const Tensor& indices,
  10. const c10::optional<Tensor>& offsets_in,
  11. const bool /* scale_grad_by_freq */,
  12. const int64_t /* mode */,
  13. bool pruned_weights,
  14. const c10::optional<Tensor>& per_sample_weights_,
  15. const c10::optional<Tensor>& compressed_indices_mapping,
  16. bool include_last_offset);
  17. Tensor& embedding_bag_4bit_rowwise_offsets_out(
  18. Tensor& output,
  19. const Tensor& weight,
  20. const Tensor& indices,
  21. const c10::optional<Tensor>& offsets_in,
  22. const bool /* scale_grad_by_freq */,
  23. const int64_t /* mode */,
  24. bool pruned_weights,
  25. const c10::optional<Tensor>& per_sample_weights_,
  26. const c10::optional<Tensor>& compressed_indices_mapping,
  27. bool include_last_offset);
  28. Tensor& qembeddingbag_byte_unpack_out(Tensor& output, const Tensor& packed_weight);
  29. } // native
  30. } // at