vec.h 958 B

123456789101112131415161718192021222324252627282930313233343536
  1. #pragma once
  2. #if defined(CPU_CAPABILITY_AVX512)
  3. #include <ATen/cpu/vec/vec512/vec512.h>
  4. #else
  5. #include <ATen/cpu/vec/vec256/vec256.h>
  6. #endif
  7. namespace at {
  8. namespace vec {
  9. // See Note [CPU_CAPABILITY namespace]
  10. inline namespace CPU_CAPABILITY {
  11. inline Vectorized<bool> convert_to_bool(Vectorized<int8_t> x) {
  12. __at_align__ bool buffer[x.size()];
  13. x.ne(Vectorized<int8_t>(0)).store(buffer);
  14. Vectorized<bool> ret;
  15. static_assert(x.size() == ret.size(), "");
  16. std::memcpy(ret, buffer, ret.size() * sizeof(bool));
  17. return ret;
  18. }
  19. template <>
  20. inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr) {
  21. // See NOTE [Loading boolean values]
  22. return convert_to_bool(Vectorized<int8_t>::loadu(ptr));
  23. }
  24. template <>
  25. inline Vectorized<bool> Vectorized<bool>::loadu(const void* ptr, int64_t count) {
  26. // See NOTE [Loading boolean values]
  27. return convert_to_bool(Vectorized<int8_t>::loadu(ptr, count));
  28. }
  29. }}} // namespace at::vec::CPU_CAPABILITY