Skip to content

Commit 6034ac5

Browse files
committed
[libc++] Vectorize trivially equality comparable types
1 parent f905935 commit 6034ac5

File tree

8 files changed

+286
-25
lines changed

8 files changed

+286
-25
lines changed

libcxx/include/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ set(files
456456
__ios/fpos.h
457457
__iterator/access.h
458458
__iterator/advance.h
459+
__iterator/aliasing_iterator.h
459460
__iterator/back_insert_iterator.h
460461
__iterator/bounded_iter.h
461462
__iterator/common_iterator.h

libcxx/include/__algorithm/mismatch.h

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <__algorithm/unwrap_iter.h>
1717
#include <__config>
1818
#include <__functional/identity.h>
19+
#include <__iterator/aliasing_iterator.h>
1920
#include <__type_traits/desugars_to.h>
2021
#include <__type_traits/invoke.h>
2122
#include <__type_traits/is_constant_evaluated.h>
@@ -55,18 +56,13 @@ __mismatch(_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Pred& __pred, _Pro
5556

5657
#if _LIBCPP_VECTORIZE_ALGORITHMS
5758

58-
template <class _Tp,
59-
class _Pred,
60-
class _Proj1,
61-
class _Proj2,
62-
__enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
63-
__is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
64-
int> = 0>
65-
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
66-
__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
59+
template <class _Iter>
60+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter, _Iter>
61+
__mismatch_vectorized(_Iter __first1, _Iter __last1, _Iter __first2) {
62+
using __value_type = __iter_value_type<_Iter>;
6763
constexpr size_t __unroll_count = 4;
68-
constexpr size_t __vec_size = __native_vector_size<_Tp>;
69-
using __vec = __simd_vector<_Tp, __vec_size>;
64+
constexpr size_t __vec_size = __native_vector_size<__value_type>;
65+
using __vec = __simd_vector<__value_type, __vec_size>;
7066

7167
if (!__libcpp_is_constant_evaluated()) {
7268
auto __orig_first1 = __first1;
@@ -116,9 +112,41 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
116112
} // else loop over the elements individually
117113
}
118114

119-
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
115+
__equal_to __pred;
116+
__identity __proj;
117+
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj, __proj);
118+
}
119+
120+
template <class _Tp,
121+
class _Pred,
122+
class _Proj1,
123+
class _Proj2,
124+
__enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
125+
__is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
126+
int> = 0>
127+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
128+
__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred&, _Proj1&, _Proj2&) {
129+
return std::__mismatch_vectorized(__first1, __last1, __first2);
120130
}
121131

132+
template <class _Tp,
133+
class _Pred,
134+
class _Proj1,
135+
class _Proj2,
136+
__enable_if_t<!is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
137+
__is_identity<_Proj1>::value && __is_identity<_Proj2>::value &&
138+
__can_map_to_integer_v<_Tp> && __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
139+
int> = 0>
140+
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
141+
__mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __proj1, _Proj2& __proj2) {
142+
if (__libcpp_is_constant_evaluated()) {
143+
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
144+
} else {
145+
using _Iter = __aliasing_iterator<_Tp*, __get_as_integer_type_t<_Tp>>;
146+
auto __ret = std::__mismatch_vectorized(_Iter(__first1), _Iter(__last1), _Iter(__first2));
147+
return {__ret.first.__base(), __ret.second.__base()};
148+
}
149+
}
122150
#endif // _LIBCPP_VECTORIZE_ALGORITHMS
123151

124152
template <class _InputIterator1, class _InputIterator2, class _BinaryPredicate>

libcxx/include/__algorithm/simd_utils.h

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,34 @@ _LIBCPP_PUSH_MACROS
4343

4444
_LIBCPP_BEGIN_NAMESPACE_STD
4545

46+
template <class _Tp>
47+
inline constexpr bool __can_map_to_integer_v =
48+
sizeof(_Tp) == alignof(_Tp) && (sizeof(_Tp) == 1 || sizeof(_Tp) == 2 || sizeof(_Tp) == 4 || sizeof(_Tp) == 8);
49+
50+
template <size_t _TypeSize>
51+
struct __get_as_integer_type_impl;
52+
53+
template <>
54+
struct __get_as_integer_type_impl<1> {
55+
using type = uint8_t;
56+
};
57+
58+
template <>
59+
struct __get_as_integer_type_impl<2> {
60+
using type = uint16_t;
61+
};
62+
template <>
63+
struct __get_as_integer_type_impl<4> {
64+
using type = uint32_t;
65+
};
66+
template <>
67+
struct __get_as_integer_type_impl<8> {
68+
using type = uint64_t;
69+
};
70+
71+
template <class _Tp>
72+
using __get_as_integer_type_t = typename __get_as_integer_type_impl<sizeof(_Tp)>::type;
73+
4674
// This isn't specialized for 64 byte vectors on purpose. They have the potential to significantly reduce performance
4775
// in mixed simd/non-simd workloads and don't provide any performance improvement for currently vectorized algorithms
4876
// as far as benchmarks are concerned.
@@ -80,10 +108,10 @@ template <class _VecT>
80108
using __simd_vector_underlying_type_t = decltype(std::__simd_vector_underlying_type_impl(_VecT{}));
81109

82110
// This isn't inlined without always_inline when loading chars.
83-
template <class _VecT, class _Tp>
84-
_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(const _Tp* __ptr) noexcept {
111+
template <class _VecT, class _Iter>
112+
_LIBCPP_NODISCARD _LIBCPP_ALWAYS_INLINE _LIBCPP_HIDE_FROM_ABI _VecT __load_vector(_Iter __iter) noexcept {
85113
return [=]<size_t... _Indices>(index_sequence<_Indices...>) _LIBCPP_ALWAYS_INLINE noexcept {
86-
return _VecT{__ptr[_Indices]...};
114+
return _VecT{__iter[_Indices]...};
87115
}(make_index_sequence<__simd_vector_size_v<_VecT>>{});
88116
}
89117

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
10+
#define _LIBCPP___ITERATOR_ALIASING_ITERATOR_H
11+
12+
#include <__compare/strong_order.h>
13+
#include <__config>
14+
#include <__iterator/iterator_traits.h>
15+
#include <__memory/pointer_traits.h>
16+
#include <__type_traits/is_trivial.h>
17+
#include <cstddef>
18+
19+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
20+
# pragma GCC system_header
21+
#endif
22+
23+
// This iterator wrapper is used to type-pun an iterator to return a different type. This is done without UB by not
24+
// actually punning the type, but instread inspecting the object representation of the base type and copying that into
25+
// an instance of the alias type. For that reason the alias type has to be trivial. The alias is returned as a prvalue
26+
// when derferencing the iterator, since it is temporary storage. This wrapper is used to vectorize some algorithms.
27+
28+
_LIBCPP_BEGIN_NAMESPACE_STD
29+
30+
template <class _BaseIter, class _Alias>
31+
struct __aliasing_iterator_wrapper {
32+
class __iterator {
33+
_BaseIter __base_ = nullptr;
34+
35+
using __iter_traits = iterator_traits<_BaseIter>;
36+
using __base_value_type = typename __iter_traits::value_type;
37+
38+
static_assert(__has_random_access_iterator_category<_BaseIter>::value,
39+
"The base iterator has to be a random access iterator!");
40+
41+
public:
42+
using iterator_category = random_access_iterator_tag;
43+
using value_type = _Alias;
44+
using difference_type = ptrdiff_t;
45+
46+
static_assert(is_trivial<value_type>::value);
47+
static_assert(sizeof(__base_value_type) == sizeof(value_type));
48+
49+
_LIBCPP_HIDE_FROM_ABI __iterator() = default;
50+
_LIBCPP_HIDE_FROM_ABI __iterator(_BaseIter __base) _NOEXCEPT : __base_(__base) {}
51+
52+
_LIBCPP_HIDE_FROM_ABI __iterator& operator++() _NOEXCEPT {
53+
++__base_;
54+
return *this;
55+
}
56+
57+
_LIBCPP_HIDE_FROM_ABI __iterator operator++(int) _NOEXCEPT {
58+
__iterator __tmp(*this);
59+
++__base_;
60+
return __tmp;
61+
}
62+
63+
_LIBCPP_HIDE_FROM_ABI __iterator& operator--() _NOEXCEPT {
64+
--__base_;
65+
return *this;
66+
}
67+
68+
_LIBCPP_HIDE_FROM_ABI __iterator operator--(int) _NOEXCEPT {
69+
__iterator __tmp(*this);
70+
--__base_;
71+
return __tmp;
72+
}
73+
74+
_LIBCPP_HIDE_FROM_ABI friend __iterator operator+(__iterator __iter, difference_type __n) _NOEXCEPT {
75+
return __iterator(__iter.__base_ + __n);
76+
}
77+
78+
_LIBCPP_HIDE_FROM_ABI friend __iterator operator+(difference_type __n, __iterator __iter) _NOEXCEPT {
79+
return __iterator(__n + __iter.__base_);
80+
}
81+
82+
_LIBCPP_HIDE_FROM_ABI __iterator& operator+=(difference_type __n) _NOEXCEPT {
83+
__base_ += __n;
84+
return *this;
85+
}
86+
87+
_LIBCPP_HIDE_FROM_ABI friend __iterator operator-(__iterator __iter, difference_type __n) _NOEXCEPT {
88+
return __iterator(__iter.__base_ - __n);
89+
}
90+
91+
_LIBCPP_HIDE_FROM_ABI friend difference_type operator-(__iterator __lhs, __iterator __rhs) _NOEXCEPT {
92+
return __lhs.__base_ - __rhs.__base_;
93+
}
94+
95+
_LIBCPP_HIDE_FROM_ABI __iterator& operator-=(difference_type __n) _NOEXCEPT {
96+
__base_ -= __n;
97+
return *this;
98+
}
99+
100+
_LIBCPP_HIDE_FROM_ABI _BaseIter __base() const _NOEXCEPT { return __base_; }
101+
102+
_LIBCPP_HIDE_FROM_ABI _Alias operator*() const _NOEXCEPT {
103+
_Alias __val;
104+
__builtin_memcpy(&__val, std::__to_address(__base_), sizeof(value_type));
105+
return __val;
106+
}
107+
108+
_LIBCPP_HIDE_FROM_ABI value_type operator[](difference_type __n) const _NOEXCEPT { return *(*this + __n); }
109+
110+
_LIBCPP_HIDE_FROM_ABI friend bool operator==(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
111+
return __lhs.__base_ == __rhs.__base_;
112+
}
113+
114+
_LIBCPP_HIDE_FROM_ABI friend bool operator!=(const __iterator& __lhs, const __iterator& __rhs) _NOEXCEPT {
115+
return __lhs.__base_ != __rhs.__base_;
116+
}
117+
118+
#if _LIBCPP_STD_VER >= 20
119+
_LIBCPP_HIDE_FROM_ABI friend auto operator<=>(__iterator, __iterator) noexcept = default;
120+
#endif
121+
};
122+
};
123+
124+
// This is required to avoid ADL instantiations on _BaseT
125+
template <class _BaseT, class _Alias>
126+
using __aliasing_iterator = __aliasing_iterator_wrapper<_BaseT, _Alias>::__iterator;
127+
128+
_LIBCPP_END_NAMESPACE_STD
129+
130+
#endif // _LIBCPP___ITERATOR_ALIASING_ITERATOR_H

libcxx/include/libcxx.imp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@
450450
{ include: [ "<__ios/fpos.h>", "private", "<ios>", "public" ] },
451451
{ include: [ "<__iterator/access.h>", "private", "<iterator>", "public" ] },
452452
{ include: [ "<__iterator/advance.h>", "private", "<iterator>", "public" ] },
453+
{ include: [ "<__iterator/aliasing_iterator.h>", "private", "<iterator>", "public" ] },
453454
{ include: [ "<__iterator/back_insert_iterator.h>", "private", "<iterator>", "public" ] },
454455
{ include: [ "<__iterator/bounded_iter.h>", "private", "<iterator>", "public" ] },
455456
{ include: [ "<__iterator/common_iterator.h>", "private", "<iterator>", "public" ] },

libcxx/include/module.modulemap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,7 @@ module std_private_iosfwd_streambuf_fwd [system] { header "__fwd/streambuf.h" }
14061406

14071407
module std_private_iterator_access [system] { header "__iterator/access.h" }
14081408
module std_private_iterator_advance [system] { header "__iterator/advance.h" }
1409+
module std_private_iterator_aliasing_iterator [system] { header "__iterator/aliasing_iterator.h" }
14091410
module std_private_iterator_back_insert_iterator [system] { header "__iterator/back_insert_iterator.h" }
14101411
module std_private_iterator_bounded_iter [system] { header "__iterator/bounded_iter.h" }
14111412
module std_private_iterator_common_iterator [system] { header "__iterator/common_iterator.h" }
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// ADDITIONAL_COMPILE_FLAGS: -Wprivate-header
10+
11+
#include <__iterator/aliasing_iterator.h>
12+
#include <cassert>
13+
14+
struct NonTrivial {
15+
int i_;
16+
17+
NonTrivial(int i) : i_(i) {}
18+
NonTrivial(const NonTrivial& other) : i_(other.i_) {}
19+
20+
NonTrivial& operator=(const NonTrivial& other) {
21+
i_ = other.i_;
22+
return *this;
23+
}
24+
25+
~NonTrivial() {}
26+
};
27+
28+
int main(int, char**) {
29+
{
30+
NonTrivial arr[] = {1, 2, 3, 4};
31+
std::__aliasing_iterator<NonTrivial*, int> iter(arr);
32+
33+
assert(*iter == 1);
34+
assert(iter[0] == 1);
35+
assert(iter[1] == 2);
36+
++iter;
37+
assert(*iter == 2);
38+
assert(iter[-1] == 1);
39+
assert(iter.__base() == arr + 1);
40+
assert(iter == iter);
41+
assert(iter != (iter + 1));
42+
}
43+
44+
return 0;
45+
}

libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,27 @@ TEST_CONSTEXPR_CXX20 void check(Container1 lhs, Container2 rhs, size_t offset) {
6666
#endif
6767
}
6868

69-
struct NonTrivial {
69+
// Compares modulo 4 to make sure we only forward to the vectorized version if we are trivially equality comparable
70+
struct NonTrivialMod4Comp {
7071
int i_;
7172

72-
TEST_CONSTEXPR_CXX20 NonTrivial(int i) : i_(i) {}
73-
TEST_CONSTEXPR_CXX20 NonTrivial(NonTrivial&& other) : i_(other.i_) { other.i_ = 0; }
73+
TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(int i) : i_(i) {}
74+
TEST_CONSTEXPR_CXX20 NonTrivialMod4Comp(NonTrivialMod4Comp&& other) : i_(other.i_) { other.i_ = 0; }
7475

75-
TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivial& lhs, const NonTrivial& rhs) { return lhs.i_ == rhs.i_; }
76+
TEST_CONSTEXPR_CXX20 friend bool operator==(const NonTrivialMod4Comp& lhs, const NonTrivialMod4Comp& rhs) {
77+
return lhs.i_ % 4 == rhs.i_ % 4;
78+
}
79+
};
80+
81+
#if TEST_STD_VER >= 20
82+
struct TriviallyEqualityComparable {
83+
int i_;
84+
85+
TEST_CONSTEXPR_CXX20 TriviallyEqualityComparable(int i) : i_(i) {}
86+
87+
TEST_CONSTEXPR_CXX20 friend bool operator==(TriviallyEqualityComparable, TriviallyEqualityComparable) = default;
7688
};
89+
#endif // TEST_STD_VER >= 20
7790

7891
struct ModTwoComp {
7992
TEST_CONSTEXPR_CXX20 bool operator()(int lhs, int rhs) { return lhs % 2 == rhs % 2; }
@@ -136,16 +149,30 @@ TEST_CONSTEXPR_CXX20 bool test() {
136149
types::for_each(types::cpp17_input_iterator_list<int*>(), Test());
137150

138151
{ // use a non-integer type to also test the general case - all elements match
139-
std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
140-
std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
141-
check<NonTrivial*>(std::move(lhs), std::move(rhs), 8);
152+
std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
153+
std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 1, 6, 7, 8};
154+
check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 8);
142155
}
143156

144157
{ // use a non-integer type to also test the general case - not all elements match
145-
std::array<NonTrivial, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
146-
std::array<NonTrivial, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
147-
check<NonTrivial*>(std::move(lhs), std::move(rhs), 4);
158+
std::array<NonTrivialMod4Comp, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
159+
std::array<NonTrivialMod4Comp, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
160+
check<NonTrivialMod4Comp*>(std::move(lhs), std::move(rhs), 4);
161+
}
162+
163+
#if TEST_STD_VER >= 20
164+
{ // trivially equality comparable class type to test forwarding to the vectorized version - all elements match
165+
std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 5, 6, 7, 8};
166+
std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
167+
check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 8);
168+
}
169+
170+
{ // trivially equality comparable class type to test forwarding to the vectorized version - not all elements match
171+
std::array<TriviallyEqualityComparable, 8> lhs = {1, 2, 3, 4, 7, 6, 7, 8};
172+
std::array<TriviallyEqualityComparable, 8> rhs = {1, 2, 3, 4, 5, 6, 7, 8};
173+
check<TriviallyEqualityComparable*>(std::move(lhs), std::move(rhs), 4);
148174
}
175+
#endif // TEST_STD_VER >= 20
149176

150177
return true;
151178
}

0 commit comments

Comments
 (0)