QTensorImpl.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #pragma once
  2. #include <ATen/quantized/Quantizer.h>
  3. #include <c10/core/TensorImpl.h>
  4. #include <c10/util/Exception.h>
  5. namespace at {
  6. /**
  7. * QTensorImpl is a TensorImpl for Quantized Tensors, it stores Quantizer which
  8. * specifies the quantization scheme and parameters, for more information please
  9. * see ATen/quantized/Quantizer.h
  10. *
  11. * We'll use QTensor in code or documentation to refer to a Tensor with QTensorImpl.
  12. */
  13. struct TORCH_API QTensorImpl : public c10::TensorImpl {
  14. public:
  15. QTensorImpl(
  16. Storage&& storage,
  17. DispatchKeySet key_set,
  18. const caffe2::TypeMeta data_type,
  19. QuantizerPtr quantizer);
  20. // See Note [Enum ImplType]
  21. QTensorImpl(
  22. ImplType type,
  23. Storage&& storage,
  24. DispatchKeySet key_set,
  25. const caffe2::TypeMeta data_type,
  26. QuantizerPtr quantizer);
  27. // TODO: Expose in PyTorch Frontend
  28. QuantizerPtr quantizer() {
  29. return quantizer_;
  30. }
  31. void set_quantizer_(QuantizerPtr quantizer) {
  32. quantizer_ = quantizer;
  33. }
  34. /**
  35. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  36. *
  37. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  38. * see NOTE [ TensorImpl Shallow-Copying ].
  39. */
  40. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  41. const c10::VariableVersion& version_counter,
  42. bool allow_tensor_metadata_change) const override {
  43. auto impl = c10::make_intrusive<QTensorImpl>(
  44. Storage(storage()), key_set(), data_type_, quantizer_);
  45. copy_tensor_metadata(
  46. /*src_impl=*/this,
  47. /*dest_impl=*/impl.get(),
  48. /*version_counter=*/version_counter,
  49. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  50. impl->refresh_numel();
  51. impl->refresh_contiguous();
  52. return impl;
  53. }
  54. /**
  55. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  56. *
  57. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  58. * see NOTE [ TensorImpl Shallow-Copying ].
  59. */
  60. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  61. c10::VariableVersion&& version_counter,
  62. bool allow_tensor_metadata_change) const override {
  63. auto impl = c10::make_intrusive<QTensorImpl>(
  64. Storage(storage()), key_set(), data_type_, quantizer_);
  65. copy_tensor_metadata(
  66. /*src_impl=*/this,
  67. /*dest_impl=*/impl.get(),
  68. /*version_counter=*/std::move(version_counter),
  69. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  70. impl->refresh_numel();
  71. impl->refresh_contiguous();
  72. return impl;
  73. }
  74. /**
  75. * Shallow-copies data from another TensorImpl into this TensorImpl.
  76. *
  77. * For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
  78. * see NOTE [ TensorImpl Shallow-Copying ].
  79. */
  80. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  81. AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
  82. auto q_impl = static_cast<const QTensorImpl*>(impl.get());
  83. copy_tensor_metadata(
  84. /*src_impl=*/q_impl,
  85. /*dest_impl=*/this,
  86. /*version_counter=*/version_counter(),
  87. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  88. refresh_numel();
  89. refresh_contiguous();
  90. }
  91. private:
  92. QuantizerPtr quantizer_;
  93. const char* tensorimpl_type_name() const override;
  94. /**
  95. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset)
  96. * from one TensorImpl to another TensorImpl.
  97. *
  98. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE [ TensorImpl Shallow-Copying ].
  99. */
  100. static void copy_tensor_metadata(
  101. const QTensorImpl* src_q_impl,
  102. QTensorImpl* dest_q_impl,
  103. const c10::VariableVersion& version_counter,
  104. bool allow_tensor_metadata_change) {
  105. TensorImpl::copy_tensor_metadata(src_q_impl, dest_q_impl, version_counter, allow_tensor_metadata_change);
  106. // OpaqueTensorImpl-specific fields.
  107. dest_q_impl->quantizer_ = src_q_impl->quantizer_;
  108. }
  109. };
  110. } // namespace at