ForeachUtils.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <c10/util/irange.h>
  4. #include <ATen/Dispatch.h>
  5. #ifndef AT_PER_OPERATOR_HEADERS
  6. #include <ATen/NativeFunctions.h>
  7. #else
  8. #include <ATen/ops/result_type_native.h>
  9. #endif
  10. namespace at {
  11. namespace native {
  12. namespace {
  13. // Check if tensor list has either a boolean tensor or a integer tensor
  14. bool has_integral_tensor(TensorList tensors, const bool includeBool) {
  15. return std::any_of(tensors.begin(), tensors.end(),
  16. [&includeBool](const auto & t) { return at::isIntegralType(t.scalar_type(), includeBool); });
  17. }
  18. // check if tensor list has bool tensors
  19. bool has_bool_tensor(TensorList tensors) {
  20. return std::any_of(tensors.begin(), tensors.end(),
  21. [](const auto & t) -> bool { return t.scalar_type() == ScalarType::Bool; });
  22. }
  23. // Check foreach API restrictions
  24. // - Tensor lists must be non-empty.
  25. // - All TensorLists and ScalarLists must have the same number of elements.
  26. // - Corresponding tensors must have the same size.
  27. void check_foreach_api_restrictions(TensorList tensors) {
  28. TORCH_CHECK(!tensors.empty(), "Tensor list must have at least one tensor.");
  29. }
  30. void check_foreach_api_restrictions(TensorList tensors, ArrayRef<Scalar> scalars) {
  31. check_foreach_api_restrictions(tensors);
  32. TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
  33. }
  34. void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
  35. TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
  36. TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
  37. TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
  38. }
  39. void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
  40. TORCH_CHECK(!tensors1.empty(), "Tensor list must have at least one tensor.");
  41. TORCH_CHECK(!tensors2.empty(), "Tensor list must have at least one tensor.");
  42. TORCH_CHECK(!tensors3.empty(), "Tensor list must have at least one tensor.");
  43. TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors2.size());
  44. TORCH_CHECK(tensors1.size() == tensors3.size(), "Tensor lists must have the same number of tensors, got ", tensors1.size(), " and ", tensors3.size());
  45. }
  46. void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<Scalar> scalars) {
  47. check_foreach_api_restrictions(tensors1, tensors2, tensors3);
  48. TORCH_CHECK(tensors1.size() == scalars.size(), "Tensor list must have same number of elements as scalar list, got ", tensors1.size(), " and ", scalars.size());
  49. }
  50. // To go via 'fast' path, several conditions must be satisfied
  51. // - All tensors in all lists must have the same dtype.
  52. // - All tensors must be on the same device
  53. // - All tensors must have strided layout
  54. // - All tensors must be non-overlapping and dense
  55. // - Resulting tensor must have the same dtype as the input one
  56. // Please, make sure to call check_foreach_api_restrictions before calling this method.
  57. // There is a set of preconditions that have to be satisfied.
  58. bool check_fast_path_restrictions(
  59. ArrayRef<TensorList> tensorLists,
  60. ArrayRef<Scalar> scalarList = {},
  61. bool does_op_promote_integer_inputs_to_float = false) {
  62. const auto expected_dtype = tensorLists[0][0].dtype();
  63. const auto expected_device = tensorLists[0][0].device();
  64. auto is_tensor_okay = [&](const Tensor& tensor) {
  65. return tensor.dtype() == expected_dtype &&
  66. tensor.device() == expected_device &&
  67. tensor.layout() == at::kStrided &&
  68. tensor.is_non_overlapping_and_dense();
  69. };
  70. for (const auto& tensorList : tensorLists) {
  71. for (const auto& tensor : tensorList) {
  72. if (!is_tensor_okay(tensor)) {
  73. return false;
  74. }
  75. }
  76. }
  77. // Check if corresponding tensors in tensor lists have the same sizes and strides.
  78. for (const auto& tensor_list : tensorLists) {
  79. for (const auto j : c10::irange(tensorLists[0].size())) {
  80. if (tensorLists[0][j].sizes() != tensor_list[j].sizes()) {
  81. return false;
  82. }
  83. if (tensorLists[0][j].strides() != tensor_list[j].strides()) {
  84. return false;
  85. }
  86. }
  87. }
  88. // This function has already checked that `tensorList[j][i]` for all j, i has the same dtype
  89. // using `is_tensor_okay` function above.
  90. // This means we only need to check if {tensorList[0][0], tensorList[0][1], tensorList[0][2], ...}
  91. // do type promotion with scalarLIst.
  92. for (const auto i : c10::irange(tensorLists[0].size())) {
  93. // For division, integer inputs will result in float.
  94. if (does_op_promote_integer_inputs_to_float) {
  95. if (at::isIntegralType(tensorLists[0][i].scalar_type(), /*includeBool*/ true)) {
  96. return false;
  97. }
  98. }
  99. if (!scalarList.empty()) {
  100. const auto& scalar = scalarList.size() == 1 ? scalarList[0] : scalarList[i];
  101. const auto& tensor = tensorLists[0][i];
  102. // note(mkozuki): This check might be responsible for `_foreach_add(bool_tensors, bool_tensors)`
  103. // being pushed to slow path.
  104. if (tensor.scalar_type() != at::native::result_type(scalar, tensor)) {
  105. return false;
  106. }
  107. }
  108. }
  109. return true;
  110. }
  111. std::vector<c10::Scalar> convert_tensor_to_scalar_list(
  112. const Tensor& scalarList_,
  113. int64_t expect_length) {
  114. std::vector<c10::Scalar> scalarList;
  115. TORCH_CHECK(
  116. scalarList_.device() == c10::kCPU,
  117. "Expected scalars to be on CPU, got ",
  118. scalarList_.device(),
  119. " instead.");
  120. TORCH_CHECK(
  121. scalarList_.is_contiguous(), "Expected scalars to be contiguous.");
  122. TORCH_CHECK(
  123. scalarList_.dim() == 1,
  124. "Expected packed scalar Tensor to be of dimension 1. Got ",
  125. scalarList_.dim(),
  126. " instead.");
  127. AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
  128. kComplexHalf,
  129. kHalf,
  130. kBool,
  131. kBFloat16,
  132. scalarList_.scalar_type(),
  133. "convert_tensor_to_scalar_list",
  134. [&]() {
  135. const scalar_t* scalar_data = scalarList_.data_ptr<scalar_t>();
  136. TORCH_CHECK(
  137. (expect_length == scalarList_.size(0)),
  138. "Expected length of scalars to match input of length ",
  139. expect_length,
  140. " but got ",
  141. scalarList_.size(0),
  142. " instead.");
  143. for (int64_t i = 0; i < scalarList_.size(0); i++) {
  144. scalarList.push_back(c10::Scalar(scalar_data[i]));
  145. }
  146. });
  147. return scalarList;
  148. }
  149. bool can_use_fast_route(ArrayRef<TensorList> tensorLists,
  150. ArrayRef<Scalar> scalarList = {},
  151. bool does_op_promote_integer_inputs_to_float = false) {
  152. return check_fast_path_restrictions(tensorLists, scalarList, does_op_promote_integer_inputs_to_float);
  153. }
  154. bool can_use_fast_route(TensorList tensors1, TensorList tensors2, bool does_op_promote_integer_inputs_to_float = false) {
  155. return can_use_fast_route({tensors1, tensors2}, {}, does_op_promote_integer_inputs_to_float);
  156. }
  157. }
  158. }} // at::native