#pragma once #include #if AT_USE_JITERATOR() #include #include #include #include #include #include #include #include namespace at { namespace native { #define AT_FOR_8_CASES(_) \ _(1) \ _(2) \ _(3) \ _(4) \ _(5) \ _(6) \ _(7) \ _(8) #define AT_FOR_8_CASES_WITH_COMMA(_) \ _(1) , \ _(2) , \ _(3) , \ _(4) , \ _(5) , \ _(6) , \ _(7) , \ _(8) c10::SmallVector get_extra_args_typenames(const c10::SmallVector& extra_args) { c10::SmallVector args_typenames(extra_args.size()); for (auto i = 0; i < extra_args.size(); ++i) { args_typenames[i] = at::cuda::jit::typeName(extra_args[i].type()); } return args_typenames; } int can_vectorize_up_to(at::ScalarType type, char* pointer) { switch(type) { #define DEFINE_CASE(ctype, scalartype) \ case ScalarType::scalartype : return memory::can_vectorize_up_to(pointer); AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); } } // jitted version of the above // See Note [Jiterator], this relies on the assumptions enumerated there int jitted_can_vectorize_up_to(const TensorIteratorBase& iter) { const at::ScalarType common_dtype = iter.common_dtype(); const at::ScalarType result_dtype = common_dtype; // Deals with output int result = can_vectorize_up_to(result_dtype, static_cast(iter.data_ptr(0))); // Incorporates input(s) for (auto i = 1; i < iter.ntensors(); ++i) { result = std::min(result, can_vectorize_up_to(common_dtype, static_cast(iter.data_ptr(i)))); } return result; } template static std::unique_ptr> make_unique_offset_calculator( const TensorIteratorBase& iter) { // array size can not be 0, this happens when N == 0 constexpr int array_size = std::max(N, 1); TORCH_INTERNAL_ASSERT(N == (IS_INPUT ? iter.ninputs() : iter.noutputs())); std::array strides; int64_t element_sizes[array_size]; for (int i = 0; i < N; i++) { int index = IS_INPUT ? i + iter.noutputs() : i; strides[i] = iter.strides(index).data(); element_sizes[i] = iter.element_size(index); } return std::make_unique>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); } template struct OffsetCalculatorVariant { #define DEFINE_CASE(index) std::unique_ptr> using OffsetCalculatorTypes = c10::variant< AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) >; #undef DEFINE_CASE OffsetCalculatorVariant(const TensorIteratorBase& iter) { int num = IS_INPUT ? iter.ninputs() : iter.noutputs(); switch(num) { #define DEFINE_CASE(index) \ case index : v = make_unique_offset_calculator(iter); break; AT_FOR_8_CASES(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_CHECK(false, "OffsetCalculatorVariant is not implemented for num_tensor = ", num); } } void* data_ptr() { return c10::visit([](auto & v){ return static_cast(v.get()); }, v); } private: OffsetCalculatorTypes v; }; struct ArrayVariant { // works for up to 8 input + 8 outputs #define DEFINE_CASE(index) at::detail::Array, at::detail::Array using ArrayTypes = c10::variant< AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) >; #undef DEFINE_CASE ArrayVariant(const TensorIteratorBase& iter) { int ntensors = iter.ntensors(); switch(ntensors) { #define DEFINE_CASE(index) \ case index: array = at::detail::Array{}; break; \ case index+8: array = at::detail::Array{}; break; AT_FOR_8_CASES(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_CHECK(false, "ArrayVariant is not implemented for ntensors = ", ntensors); } c10::visit([&](auto& a) { for (auto i = 0; i < ntensors; ++i) { a[i] = (char*)iter.data_ptr(i); } }, array); } void* data_ptr() { return c10::visit([](auto & a){ return static_cast(&a); }, array); } private: ArrayTypes array; }; struct TrivialOffsetCalculatorVariant { #define DEFINE_CASE(index) TrivialOffsetCalculator using TrivialOffsetCalculatorTypes = c10::variant< AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) >; #undef DEFINE_CASE TrivialOffsetCalculatorVariant(int num) { switch(num) { #define DEFINE_CASE(index) \ case index: v = TrivialOffsetCalculator(); break; AT_FOR_8_CASES(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_CHECK(false, "TrivialOffsetCalculatorVariant is not implemented for num_tensors = ", num); } } void* data_ptr() { return c10::visit([](auto & v){ return static_cast(&v); }, v); } private: TrivialOffsetCalculatorTypes v; }; struct LoadWithCastVariant { #define DEFINE_CASE(index) std::unique_ptr> using LoadWithCastPtr = c10::variant< AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) >; #undef DEFINE_CASE LoadWithCastVariant(const TensorIteratorBase& iter) { int arity = iter.ninputs(); switch(arity) { #define DEFINE_CASE(index) \ case index: v = std::make_unique>(iter); break; AT_FOR_8_CASES(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_CHECK(false, "LoadWithCastVariant is not implemented for ninputs = ", arity); } } void* data_ptr() { return c10::visit([](auto & v){ return static_cast(v.get()); }, v); } private: LoadWithCastPtr v; }; struct StoreWithCastVariant { #define DEFINE_CASE(index) std::unique_ptr> using StoreWithCastPtr = c10::variant< AT_FOR_8_CASES_WITH_COMMA(DEFINE_CASE) >; #undef DEFINE_CASE StoreWithCastVariant(const TensorIteratorBase& iter) { int num = iter.noutputs(); switch(num) { #define DEFINE_CASE(index) \ case index: v = std::make_unique>(iter); break; AT_FOR_8_CASES(DEFINE_CASE) #undef DEFINE_CASE default: TORCH_CHECK(false, "StoreWithCastVariant is not implemented for noutputs = ", num); } } void* data_ptr() { return c10::visit([](auto & v){ return static_cast(v.get()); }, v); } private: StoreWithCastPtr v; }; }} // namespace at::native #endif // AT_USE_JITERATOR()