MathBitsFallback.h 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. #include <ATen/core/Tensor.h>
  2. #include <ATen/core/dispatch/Dispatcher.h>
  3. #include <ATen/core/op_registration/op_registration.h>
  4. #include <ATen/native/UnaryOps.h>
  5. #include <ATen/native/Resize.h>
  6. #include <c10/util/irange.h>
  7. #include <torch/library.h>
  8. #ifndef AT_PER_OPERATOR_HEADERS
  9. #include <ATen/Functions.h>
  10. #else
  11. #include <ATen/ops/clone.h>
  12. #include <utility>
  13. #endif
  14. namespace at {
  15. namespace native {
  16. // This fallback should only be used for operations that are self inverse and have a corresponding tensor
  17. // bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
  18. // Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
  19. // Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.
  20. // NOTE: To use this fallback, `clone` and `copy_` should fully understand and be able to correctly handle the semantic of your math bit.
  21. struct MathOpFallback {
  22. MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(std::move(op_name_)) {}
  23. virtual bool is_bit_set(const Tensor&) = 0;
  24. void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
  25. /*
  26. Situations to handle:
  27. 1. Out-of-place operation. Easy: materialize all inputs and
  28. call it a day.
  29. 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
  30. Materialize other inputs as in (1).
  31. 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
  32. Materialize other inputs as in (1).
  33. It is important to be able to tell if we READ from an argument and if we
  34. WRITE to an argument. Conservative approach is to assume that we always
  35. READ from an argument, but in out= operations you can skip
  36. conjugating inputs on entry that never get used. In the current schema we
  37. can't easily tell if the operation is in in-place or out= operation.
  38. Note:
  39. 1. Mutable tensorlists containing tensors whose math bit set to true are disallowed.
  40. 2. Mutable tensors with math bit set to true are unconditionally cloned to ensure
  41. correct behavior in the case when the mutable tensor shares memory with non mutable arguments.
  42. If we were to in-place resolve the math bit for mutable inputs, then the non-mutable inputs sharing partial or full memory
  43. with these mutable inputs would read into wrong values in the following cases:
  44. 1. Non mutable inputs have their math bit set to false.
  45. 2. Math bit for mutable input(s) is resolved before the non mutable inputs (with bit set to true and sharing memory
  46. with one or more mutable arg(s)) are cloned.
  47. At the end, the final value of the mutable arguments from the stack are copied into the original input mutable tensor inputs.
  48. */
  49. const auto& arguments = op.schema().arguments();
  50. const auto num_arguments = arguments.size();
  51. const auto stack_start = stack->size() - num_arguments;
  52. c10::optional<bool> is_write;
  53. for (const auto i : c10::irange(num_arguments)) {
  54. // Three possible states:
  55. // 1. alias_info has no value --> out-of-place operation
  56. // 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
  57. // 3. alias_info does have a value, alias_info->is_write=False --> view operation
  58. const AliasInfo* alias_info = arguments[i].alias_info();
  59. if (alias_info != nullptr) {
  60. if (is_write.has_value()) {
  61. TORCH_CHECK(*is_write == alias_info->isWrite(),
  62. "Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
  63. op_name, " fallback doesn't work for operators with a mix "
  64. "mutable and non-mutable inputs that alias with outputs, "
  65. "this must be implemented manually. "
  66. "If you got this error on a core op, please report a bug to PyTorch.");
  67. } else {
  68. is_write = alias_info->isWrite();
  69. }
  70. }
  71. }
  72. if (is_write.has_value() && !*is_write) {
  73. // We assume that view operators automatically handle the math bit
  74. // correctly by propagating the dispatch key in key_set.
  75. // This is not necessarily always right, so you should test these cases.
  76. op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
  77. return;
  78. }
  79. // Mutable inputs with math bit set to True and their clones
  80. std::vector<std::pair<Tensor, Tensor>> mutable_inputs_with_their_clones;
  81. for (const auto i : c10::irange(num_arguments)) {
  82. auto& ivalue = (*stack)[stack_start + i];
  83. if (!(ivalue.isTensor() || ivalue.isTensorList())) {
  84. continue;
  85. }
  86. const auto& argument = arguments[i];
  87. bool mut_arg = false;
  88. if (argument.alias_info()) {
  89. // Was already tested by is_write loop above
  90. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
  91. mut_arg = true;
  92. }
  93. if (ivalue.isTensor()) {
  94. if (!is_bit_set(ivalue.toTensor())) {
  95. continue;
  96. }
  97. auto tensor = std::move(ivalue).toTensor();
  98. auto resolved_tensor = at::clone(tensor);
  99. if (mut_arg) {
  100. TORCH_CHECK(mutable_inputs_with_their_clones.empty(), op_name, " fallback does not support operators with more than one mutable tensors with ",
  101. op_name, "bit set to true.");
  102. mutable_inputs_with_their_clones.emplace_back(std::move(tensor), resolved_tensor);
  103. }
  104. (*stack)[stack_start + i] = std::move(resolved_tensor);
  105. } else if (ivalue.isTensorList()) {
  106. auto tensors = std::move(ivalue).toTensorList();
  107. for(const auto j : c10::irange(tensors.size())) {
  108. const auto& tensor = tensors[j];
  109. if (!is_bit_set(tensor)) {
  110. continue;
  111. }
  112. TORCH_CHECK(!mut_arg, " fallback doesn't currently support mutable TensorLists with ",
  113. op_name, " inputs. Please materialize all the ", op_name, " input tensor(s) in the mutable TensorList inputs before calling ",
  114. op.schema().name());
  115. tensors[j] = at::clone(tensor);
  116. }
  117. (*stack)[stack_start + i] = std::move(tensors);
  118. }
  119. }
  120. op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
  121. TORCH_INTERNAL_ASSERT(mutable_inputs_with_their_clones.size() <= 1);
  122. for (std::pair<Tensor, Tensor> mut_tensors: mutable_inputs_with_their_clones) {
  123. auto& mutable_input = mut_tensors.first;
  124. auto& cloned_mutable_input = mut_tensors.second;
  125. auto& ivalue = (*stack)[stack_start];
  126. auto returned_output = std::move(ivalue).toTensor();
  127. // sanity check to ensure that the tensor in stack aliases the cloned_mutable_input
  128. TORCH_INTERNAL_ASSERT(cloned_mutable_input.is_same(returned_output));
  129. // necessary for out= arg
  130. at::native::resize_output(mutable_input, returned_output.sizes());
  131. mutable_input.copy_(returned_output);
  132. (*stack)[stack_start] = std::move(mutable_input);
  133. }
  134. }
  135. virtual ~MathOpFallback() = default;
  136. DispatchKey key;
  137. string op_name;
  138. };
  139. }
  140. }// namespace at