1234567891011121314151617181920212223242526272829 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/core/ivalue.h>
- struct EmbeddingPackedParamsBase : public torch::jit::CustomClassHolder {
- virtual at::Tensor embeddingbag_byte(
- const at::Tensor& indices,
- const c10::optional<at::Tensor>& offsets,
- bool pruned_weights,
- const c10::optional<at::Tensor>& per_sample_weights_,
- const c10::optional<at::Tensor>& compressed_indices_mapping,
- bool include_last_offset,
- bool is_embedding_op) = 0;
- virtual at::Tensor embeddingbag_4bit(
- const at::Tensor& indices,
- const c10::optional<at::Tensor>& offsets,
- bool pruned_weights,
- const c10::optional<at::Tensor>& per_sample_weights_,
- const c10::optional<at::Tensor>& compressed_indices_mapping,
- bool include_last_offset,
- bool is_embedding_op) = 0;
- virtual at::Tensor unpack() = 0;
- virtual int64_t bit_rate() const = 0;
- virtual int64_t version() const = 0;
- };
|