123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500 |
- #pragma once
- // DO NOT DEFINE STATIC DATA IN THIS HEADER!
- // See Note [Do not compile initializers with AVX]
- #include <ATen/cpu/vec/vec.h>
- namespace at { namespace vec {
- // BFloat16 specification
- template <typename scalar_t> struct VecScalarType { using type = scalar_t; };
- template <> struct VecScalarType<BFloat16> { using type = float; };
- // This is different from at::acc_type since we only need to specialize BFloat16
- template <typename scalar_t>
- using vec_scalar_t = typename VecScalarType<scalar_t>::type;
- // Note that we already have specialized member of Vectorized<scalar_t> for BFloat16
- // so the following functions would run smoothly:
- // using Vec = Vectorized<BFloat16>;
- // Vec one = Vec(BFloat16(1));
- // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N);
- //
- // Then why we still need to specialize "funtional"?
- // If we do specialization at Vectorized<> level, the above example would need 3 pairs of
- // conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/".
- // If we do specialization at vec::map<>() level, we have only 1 pair of conversion
- // of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only.
- //
- // The following BFloat16 functionality will only do data type conversion for input
- // and output vector (reduce functionality will only convert the final scalar back to bf16).
- // Compared to Vectorized<> specialization,
- // 1. better performance since we have less data type conversion;
- // 2. less rounding error since immediate results are kept in fp32;
- // 3. accumulation done on data type of fp32.
- //
- // If you plan to extend this file, please ensure adding unit tests at
- // aten/src/ATen/test/vec_test_all_types.cpp
- //
- template <typename scalar_t = BFloat16, typename Op>
- inline BFloat16 reduce_all(const Op& vec_fun, const BFloat16* data, int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- if (size < bVec::size()) {
- bVec data_bvec = bVec::loadu(data, size);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- if (size > fVec::size()) {
- data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size());
- return vec_reduce_all<float>(vec_fun, data_fvec0, fVec::size());
- } else {
- return vec_reduce_all<float>(vec_fun, data_fvec0, size);
- }
- }
- int64_t d = bVec::size();
- bVec acc_bvec = bVec::loadu(data);
- fVec acc_fvec0, acc_fvec1;
- std::tie(acc_fvec0, acc_fvec1) = convert_bfloat16_float(acc_bvec);
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = vec_fun(acc_fvec1, data_fvec1);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- if (size - d > fVec::size()) {
- acc_fvec0 = vec_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
- } else {
- acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d);
- }
- }
- acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1);
- return vec_reduce_all<float>(vec_fun, acc_fvec0);
- }
- template <typename scalar_t = BFloat16, typename Op1, typename Op2>
- inline std::pair<BFloat16, BFloat16> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
- const BFloat16* data, int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- if (size < bVec::size()) {
- bVec data_bvec = bVec::loadu(data, size);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- if (size > fVec::size()) {
- fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size());
- fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size());
- return std::pair<BFloat16, BFloat16>(
- vec_reduce_all<float>(vec_fun1, acc1_fvec, fVec::size()),
- vec_reduce_all<float>(vec_fun2, acc2_fvec, fVec::size()));
- } else {
- return std::pair<BFloat16, BFloat16>(
- vec_reduce_all<float>(vec_fun1, data_fvec0, size),
- vec_reduce_all<float>(vec_fun2, data_fvec0, size));
- }
- }
- int64_t d = bVec::size();
- bVec acc_bvec = bVec::loadu(data);
- fVec acc1_fvec0, acc1_fvec1;
- std::tie(acc1_fvec0, acc1_fvec1) = convert_bfloat16_float(acc_bvec);
- fVec acc2_fvec0, acc2_fvec1;
- std::tie(acc2_fvec0, acc2_fvec1) = convert_bfloat16_float(acc_bvec);
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
- acc1_fvec1 = vec_fun1(acc1_fvec1, data_fvec1);
- acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
- acc2_fvec1 = vec_fun2(acc2_fvec1, data_fvec1);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- if (size - d > fVec::size()) {
- acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0);
- acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size());
- acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0);
- acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size());
- } else {
- acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d);
- acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d);
- }
- }
- acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1);
- acc2_fvec0 = vec_fun2(acc2_fvec0, acc2_fvec1);
- return std::pair<BFloat16, BFloat16>(
- vec_reduce_all<float>(vec_fun1, acc1_fvec0),
- vec_reduce_all<float>(vec_fun2, acc2_fvec0));
- }
- template <typename scalar_t = BFloat16, typename MapOp, typename ReduceOp>
- inline BFloat16 map_reduce_all(
- const MapOp& map_fun,
- const ReduceOp& red_fun,
- const BFloat16* data,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- if (size < bVec::size()) {
- bVec data_bvec = bVec::loadu(data, size);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- if (size > fVec::size()) {
- data_fvec0 = map_fun(data_fvec0);
- data_fvec1 = map_fun(data_fvec1);
- data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
- return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
- } else {
- data_fvec0 = map_fun(data_fvec0);
- return vec_reduce_all<float>(red_fun, data_fvec0, size);
- }
- }
- int64_t d = bVec::size();
- bVec acc_bvec = bVec::loadu(data);
- fVec acc_fvec0, acc_fvec1;
- std::tie(acc_fvec0, acc_fvec1) = convert_bfloat16_float(acc_bvec);
- acc_fvec0 = map_fun(acc_fvec0);
- acc_fvec1 = map_fun(acc_fvec1);
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- data_fvec0 = map_fun(data_fvec0);
- data_fvec1 = map_fun(data_fvec1);
- acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- if (size - d > fVec::size()) {
- data_fvec0 = map_fun(data_fvec0);
- data_fvec1 = map_fun(data_fvec1);
- acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
- } else {
- data_fvec0 = map_fun(data_fvec0);
- acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
- }
- }
- acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
- return vec_reduce_all<float>(red_fun, acc_fvec0);
- }
- template <typename scalar_t = BFloat16, typename MapOp, typename ReduceOp>
- inline BFloat16 map2_reduce_all(
- const MapOp& map_fun,
- const ReduceOp& red_fun,
- const BFloat16* data,
- const BFloat16* data2,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- if (size < bVec::size()) {
- bVec data_bvec = bVec::loadu(data, size);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(data2, size);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- if (size > fVec::size()) {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0);
- data_fvec1 = map_fun(data_fvec1, data2_fvec1);
- data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
- return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
- } else {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0);
- return vec_reduce_all<float>(red_fun, data_fvec0, size);
- }
- }
- int64_t d = bVec::size();
- bVec acc_bvec = bVec::loadu(data);
- fVec acc_fvec0, acc_fvec1;
- std::tie(acc_fvec0, acc_fvec1) = convert_bfloat16_float(acc_bvec);
- bVec acc2_bvec = bVec::loadu(data2);
- fVec acc2_fvec0, acc2_fvec1;
- std::tie(acc2_fvec0, acc2_fvec1) = convert_bfloat16_float(acc2_bvec);
- acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0);
- acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1);
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(data2 + d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- data_fvec0 = map_fun(data_fvec0, data2_fvec0);
- data_fvec1 = map_fun(data_fvec1, data2_fvec1);
- acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(data2 + d, size - d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- if (size - d > fVec::size()) {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0);
- data_fvec1 = map_fun(data_fvec1, data2_fvec1);
- acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
- } else {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0);
- acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
- }
- }
- acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
- return vec_reduce_all<float>(red_fun, acc_fvec0);
- }
- template <typename scalar_t = BFloat16, typename MapOp, typename ReduceOp>
- inline BFloat16 map3_reduce_all(
- const MapOp& map_fun,
- const ReduceOp& red_fun,
- const BFloat16* data,
- const BFloat16* data2,
- const BFloat16* data3,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- if (size < bVec::size()) {
- bVec data_bvec = bVec::loadu(data, size);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(data2, size);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(data3, size);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- if (size > fVec::size()) {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
- data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
- data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size());
- return vec_reduce_all<float>(red_fun, data_fvec0, fVec::size());
- } else {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
- return vec_reduce_all<float>(red_fun, data_fvec0, size);
- }
- }
- int64_t d = bVec::size();
- bVec acc_bvec = bVec::loadu(data);
- fVec acc_fvec0, acc_fvec1;
- std::tie(acc_fvec0, acc_fvec1) = convert_bfloat16_float(acc_bvec);
- bVec acc2_bvec = bVec::loadu(data2);
- fVec acc2_fvec0, acc2_fvec1;
- std::tie(acc2_fvec0, acc2_fvec1) = convert_bfloat16_float(acc2_bvec);
- bVec acc3_bvec = bVec::loadu(data3);
- fVec acc3_fvec0, acc3_fvec1;
- std::tie(acc3_fvec0, acc3_fvec1) = convert_bfloat16_float(acc3_bvec);
- acc_fvec0 = map_fun(acc_fvec0, acc2_fvec0, acc3_fvec0);
- acc_fvec1 = map_fun(acc_fvec1, acc2_fvec1, acc3_fvec1);
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(data2 + d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(data3 + d);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
- data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
- acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = red_fun(acc_fvec1, data_fvec1);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(data2 + d, size - d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(data3 + d, size - d);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- if (size - d > fVec::size()) {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
- data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1);
- acc_fvec0 = red_fun(acc_fvec0, data_fvec0);
- acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size());
- } else {
- data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0);
- acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d);
- }
- }
- acc_fvec0 = red_fun(acc_fvec0, acc_fvec1);
- return vec_reduce_all<float>(red_fun, acc_fvec0);
- }
- template <typename scalar_t = BFloat16, typename Op>
- inline void map(
- const Op& vec_fun,
- BFloat16* output_data,
- const BFloat16* input_data,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- int64_t d = 0;
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(input_data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- fVec output_fvec0 = vec_fun(data_fvec0);
- fVec output_fvec1 = vec_fun(data_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(input_data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- fVec output_fvec0 = vec_fun(data_fvec0);
- fVec output_fvec1 = vec_fun(data_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d, size - d);
- }
- }
- template <typename scalar_t = BFloat16, typename Op>
- inline void map2(
- const Op& vec_fun,
- BFloat16* output_data,
- const BFloat16* input_data,
- const BFloat16* input_data2,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- int64_t d = 0;
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data_bvec = bVec::loadu(input_data + d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(input_data2 + d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
- fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d);
- }
- if (size - d > 0) {
- bVec data_bvec = bVec::loadu(input_data + d, size - d);
- fVec data_fvec0, data_fvec1;
- std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
- bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- fVec output_fvec0 = vec_fun(data_fvec0, data2_fvec0);
- fVec output_fvec1 = vec_fun(data_fvec1, data2_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d, size - d);
- }
- }
- template <typename scalar_t = BFloat16, typename Op>
- inline void map3(
- const Op& vec_fun,
- BFloat16* output_data,
- const BFloat16* input_data1,
- const BFloat16* input_data2,
- const BFloat16* input_data3,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- int64_t d = 0;
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data1_bvec = bVec::loadu(input_data1 + d);
- fVec data1_fvec0, data1_fvec1;
- std::tie(data1_fvec0, data1_fvec1) = convert_bfloat16_float(data1_bvec);
- bVec data2_bvec = bVec::loadu(input_data2 + d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(input_data3 + d);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
- fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d);
- }
- if (size - d > 0) {
- bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
- fVec data1_fvec0, data1_fvec1;
- std::tie(data1_fvec0, data1_fvec1) = convert_bfloat16_float(data1_bvec);
- bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0);
- fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d, size - d);
- }
- }
- template <typename scalar_t = BFloat16, typename Op>
- inline void map4(
- const Op& vec_fun,
- BFloat16* output_data,
- const BFloat16* input_data1,
- const BFloat16* input_data2,
- const BFloat16* input_data3,
- const BFloat16* input_data4,
- int64_t size) {
- using bVec = vec::Vectorized<BFloat16>;
- using fVec = vec::Vectorized<float>;
- int64_t d = 0;
- for (; d < size - (size % bVec::size()); d += bVec::size()) {
- bVec data1_bvec = bVec::loadu(input_data1 + d);
- fVec data1_fvec0, data1_fvec1;
- std::tie(data1_fvec0, data1_fvec1) = convert_bfloat16_float(data1_bvec);
- bVec data2_bvec = bVec::loadu(input_data2 + d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(input_data3 + d);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- bVec data4_bvec = bVec::loadu(input_data4 + d);
- fVec data4_fvec0, data4_fvec1;
- std::tie(data4_fvec0, data4_fvec1) = convert_bfloat16_float(data4_bvec);
- fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
- fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d);
- }
- if (size - d > 0) {
- bVec data1_bvec = bVec::loadu(input_data1 + d, size - d);
- fVec data1_fvec0, data1_fvec1;
- std::tie(data1_fvec0, data1_fvec1) = convert_bfloat16_float(data1_bvec);
- bVec data2_bvec = bVec::loadu(input_data2 + d, size - d);
- fVec data2_fvec0, data2_fvec1;
- std::tie(data2_fvec0, data2_fvec1) = convert_bfloat16_float(data2_bvec);
- bVec data3_bvec = bVec::loadu(input_data3 + d, size - d);
- fVec data3_fvec0, data3_fvec1;
- std::tie(data3_fvec0, data3_fvec1) = convert_bfloat16_float(data3_bvec);
- bVec data4_bvec = bVec::loadu(input_data4 + d, size - d);
- fVec data4_fvec0, data4_fvec1;
- std::tie(data4_fvec0, data4_fvec1) = convert_bfloat16_float(data4_bvec);
- fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0);
- fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1);
- bVec output_bvec = convert_float_bfloat16(output_fvec0, output_fvec1);
- output_bvec.store(output_data + d, size - d);
- }
- }
- }} // namespace at::vec
|