C++17.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. #pragma once
  2. #ifndef C10_UTIL_CPP17_H_
  3. #define C10_UTIL_CPP17_H_
  4. #include <c10/macros/Macros.h>
  5. #include <cstdlib>
  6. #include <functional>
  7. #include <memory>
  8. #include <sstream>
  9. #include <string>
  10. #include <type_traits>
  11. #include <utility>
  12. #if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \
  13. __GNUC__ < 5
  14. #error \
  15. "You're trying to build PyTorch with a too old version of GCC. We need GCC 5 or later."
  16. #endif
  17. #if defined(__clang__) && __clang_major__ < 4
  18. #error \
  19. "You're trying to build PyTorch with a too old version of Clang. We need Clang 4 or later."
  20. #endif
  21. #if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201402L)) || \
  22. (!defined(_MSC_VER) && __cplusplus < 201402L)
  23. #error You need C++14 to compile PyTorch
  24. #endif
  25. #if defined(_WIN32) && (defined(min) || defined(max))
  26. #error Macro clash with min and max -- define NOMINMAX when compiling your program on Windows
  27. #endif
  28. /*
  29. * This header adds some polyfills with C++17 functionality
  30. */
  31. namespace c10 {
  32. // in c++17 std::result_of has been superceded by std::invoke_result. Since
  33. // c++20, std::result_of is removed.
  34. template <typename F, typename... args>
  35. #if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703L
  36. using invoke_result = typename std::invoke_result<F, args...>;
  37. #else
  38. using invoke_result = typename std::result_of<F && (args && ...)>;
  39. #endif
  40. template <typename F, typename... args>
  41. using invoke_result_t = typename invoke_result<F, args...>::type;
  42. // std::is_pod is deprecated in C++20, std::is_standard_layout and
  43. // std::is_trivial are introduced in C++11, std::conjunction has been introduced
  44. // in C++17.
  45. template <typename T>
  46. #if defined(__cpp_lib_logical_traits) && __cpp_lib_logical_traits >= 201510L
  47. using is_pod = std::conjunction<std::is_standard_layout<T>, std::is_trivial<T>>;
  48. #else
  49. using is_pod = std::is_pod<T>;
  50. #endif
  51. template <typename T>
  52. constexpr bool is_pod_v = is_pod<T>::value;
  53. namespace guts {
  54. template <typename Base, typename Child, typename... Args>
  55. typename std::enable_if<
  56. !std::is_array<Base>::value && !std::is_array<Child>::value &&
  57. std::is_base_of<Base, Child>::value,
  58. std::unique_ptr<Base>>::type
  59. make_unique_base(Args&&... args) {
  60. return std::unique_ptr<Base>(new Child(std::forward<Args>(args)...));
  61. }
  62. #if defined(__cpp_lib_logical_traits) && !(defined(_MSC_VER) && _MSC_VER < 1920)
  63. template <class... B>
  64. using conjunction = std::conjunction<B...>;
  65. template <class... B>
  66. using disjunction = std::disjunction<B...>;
  67. template <bool B>
  68. using bool_constant = std::bool_constant<B>;
  69. template <class B>
  70. using negation = std::negation<B>;
  71. #else
  72. // Implementation taken from http://en.cppreference.com/w/cpp/types/conjunction
  73. template <class...>
  74. struct conjunction : std::true_type {};
  75. template <class B1>
  76. struct conjunction<B1> : B1 {};
  77. template <class B1, class... Bn>
  78. struct conjunction<B1, Bn...>
  79. : std::conditional_t<bool(B1::value), conjunction<Bn...>, B1> {};
  80. // Implementation taken from http://en.cppreference.com/w/cpp/types/disjunction
  81. template <class...>
  82. struct disjunction : std::false_type {};
  83. template <class B1>
  84. struct disjunction<B1> : B1 {};
  85. template <class B1, class... Bn>
  86. struct disjunction<B1, Bn...>
  87. : std::conditional_t<bool(B1::value), B1, disjunction<Bn...>> {};
  88. // Implementation taken from
  89. // http://en.cppreference.com/w/cpp/types/integral_constant
  90. template <bool B>
  91. using bool_constant = std::integral_constant<bool, B>;
  92. // Implementation taken from http://en.cppreference.com/w/cpp/types/negation
  93. template <class B>
  94. struct negation : bool_constant<!bool(B::value)> {};
  95. #endif
  96. #ifdef __cpp_lib_void_t
  97. template <class T>
  98. using void_t = std::void_t<T>;
  99. #else
  100. // Implementation taken from http://en.cppreference.com/w/cpp/types/void_t
  101. // (it takes CWG1558 into account and also works for older compilers)
  102. template <typename... Ts>
  103. struct make_void {
  104. typedef void type;
  105. };
  106. template <typename... Ts>
  107. using void_t = typename make_void<Ts...>::type;
  108. #endif
  109. #if defined(USE_ROCM)
  110. // rocm doesn't like the C10_HOST_DEVICE
  111. #define CUDA_HOST_DEVICE
  112. #else
  113. #define CUDA_HOST_DEVICE C10_HOST_DEVICE
  114. #endif
  115. #if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__)
  116. template <class F, class Tuple>
  117. CUDA_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) {
  118. return std::apply(std::forward<F>(f), std::forward<Tuple>(t));
  119. }
  120. #else
  121. // Implementation from http://en.cppreference.com/w/cpp/utility/apply (but
  122. // modified)
  123. // TODO This is an incomplete implementation of std::apply, not working for
  124. // member functions.
  125. namespace detail {
  126. template <class F, class Tuple, std::size_t... INDEX>
  127. #if defined(_MSC_VER)
  128. // MSVC has a problem with the decltype() return type, but it also doesn't need
  129. // it
  130. C10_HOST_DEVICE constexpr auto apply_impl(
  131. F&& f,
  132. Tuple&& t,
  133. std::index_sequence<INDEX...>)
  134. #else
  135. // GCC/Clang need the decltype() return type
  136. CUDA_HOST_DEVICE constexpr decltype(auto) apply_impl(
  137. F&& f,
  138. Tuple&& t,
  139. std::index_sequence<INDEX...>)
  140. #endif
  141. {
  142. return std::forward<F>(f)(std::get<INDEX>(std::forward<Tuple>(t))...);
  143. }
  144. } // namespace detail
  145. template <class F, class Tuple>
  146. CUDA_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) {
  147. return detail::apply_impl(
  148. std::forward<F>(f),
  149. std::forward<Tuple>(t),
  150. std::make_index_sequence<
  151. std::tuple_size<std::remove_reference_t<Tuple>>::value>{});
  152. }
  153. #endif
  154. #undef CUDA_HOST_DEVICE
  155. template <typename Functor, typename... Args>
  156. typename std::enable_if<
  157. std::is_member_pointer<typename std::decay<Functor>::type>::value,
  158. typename c10::invoke_result_t<Functor, Args...>>::type
  159. invoke(Functor&& f, Args&&... args) {
  160. return std::mem_fn(std::forward<Functor>(f))(std::forward<Args>(args)...);
  161. }
  162. template <typename Functor, typename... Args>
  163. typename std::enable_if<
  164. !std::is_member_pointer<typename std::decay<Functor>::type>::value,
  165. typename c10::invoke_result_t<Functor, Args...>>::type
  166. invoke(Functor&& f, Args&&... args) {
  167. return std::forward<Functor>(f)(std::forward<Args>(args)...);
  168. }
  169. namespace detail {
  170. struct _identity final {
  171. template <class T>
  172. using type_identity = T;
  173. template <class T>
  174. decltype(auto) operator()(T&& arg) {
  175. return std::forward<T>(arg);
  176. }
  177. };
  178. template <class Func, class Enable = void>
  179. struct function_takes_identity_argument : std::false_type {};
  180. #if defined(_MSC_VER)
  181. // For some weird reason, MSVC shows a compiler error when using guts::void_t
  182. // instead of std::void_t. But we're only building on MSVC versions that have
  183. // std::void_t, so let's just use that one.
  184. template <class Func>
  185. struct function_takes_identity_argument<
  186. Func,
  187. std::void_t<decltype(std::declval<Func>()(_identity()))>> : std::true_type {
  188. };
  189. #else
  190. template <class Func>
  191. struct function_takes_identity_argument<
  192. Func,
  193. void_t<decltype(std::declval<Func>()(_identity()))>> : std::true_type {};
  194. #endif
  195. template <bool Condition>
  196. struct _if_constexpr;
  197. template <>
  198. struct _if_constexpr<true> final {
  199. template <
  200. class ThenCallback,
  201. class ElseCallback,
  202. std::enable_if_t<
  203. function_takes_identity_argument<ThenCallback>::value,
  204. void*> = nullptr>
  205. static decltype(auto) call(
  206. ThenCallback&& thenCallback,
  207. ElseCallback&& /* elseCallback */) {
  208. // The _identity instance passed in can be used to delay evaluation of an
  209. // expression, because the compiler can't know that it's just the identity
  210. // we're passing in.
  211. return thenCallback(_identity());
  212. }
  213. template <
  214. class ThenCallback,
  215. class ElseCallback,
  216. std::enable_if_t<
  217. !function_takes_identity_argument<ThenCallback>::value,
  218. void*> = nullptr>
  219. static decltype(auto) call(
  220. ThenCallback&& thenCallback,
  221. ElseCallback&& /* elseCallback */) {
  222. return thenCallback();
  223. }
  224. };
  225. template <>
  226. struct _if_constexpr<false> final {
  227. template <
  228. class ThenCallback,
  229. class ElseCallback,
  230. std::enable_if_t<
  231. function_takes_identity_argument<ElseCallback>::value,
  232. void*> = nullptr>
  233. static decltype(auto) call(
  234. ThenCallback&& /* thenCallback */,
  235. ElseCallback&& elseCallback) {
  236. // The _identity instance passed in can be used to delay evaluation of an
  237. // expression, because the compiler can't know that it's just the identity
  238. // we're passing in.
  239. return elseCallback(_identity());
  240. }
  241. template <
  242. class ThenCallback,
  243. class ElseCallback,
  244. std::enable_if_t<
  245. !function_takes_identity_argument<ElseCallback>::value,
  246. void*> = nullptr>
  247. static decltype(auto) call(
  248. ThenCallback&& /* thenCallback */,
  249. ElseCallback&& elseCallback) {
  250. return elseCallback();
  251. }
  252. };
  253. } // namespace detail
  254. /*
  255. * Get something like C++17 if constexpr in C++14.
  256. *
  257. * Example 1: simple constexpr if/then/else
  258. * template<int arg> int increment_absolute_value() {
  259. * int result = arg;
  260. * if_constexpr<(arg > 0)>(
  261. * [&] { ++result; } // then-case
  262. * [&] { --result; } // else-case
  263. * );
  264. * return result;
  265. * }
  266. *
  267. * Example 2: without else case (i.e. conditionally prune code from assembly)
  268. * template<int arg> int decrement_if_positive() {
  269. * int result = arg;
  270. * if_constexpr<(arg > 0)>(
  271. * // This decrement operation is only present in the assembly for
  272. * // template instances with arg > 0.
  273. * [&] { --result; }
  274. * );
  275. * return result;
  276. * }
  277. *
  278. * Example 3: branch based on type (i.e. replacement for SFINAE)
  279. * struct MyClass1 {int value;};
  280. * struct MyClass2 {int val};
  281. * template <class T>
  282. * int func(T t) {
  283. * return if_constexpr<std::is_same<T, MyClass1>::value>(
  284. * [&](auto _) { return _(t).value; }, // this code is invalid for T ==
  285. * MyClass2, so a regular non-constexpr if statement wouldn't compile
  286. * [&](auto _) { return _(t).val; } // this code is invalid for T ==
  287. * MyClass1
  288. * );
  289. * }
  290. *
  291. * Note: The _ argument passed in Example 3 is the identity function, i.e. it
  292. * does nothing. It is used to force the compiler to delay type checking,
  293. * because the compiler doesn't know what kind of _ is passed in. Without it,
  294. * the compiler would fail when you try to access t.value but the member doesn't
  295. * exist.
  296. *
  297. * Note: In Example 3, both branches return int, so func() returns int. This is
  298. * not necessary. If func() had a return type of "auto", then both branches
  299. * could return different types, say func<MyClass1>() could return int and
  300. * func<MyClass2>() could return string.
  301. *
  302. * Note: if_constexpr<cond, t, f> is *eager* w.r.t. template expansion - meaning
  303. * this polyfill does not behave like a true "if statement at compilation time".
  304. * The `_` trick above only defers typechecking, which happens after
  305. * templates have been expanded. (Of course this is all that's necessary for
  306. * many use cases).
  307. */
  308. template <bool Condition, class ThenCallback, class ElseCallback>
  309. decltype(auto) if_constexpr(
  310. ThenCallback&& thenCallback,
  311. ElseCallback&& elseCallback) {
  312. #if defined(__cpp_if_constexpr)
  313. // If we have C++17, just use it's "if constexpr" feature instead of wrapping
  314. // it. This will give us better error messages.
  315. if constexpr (Condition) {
  316. if constexpr (detail::function_takes_identity_argument<
  317. ThenCallback>::value) {
  318. // Note that we use static_cast<T&&>(t) instead of std::forward (or
  319. // ::std::forward) because using the latter produces some compilation
  320. // errors about ambiguous `std` on MSVC when using C++17. This static_cast
  321. // is just what std::forward is doing under the hood, and is equivalent.
  322. return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
  323. } else {
  324. return static_cast<ThenCallback&&>(thenCallback)();
  325. }
  326. } else {
  327. if constexpr (detail::function_takes_identity_argument<
  328. ElseCallback>::value) {
  329. return static_cast<ElseCallback&&>(elseCallback)(detail::_identity());
  330. } else {
  331. return static_cast<ElseCallback&&>(elseCallback)();
  332. }
  333. }
  334. #else
  335. // C++14 implementation of if constexpr
  336. return detail::_if_constexpr<Condition>::call(
  337. static_cast<ThenCallback&&>(thenCallback),
  338. static_cast<ElseCallback&&>(elseCallback));
  339. #endif
  340. }
  341. template <bool Condition, class ThenCallback>
  342. decltype(auto) if_constexpr(ThenCallback&& thenCallback) {
  343. #if defined(__cpp_if_constexpr)
  344. // If we have C++17, just use it's "if constexpr" feature instead of wrapping
  345. // it. This will give us better error messages.
  346. if constexpr (Condition) {
  347. if constexpr (detail::function_takes_identity_argument<
  348. ThenCallback>::value) {
  349. // Note that we use static_cast<T&&>(t) instead of std::forward (or
  350. // ::std::forward) because using the latter produces some compilation
  351. // errors about ambiguous `std` on MSVC when using C++17. This static_cast
  352. // is just what std::forward is doing under the hood, and is equivalent.
  353. return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
  354. } else {
  355. return static_cast<ThenCallback&&>(thenCallback)();
  356. }
  357. }
  358. #else
  359. // C++14 implementation of if constexpr
  360. return if_constexpr<Condition>(
  361. static_cast<ThenCallback&&>(thenCallback), [](auto) {});
  362. #endif
  363. }
  364. // GCC 4.8 doesn't define std::to_string, even though that's in C++11. Let's
  365. // define it.
  366. namespace detail {
  367. class DummyClassForToString final {};
  368. } // namespace detail
  369. } // namespace guts
  370. } // namespace c10
  371. namespace std {
  372. // We use SFINAE to detect if std::to_string exists for a type, but that only
  373. // works if the function name is defined. So let's define a std::to_string for a
  374. // dummy type. If you're getting an error here saying that this overload doesn't
  375. // match your std::to_string() call, then you're calling std::to_string() but
  376. // should be calling c10::guts::to_string().
  377. inline std::string to_string(c10::guts::detail::DummyClassForToString) {
  378. return "";
  379. }
  380. } // namespace std
  381. namespace c10 {
  382. namespace guts {
  383. namespace detail {
  384. template <class T, class Enable = void>
  385. struct to_string_ final {
  386. static std::string call(T value) {
  387. std::ostringstream str;
  388. str << value;
  389. return str.str();
  390. }
  391. };
  392. // If a std::to_string exists, use that instead
  393. template <class T>
  394. struct to_string_<T, void_t<decltype(std::to_string(std::declval<T>()))>>
  395. final {
  396. static std::string call(T value) {
  397. return std::to_string(value);
  398. }
  399. };
  400. } // namespace detail
  401. template <class T>
  402. inline std::string to_string(T value) {
  403. return detail::to_string_<T>::call(value);
  404. }
  405. template <class T>
  406. constexpr const T& min(const T& a, const T& b) {
  407. return (b < a) ? b : a;
  408. }
  409. template <class T>
  410. constexpr const T& max(const T& a, const T& b) {
  411. return (a < b) ? b : a;
  412. }
  413. } // namespace guts
  414. } // namespace c10
  415. #endif // C10_UTIL_CPP17_H_