EmbeddingPackedParams.h 921 B

1234567891011121314151617181920212223242526272829
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/core/ivalue.h>
  4. struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
  5. virtual at::Tensor embeddingbag_byte(
  6. const at::Tensor& indices,
  7. const c10::optional<at::Tensor>& offsets,
  8. bool pruned_weights,
  9. const c10::optional<at::Tensor>& per_sample_weights_,
  10. const c10::optional<at::Tensor>& compressed_indices_mapping,
  11. bool include_last_offset,
  12. bool is_embedding_op) = 0;
  13. virtual at::Tensor embeddingbag_4bit(
  14. const at::Tensor& indices,
  15. const c10::optional<at::Tensor>& offsets,
  16. bool pruned_weights,
  17. const c10::optional<at::Tensor>& per_sample_weights_,
  18. const c10::optional<at::Tensor>& compressed_indices_mapping,
  19. bool include_last_offset,
  20. bool is_embedding_op) = 0;
  21. virtual at::Tensor unpack() = 0;
  22. virtual int64_t bit_rate() const = 0;
  23. virtual int64_t version() const = 0;
  24. };