TensorMeta.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #pragma once
  2. #include <ATen/DimVector.h>
  3. #include <ATen/core/Dimname.h>
  4. #include <c10/core/TensorOptions.h>
  5. #include <c10/util/strides.h>
  6. C10_CLANG_DIAGNOSTIC_PUSH()
  7. #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
  8. C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
  9. #endif
  10. namespace at {
  11. class Tensor;
  12. namespace impl {
  13. // Use this to define the prototype for a meta function. There are two
  14. // versions; one that takes one argument (just the operator name), or FUNC2
  15. // variant that takes two arguments (operator name and overload name).
  16. //
  17. // Example usage:
  18. //
  19. // TORCH_META_FUNC2(add, Tensor) (
  20. // const Tensor& self, const Tensor& other
  21. // ) {
  22. // ... compute sizes and options ...
  23. // set_output(sizes, options);
  24. // }
  25. //
  26. #define TORCH_META_FUNC(name) void structured_##name::meta
  27. #define TORCH_META_FUNC2(name, overload) \
  28. void structured_##name##_##overload::meta
  29. // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
  30. // as a return value. They should be used when the kernel in question has
  31. // precomputed values declared in native_functions.yaml and the corresponding
  32. // implementation should return an instance of the aforementioned struct.
  33. #define TORCH_PRECOMPUTE_META_FUNC(name) \
  34. structured_##name::meta_return_ty structured_##name::meta
  35. #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
  36. structured_##name##_##overload::meta_return_ty \
  37. structured_##name##_##overload::meta
  38. // Use this to create a precompute struct in a meta function.
  39. #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
  40. #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
  41. structured_##name##_##overload::precompute_out<>
  42. // Use this to define the prototype for an implementation. This takes only
  43. // one argument, which is the name of the dispatch key entry you're
  44. // implementing.
  45. //
  46. // Example usage:
  47. //
  48. // TORCH_IMPL_FUNC(add_cpu) (
  49. // Tensor& result, const Tensor& self, const Tensor& other
  50. // ) {
  51. // ... do the actual implementation ...
  52. // }
  53. //
  54. #define TORCH_IMPL_FUNC(name) void structured_##name::impl
  55. // Base class for all structured kernel classes. The set_output virtual
  56. // method is varied depending whether or not the operator is
  57. // functional/out/inplace, and could also be specialized for CPU/CUDA/etc
  58. // (although presently it isn't).
  59. //
  60. // A notable subclass of this interface is TensorIteratorBase.
  61. struct TORCH_API MetaBase {
  62. virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
  63. // Note: [set_output_*]
  64. // See: https://github.com/pytorch/pytorch/issues/69813
  65. // Whenever defining the output properties in the META function of a
  66. // structured kernel (what was usually done with `set_output`), use one of
  67. // these 3 variants, instead. In order to decide which variant to use, check
  68. // the following decision tree:
  69. //
  70. // - Can the kernel you are going to implement support output tensors
  71. // with arbitrary strides?
  72. // |
  73. // -- YES: `set_output_raw_strided`
  74. // |
  75. // -- NO: Should the output tensor strides be contiguous?
  76. // |
  77. // -- YES: `set_output_contiguous`
  78. // |
  79. // -- NO: `set_output_strided`
  80. //
  81. // Use this function whenever the kernel requires specific strides for the
  82. // output. If `strides` does not match the given output strides, proxy outputs
  83. // will be created and passed to the IMPL function.
  84. virtual void set_output_strided(
  85. int64_t output_idx,
  86. IntArrayRef sizes,
  87. IntArrayRef strides,
  88. TensorOptions options,
  89. DimnameList names = {}) {
  90. TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
  91. }
  92. // Use this function whenever the kernel knows how to handle arbitrary strided
  93. // outputs. This function has the same behavior as the old `set_output`: it
  94. // will only re-stride if the given output was resized.
  95. virtual void set_output_raw_strided(
  96. int64_t output_idx,
  97. IntArrayRef sizes,
  98. IntArrayRef strides_hint,
  99. TensorOptions options,
  100. DimnameList names = {}) {
  101. TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
  102. }
  103. // Use this function if the kernel requires contiguous strides.
  104. // Alias for `set_output_strided`, but with contiguous strides.
  105. void set_output_contiguous(
  106. int64_t output_idx,
  107. IntArrayRef sizes,
  108. TensorOptions options,
  109. DimnameList names = {}) {
  110. auto strides = c10::contiguous_strides(sizes);
  111. set_output_strided(output_idx, sizes, strides, options, names);
  112. }
  113. // Returns a reference to an undefined tensor if there is no presupplied
  114. // output
  115. const Tensor& maybe_get_output() {
  116. return maybe_get_output(0);
  117. }
  118. virtual ~MetaBase() = default;
  119. };
  120. } // namespace impl
  121. } // namespace at
  122. C10_CLANG_DIAGNOSTIC_POP()