Skip to content

Commit b40d7cd

Browse files
committed
[libc++] Explicitly convert to masks
1 parent 1c334de commit b40d7cd

File tree

2 files changed

+50
-31
lines changed

2 files changed

+50
-31
lines changed

libcxx/include/__algorithm/mismatch.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ __mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
7777
}
7878

7979
for (size_t __i = 0; __i != __unroll_count; ++__i) {
80-
if (auto __cmp_res = __lhs[__i] == __rhs[__i]; !std::__all_of(__cmp_res)) {
80+
if (auto __cmp_res = std::__as_mask(__lhs[__i] == __rhs[__i]); !std::__all_of(__cmp_res)) {
8181
auto __offset = __i * __vec_size + std::__find_first_not_set(__cmp_res);
8282
return {__first1 + __offset, __first2 + __offset};
8383
}
@@ -89,7 +89,7 @@ __mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
8989

9090
// check the remaining 0-3 vectors
9191
while (static_cast<size_t>(__last1 - __first1) >= __vec_size) {
92-
if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2);
92+
if (auto __cmp_res = std::__as_mask(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
9393
!std::__all_of(__cmp_res)) {
9494
auto __offset = std::__find_first_not_set(__cmp_res);
9595
return {__first1 + __offset, __first2 + __offset};
@@ -106,8 +106,8 @@ __mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
106106
if (static_cast<size_t>(__first1 - __orig_first1) >= __vec_size) {
107107
__first1 = __last1 - __vec_size;
108108
__first2 = __last2 - __vec_size;
109-
auto __offset =
110-
std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
109+
auto __offset = std::__find_first_not_set(
110+
std::__as_mask(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2)));
111111
return {__first1 + __offset, __first2 + __offset};
112112
} // else loop over the elements individually
113113
}

libcxx/include/__algorithm/simd_utils.h

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,42 +116,61 @@ _LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vecto
116116
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
117117
}
118118

119-
template <class _Tp, size_t _Np>
120-
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI bool __all_of(__simd_vector<_Tp, _Np> __vec) noexcept {
121-
return __builtin_reduce_and(__builtin_convertvector(__vec, __simd_vector<bool, _Np>));
119+
template <size_t _Np>
120+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI bool __all_of(__simd_vector<bool, _Np> __vec) noexcept {
121+
return __builtin_reduce_and(__vec);
122122
}
123123

124124
template <class _Tp, size_t _Np>
125-
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_set(__simd_vector<_Tp, _Np> __vec) noexcept {
126-
using __mask_vec = __simd_vector<bool, _Np>;
125+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI auto __as_mask(__simd_vector<_Tp, _Np> __vec) noexcept {
126+
static_assert(!is_same<_Tp, bool>::value, "vector type should not be a bool!");
127+
return __builtin_convertvector(__vec, __simd_vector<bool, _Np>);
128+
}
127129

128-
// This has MSan disabled du to https://github.com/llvm/llvm-project/issues/85876
129-
auto __impl = [&]<class _MaskT>(_MaskT) _LIBCPP_NO_SANITIZE("memory") noexcept {
130-
# if defined(_LIBCPP_BIG_ENDIAN)
131-
return std::min<size_t>(
132-
_Np, std::__countl_zero(__builtin_bit_cast(_MaskT, __builtin_convertvector(__vec, __mask_vec))));
133-
# else
134-
return std::min<size_t>(
135-
_Np, std::__countr_zero(__builtin_bit_cast(_MaskT, __builtin_convertvector(__vec, __mask_vec))));
136-
# endif
137-
};
138-
139-
if constexpr (sizeof(__mask_vec) == sizeof(uint8_t)) {
140-
return __impl(uint8_t{});
141-
} else if constexpr (sizeof(__mask_vec) == sizeof(uint16_t)) {
142-
return __impl(uint16_t{});
143-
} else if constexpr (sizeof(__mask_vec) == sizeof(uint32_t)) {
144-
return __impl(uint32_t{});
145-
} else if constexpr (sizeof(__mask_vec) == sizeof(uint64_t)) {
146-
return __impl(uint64_t{});
130+
// This uses __builtin_convertvector around the __builtin_shufflevector to work around #107981.
131+
template <size_t _Np>
132+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI auto __extend_vector(__simd_vector<bool, _Np> __vec) noexcept {
133+
using _VecT = __simd_vector<bool, _Np>;
134+
if constexpr (_Np == 4) {
135+
return __builtin_convertvector(
136+
__builtin_shufflevector(__vec, _VecT{}, 0, 1, 2, 3, 4, 5, 6, 7), __simd_vector<bool, 8>);
137+
} else if constexpr (_Np == 2) {
138+
return std::__extend_vector(
139+
__builtin_convertvector(__builtin_shufflevector(__vec, _VecT{}, 0, 1, 2, 3), __simd_vector<bool, 4>));
147140
} else {
148-
static_assert(sizeof(__mask_vec) == 0, "unexpected required size for mask integer type");
141+
static_assert(sizeof(_VecT) == 0, "Unexpected vector size");
142+
}
143+
}
144+
145+
template <size_t _Np>
146+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI auto __to_int_mask(__simd_vector<bool, _Np> __vec) {
147+
if constexpr (_Np < 8) {
148+
return std::__bit_cast<uint8_t>(std::__extend_vector(__vec));
149+
} else if constexpr (_Np == 8) {
150+
return std::__bit_cast<uint8_t>(__vec);
151+
} else if constexpr (_Np == 16) {
152+
return std::__bit_cast<uint16_t>(__vec);
153+
} else if constexpr (_Np == 32) {
154+
return std::__bit_cast<uint32_t>(__vec);
155+
} else if constexpr (_Np == 64) {
156+
return std::__bit_cast<uint64_t>(__vec);
157+
} else {
158+
static_assert(sizeof(__simd_vector<bool, _Np>) == 0, "Unexpected vector size");
149159
return 0;
150160
}
151161
}
152162

153-
template <class _Tp, size_t _Np>
154-
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_not_set(__simd_vector<_Tp, _Np> __vec) noexcept {
163+
template <size_t _Np>
164+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_set(__simd_vector<bool, _Np> __vec) noexcept {
165+
# if defined(_LIBCPP_BIG_ENDIAN)
166+
return std::min<size_t>(_Np, std::__countl_zero(std::__to_int_mask(__vec)));
167+
# else
168+
return std::min<size_t>(_Np, std::__countr_zero(std::__to_int_mask(__vec)));
169+
# endif
170+
}
171+
172+
template <size_t _Np>
173+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI size_t __find_first_not_set(__simd_vector<bool, _Np> __vec) noexcept {
155174
return std::__find_first_set(~__vec);
156175
}
157176

0 commit comments

Comments
 (0)