Skip to content

Commit dc09211

Browse files
committed
[libc++] Vectorize trivially equality comparable types
1 parent f5960c1 commit dc09211

File tree

3 files changed

+86
-11
lines changed

3 files changed

+86
-11
lines changed

libcxx/include/__algorithm/mismatch.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <__algorithm/unwrap_iter.h>
1717
#include <__config>
1818
#include <__functional/identity.h>
19+
#include <__type_traits/copy_cv.h>
1920
#include <__type_traits/desugars_to.h>
2021
#include <__type_traits/invoke.h>
2122
#include <__type_traits/is_constant_evaluated.h>
@@ -119,6 +120,31 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
119120
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
120121
}
121122

123+
template <class _Tp,
124+
class _Pred,
125+
class _Proj1,
126+
class _Proj2,
127+
__enable_if_t<!is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
128+
__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
129+
__can_map_to_integer_v<_Tp> && __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
130+
int> = 0>
131+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
132+
__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
133+
if (__libcpp_is_constant_evaluated()) {
134+
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
135+
} else {
136+
using __integer_t = __copy_cv_t<_Tp, __get_as_integer_type<_Tp>>;
137+
// This is valid because we disable TBAA when loading vectors. Alignment requirements still have to be fulfilled.
138+
auto __ret = std::__mismatch(
139+
reinterpret_cast<__integer_t*>(__first1),
140+
reinterpret_cast<__integer_t*>(__last1),
141+
reinterpret_cast<__integer_t*>(__first2),
142+
__pred,
143+
__proj1,
144+
__proj2);
145+
return {reinterpret_cast<_Tp*>(__ret.first), reinterpret_cast<_Tp*>(__ret.second)};
146+
}
147+
}
122148
#endif // _LIBCPP_VECTORIZE_ALGORITHMS
123149

124150
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>

libcxx/include/__algorithm/simd_utils.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,27 @@ _LIBCPP_PUSH_MACROS
4343

4444
_LIBCPP_BEGIN_NAMESPACE_STD
4545

46+
template <class _Tp>
47+
inline constexpr bool __can_map_to_integer_v =
48+
sizeof(_Tp) == alignof(_Tp) && (sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8);
49+
50+
template <class _Tp>
51+
_LIBCPP_HIDE_FROM_ABI auto __get_as_integer_type_impl() {
52+
if constexpr (sizeof(_Tp) == 1)
53+
return uint8_t{};
54+
else if constexpr (sizeof(_Tp) == 2)
55+
return uint16_t{};
56+
else if constexpr (sizeof(_Tp) == 4)
57+
return uint32_t{};
58+
else if constexpr (sizeof(_Tp) == 8)
59+
return uint64_t{};
60+
else
61+
static_assert(false, "Unexpected size type");
62+
}
63+
64+
template <class _Tp>
65+
using __get_as_integer_type = decltype(std::__get_as_integer_type_impl<_Tp>());
66+
4667
// This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
4768
// in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
4869
// as far as benchmarks are concerned.
@@ -83,7 +104,8 @@ using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_t
83104
template <class _VecT, class _Tp>
84105
_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
85106
return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
86-
return _VecT{__ptr[_Indices]...};
107+
[[__gnu__::__may_alias__]] const _Tp* __aliasing_ptr = __ptr;
108+
return _VecT{__aliasing_ptr[_Indices]...};
87109
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
88110
}
89111

libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,27 @@ TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
6666
#endif
6767
}
6868

69-
struct NonTrivial {
69+
// Compares modulo 4 to make sure we only forward to the vectorized version if we are trivially equality comparable
70+
struct NonTrivialMod4Comp {
7071
int i_;
7172

72-
TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
73-
TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
73+
TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(int i) : i_(i) {}
74+
TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(NonTrivialMod4Comp&& other) : i_(other.i_) { other.i_ = 0; }
7475

75-
TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
76+
TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivialMod4Comp& lhs, const NonTrivialMod4Comp& rhs) {
77+
return lhs.i_ % 4 == rhs.i_ % 4;
78+
}
79+
};
80+
81+
#if TEST_STD_VER >= 20
82+
struct TriviallyEqualityComparable {
83+
int i_;
84+
85+
TEST_CONSTEXPR_CXX20 TriviallyEqualityComparable(int i) : i_(i) {}
86+
87+
TEST_CONSTEXPR_CXX20 friend bool operator==(TriviallyEqualityComparable, TriviallyEqualityComparable) = default;
7688
};
89+
#endif // TEST_STD_VER >= 20
7790

7891
struct ModTwoComp {
7992
TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
@@ -136,16 +149,30 @@ TEST_CONSTEXPR_CXX20 bool test() {
136149
types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
137150

138151
{ // use a non-integer type to also test the general case - all elements match
139-
std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
140-
std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
141-
check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
152+
std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
153+
std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 1, 6, 7, 8};
154+
check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 8);
142155
}
143156

144157
{ // use a non-integer type to also test the general case - not all elements match
145-
std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
146-
std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
147-
check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
158+
std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
159+
std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
160+
check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 4);
161+
}
162+
163+
#if TEST_STD_VER >= 20
164+
{ // trivially equaltiy comparable class type to test forwarding to the vectorized version - all elements match
165+
std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
166+
std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
167+
check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 8);
168+
}
169+
170+
{ // trivially equaltiy comparable class type to test forwarding to the vectorized version - not all elements match
171+
std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
172+
std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
173+
check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 4);
148174
}
175+
#endif // TEST_STD_VER >= 20
149176

150177
return true;
151178
}

0 commit comments

Comments
 (0)