SerialStackImpl.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. // Copyright 2004-present Facebook. All Rights Reserved.
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/MemoryOverlap.h>
  5. #include <ATen/Parallel.h>
  6. #include <ATen/TensorIterator.h>
  7. #include <ATen/cpu/vec/functional.h>
  8. #include <ATen/cpu/vec/vec.h>
  9. #include <c10/util/irange.h>
  10. namespace at { namespace native { namespace detail {
  11. struct InputMeta {
  12. void* data_ptr;
  13. int64_t inner_size;
  14. InputMeta(const Tensor& t, int64_t dim, int64_t inner)
  15. : data_ptr(t.data_ptr()), inner_size(t.sizes()[dim] * inner) {}
  16. };
  17. // This kernel is used by two TensorList types:
  18. // 1. stack_serial_kernel uses at::ArrayRef<Tensor>
  19. // 2. Static runtime calls this kernel directly (csrc/jit/runtime/static/ops.cpp) with
  20. // ProcessedNodeInputWrapper.
  21. // When making changes, make sure that they are compatible with both types!
  22. template <typename scalar_t, typename TensorListType>
  23. void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t dim) {
  24. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  25. dim >= 0 && dim <= result.dim(),
  26. "dim out of range in stack_serial_kernel_impl");
  27. int64_t outer =
  28. result.numel() / (result.sizes()[dim] * result.strides()[dim]);
  29. scalar_t* result_data = result.data_ptr<scalar_t>();
  30. int64_t ninputs = tensors.size();
  31. std::vector<InputMeta> inputs;
  32. inputs.reserve(ninputs);
  33. for (const auto& tensor : tensors) {
  34. inputs.emplace_back(tensor, dim, tensor.strides()[dim]);
  35. }
  36. using Vec = vec::Vectorized<scalar_t>;
  37. scalar_t* result_ptr = result_data;
  38. for (const auto i : c10::irange(outer)) {
  39. for (const auto j : c10::irange(ninputs)) {
  40. int64_t local_inner = inputs[j].inner_size;
  41. scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
  42. if (local_inner < Vec::size()) {
  43. for (const auto k : c10::irange(local_inner)) {
  44. result_ptr[k] = input_ptr[k];
  45. }
  46. } else {
  47. vec::map(
  48. [](Vec x) { return x; }, result_ptr, input_ptr, local_inner);
  49. }
  50. result_ptr += local_inner;
  51. }
  52. }
  53. }
  54. // Checks to see whether native stack can be invoked under these conditions:
  55. // - result and input tensors are contiguous
  56. // - only one thread is used
  57. // - no type promotion has to occur
  58. // - tensors dtype is Double or Float
  59. template <typename TensorListType>
  60. bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
  61. TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
  62. const Tensor& first_tensor = tensors[0];
  63. // stack dimension should be in range [0,firstTensor.dim())
  64. // dim == firstTensor.dim() is a valid input, but it is handled by default code path
  65. // that uses unsqueeze
  66. if (dim >= first_tensor.dim()) return false;
  67. // Native stack doesn't apply any tensor is skipped.
  68. if (first_tensor.numel() == 0 && first_tensor.dim() == 1) return false;
  69. // there should be no type promotion
  70. if (result.dtype() != first_tensor.dtype()) return false;
  71. auto first_tensor_mem_format = first_tensor.suggest_memory_format();
  72. ScalarType dtype = first_tensor.scalar_type();
  73. if (!result.is_contiguous(first_tensor_mem_format)) {
  74. return false;
  75. }
  76. // fast path only works for Double and Float
  77. if (dtype != ScalarType::Double && dtype != ScalarType::Float) {
  78. return false;
  79. }
  80. // check remainder of inputs
  81. auto const &first_tensor_shape = first_tensor.sizes();
  82. for (const auto i : c10::irange(1, tensors.size())) {
  83. auto const &tensor = tensors[i];
  84. TORCH_CHECK(tensors[i].sizes() == first_tensor.sizes(),
  85. "stack expects each tensor to be equal size, but got ", first_tensor_shape,
  86. " at entry 0 and ", tensor.sizes(), " at entry ", i);
  87. // every tensor must be contiguous
  88. // tensor sizes and strides must be the same
  89. // there should be no type promotion
  90. if (!tensor.is_contiguous(first_tensor_mem_format) ||
  91. tensor.strides() != first_tensor.strides() ||
  92. tensor.dtype() != dtype) {
  93. return false;
  94. }
  95. }
  96. // fast native stack should only be used when it is not worth using multiple threads
  97. // or there is only one thread. Note that we aren't checking result.numel() here because
  98. // it may not have been resized and we want to defer that cost till later.
  99. int64_t numel_in_stack = first_tensor.numel() * tensors.size();
  100. return numel_in_stack < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
  101. }
  102. template <typename TensorListType, bool should_skip_overlap_check>
  103. struct CanUseNativeSerialStack;
  104. template <typename TensorListType>
  105. struct CanUseNativeSerialStack<TensorListType, false> {
  106. static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
  107. // Inputs cannot alias the output tensor
  108. for (const auto i : c10::irange(tensors.size())) {
  109. auto lap = at::get_overlap_status(result, tensors[i]);
  110. TORCH_CHECK(lap != at::MemOverlapStatus::Partial &&
  111. lap != at::MemOverlapStatus::Full, 0,
  112. "unsupported operation: the input tensors cannot refer to any of the "
  113. "output memory locations. Found overlap in input tensor ", i);
  114. }
  115. return can_use_native_serial_stack_impl(result, tensors, dim);
  116. }
  117. };
  118. template <typename TensorListType>
  119. struct CanUseNativeSerialStack<TensorListType, true> {
  120. static bool call(Tensor& result, TensorListType tensors, int64_t dim) {
  121. return can_use_native_serial_stack_impl(result, tensors, dim);
  122. }
  123. };
  124. }}} // namespace at::native::detail