Skip to content

Commit beaff78

Browse files
authored
[libc++] Optimize the std::mismatch tail (#83440)
This adds vectorization to the last 0-3 vectors and, if the range is large enough, the remaining elements that don't fill a vector completely. ``` ----------------------------------------------------------------------- Benchmark old full vectors partial vector ----------------------------------------------------------------------- bm_mismatch<char>/1 1.40 ns 1.62 ns 2.09 ns bm_mismatch<char>/2 1.88 ns 2.10 ns 2.33 ns bm_mismatch<char>/3 2.67 ns 2.56 ns 2.72 ns bm_mismatch<char>/4 3.01 ns 3.20 ns 3.70 ns bm_mismatch<char>/5 3.51 ns 3.73 ns 3.64 ns bm_mismatch<char>/6 4.71 ns 4.85 ns 4.37 ns bm_mismatch<char>/7 5.12 ns 5.33 ns 4.37 ns bm_mismatch<char>/8 5.79 ns 6.02 ns 4.75 ns bm_mismatch<char>/15 9.20 ns 10.5 ns 7.23 ns bm_mismatch<char>/16 10.2 ns 10.1 ns 7.46 ns bm_mismatch<char>/17 10.2 ns 10.8 ns 7.57 ns bm_mismatch<char>/31 17.6 ns 17.1 ns 10.8 ns bm_mismatch<char>/32 17.4 ns 1.64 ns 1.64 ns bm_mismatch<char>/33 23.3 ns 2.10 ns 2.33 ns bm_mismatch<char>/63 31.8 ns 16.9 ns 2.33 ns bm_mismatch<char>/64 32.6 ns 2.10 ns 2.10 ns bm_mismatch<char>/65 33.6 ns 2.57 ns 2.80 ns bm_mismatch<char>/127 67.3 ns 18.1 ns 3.27 ns bm_mismatch<char>/128 2.17 ns 2.14 ns 2.57 ns bm_mismatch<char>/129 2.36 ns 2.80 ns 3.27 ns bm_mismatch<char>/255 67.5 ns 19.6 ns 4.68 ns bm_mismatch<char>/256 3.76 ns 3.71 ns 3.97 ns bm_mismatch<char>/257 3.77 ns 4.04 ns 4.43 ns bm_mismatch<char>/511 70.8 ns 22.1 ns 7.47 ns bm_mismatch<char>/512 7.27 ns 7.30 ns 6.95 ns bm_mismatch<char>/513 7.11 ns 7.05 ns 6.96 ns bm_mismatch<char>/1023 75.9 ns 27.4 ns 13.3 ns bm_mismatch<char>/1024 13.9 ns 13.8 ns 12.4 ns bm_mismatch<char>/1025 13.6 ns 13.6 ns 12.8 ns bm_mismatch<char>/2047 87.3 ns 37.5 ns 25.4 ns bm_mismatch<char>/2048 26.8 ns 27.4 ns 24.0 ns bm_mismatch<char>/2049 26.7 ns 27.3 ns 25.5 ns bm_mismatch<char>/4095 112 ns 64.7 ns 48.7 ns bm_mismatch<char>/4096 53.0 ns 54.2 ns 46.8 ns bm_mismatch<char>/4097 52.7 ns 54.2 ns 48.4 ns bm_mismatch<char>/8191 160 ns 118 ns 98.4 ns bm_mismatch<char>/8192 107 ns 108 ns 96.0 ns bm_mismatch<char>/8193 106 ns 108 ns 97.2 ns bm_mismatch<char>/16383 283 ns 234 ns 215 ns bm_mismatch<char>/16384 227 ns 223 ns 217 ns bm_mismatch<char>/16385 221 ns 221 ns 215 ns bm_mismatch<char>/32767 547 ns 499 ns 488 ns bm_mismatch<char>/32768 495 ns 492 ns 492 ns bm_mismatch<char>/32769 491 ns 489 ns 488 ns bm_mismatch<char>/65535 1028 ns 979 ns 971 ns bm_mismatch<char>/65536 976 ns 970 ns 974 ns bm_mismatch<char>/65537 970 ns 965 ns 971 ns bm_mismatch<char>/131071 2031 ns 1948 ns 2005 ns bm_mismatch<char>/131072 1973 ns 1955 ns 1974 ns bm_mismatch<char>/131073 1989 ns 1932 ns 2001 ns bm_mismatch<char>/262143 4469 ns 4244 ns 4223 ns bm_mismatch<char>/262144 4443 ns 4183 ns 4243 ns bm_mismatch<char>/262145 4400 ns 4232 ns 4246 ns bm_mismatch<char>/524287 10169 ns 9733 ns 9592 ns bm_mismatch<char>/524288 10154 ns 9664 ns 9843 ns bm_mismatch<char>/524289 10113 ns 9641 ns 10003 ns bm_mismatch<short>/1 1.86 ns 2.53 ns 2.32 ns bm_mismatch<short>/2 2.57 ns 2.77 ns 2.55 ns bm_mismatch<short>/3 3.26 ns 3.00 ns 2.79 ns bm_mismatch<short>/4 3.95 ns 3.39 ns 3.15 ns bm_mismatch<short>/5 4.83 ns 3.97 ns 3.72 ns bm_mismatch<short>/6 5.43 ns 4.34 ns 4.03 ns bm_mismatch<short>/7 6.11 ns 4.73 ns 4.44 ns bm_mismatch<short>/8 6.84 ns 5.02 ns 4.79 ns bm_mismatch<short>/15 11.5 ns 7.12 ns 6.50 ns bm_mismatch<short>/16 13.9 ns 1.87 ns 2.11 ns bm_mismatch<short>/17 14.0 ns 3.00 ns 2.47 ns bm_mismatch<short>/31 23.1 ns 7.87 ns 2.47 ns bm_mismatch<short>/32 23.8 ns 2.57 ns 2.81 ns bm_mismatch<short>/33 24.5 ns 3.70 ns 2.94 ns bm_mismatch<short>/63 44.8 ns 9.37 ns 3.46 ns bm_mismatch<short>/64 2.32 ns 2.57 ns 2.64 ns bm_mismatch<short>/65 2.52 ns 3.02 ns 3.51 ns bm_mismatch<short>/127 45.6 ns 9.97 ns 5.18 ns bm_mismatch<short>/128 3.85 ns 3.93 ns 3.94 ns bm_mismatch<short>/129 3.82 ns 4.20 ns 4.70 ns bm_mismatch<short>/255 50.4 ns 12.6 ns 8.07 ns bm_mismatch<short>/256 7.23 ns 6.91 ns 6.98 ns bm_mismatch<short>/257 7.24 ns 7.19 ns 7.55 ns bm_mismatch<short>/511 52.3 ns 17.8 ns 14.0 ns bm_mismatch<short>/512 13.6 ns 13.7 ns 13.6 ns bm_mismatch<short>/513 13.9 ns 13.8 ns 18.5 ns bm_mismatch<short>/1023 60.9 ns 30.9 ns 26.3 ns bm_mismatch<short>/1024 26.7 ns 27.7 ns 25.7 ns bm_mismatch<short>/1025 27.7 ns 27.6 ns 25.3 ns bm_mismatch<short>/2047 88.4 ns 58.0 ns 51.6 ns bm_mismatch<short>/2048 52.8 ns 55.3 ns 50.6 ns bm_mismatch<short>/2049 55.2 ns 54.8 ns 48.7 ns bm_mismatch<short>/4095 153 ns 113 ns 102 ns bm_mismatch<short>/4096 105 ns 110 ns 101 ns bm_mismatch<short>/4097 110 ns 110 ns 99.1 ns bm_mismatch<short>/8191 277 ns 219 ns 206 ns bm_mismatch<short>/8192 226 ns 214 ns 250 ns bm_mismatch<short>/8193 226 ns 207 ns 208 ns bm_mismatch<short>/16383 519 ns 492 ns 488 ns bm_mismatch<short>/16384 494 ns 492 ns 492 ns bm_mismatch<short>/16385 492 ns 488 ns 489 ns bm_mismatch<short>/32767 1007 ns 968 ns 964 ns bm_mismatch<short>/32768 977 ns 972 ns 970 ns bm_mismatch<short>/32769 972 ns 962 ns 967 ns bm_mismatch<short>/65535 1978 ns 1918 ns 1956 ns bm_mismatch<short>/65536 1940 ns 1927 ns 1970 ns bm_mismatch<short>/65537 1937 ns 1922 ns 1959 ns bm_mismatch<short>/131071 4524 ns 4193 ns 4304 ns bm_mismatch<short>/131072 4445 ns 4196 ns 4306 ns bm_mismatch<short>/131073 4452 ns 4278 ns 4311 ns bm_mismatch<short>/262143 9801 ns 10188 ns 9634 ns bm_mismatch<short>/262144 9738 ns 10151 ns 9651 ns bm_mismatch<short>/262145 9716 ns 10171 ns 9715 ns bm_mismatch<short>/524287 19944 ns 20718 ns 20044 ns bm_mismatch<short>/524288 21139 ns 20647 ns 20008 ns bm_mismatch<short>/524289 21162 ns 19512 ns 20068 ns bm_mismatch<int>/1 1.40 ns 1.84 ns 1.87 ns bm_mismatch<int>/2 1.87 ns 2.08 ns 2.09 ns bm_mismatch<int>/3 2.36 ns 2.31 ns 2.87 ns bm_mismatch<int>/4 3.06 ns 2.72 ns 2.95 ns bm_mismatch<int>/5 3.66 ns 3.37 ns 3.42 ns bm_mismatch<int>/6 4.55 ns 3.65 ns 3.73 ns bm_mismatch<int>/7 5.03 ns 3.93 ns 3.94 ns bm_mismatch<int>/8 5.67 ns 1.86 ns 1.87 ns bm_mismatch<int>/15 9.89 ns 4.41 ns 2.34 ns bm_mismatch<int>/16 10.1 ns 2.33 ns 2.34 ns bm_mismatch<int>/17 10.2 ns 3.34 ns 2.86 ns bm_mismatch<int>/31 17.2 ns 5.54 ns 3.28 ns bm_mismatch<int>/32 2.16 ns 2.15 ns 2.58 ns bm_mismatch<int>/33 2.36 ns 3.01 ns 3.28 ns bm_mismatch<int>/63 17.7 ns 6.50 ns 4.93 ns bm_mismatch<int>/64 3.81 ns 3.58 ns 3.90 ns bm_mismatch<int>/65 3.74 ns 4.36 ns 4.45 ns bm_mismatch<int>/127 19.5 ns 9.56 ns 7.74 ns bm_mismatch<int>/128 7.30 ns 6.41 ns 6.85 ns bm_mismatch<int>/129 7.09 ns 7.04 ns 7.06 ns bm_mismatch<int>/255 24.7 ns 14.8 ns 13.3 ns bm_mismatch<int>/256 14.0 ns 12.1 ns 12.3 ns bm_mismatch<int>/257 13.8 ns 12.7 ns 12.8 ns bm_mismatch<int>/511 34.3 ns 26.3 ns 24.8 ns bm_mismatch<int>/512 27.6 ns 23.6 ns 23.9 ns bm_mismatch<int>/513 27.3 ns 24.4 ns 25.1 ns bm_mismatch<int>/1023 62.5 ns 50.9 ns 48.3 ns bm_mismatch<int>/1024 54.4 ns 46.1 ns 46.6 ns bm_mismatch<int>/1025 54.2 ns 48.4 ns 47.5 ns bm_mismatch<int>/2047 116 ns 97.8 ns 94.1 ns bm_mismatch<int>/2048 108 ns 92.6 ns 92.4 ns bm_mismatch<int>/2049 108 ns 104 ns 94.0 ns bm_mismatch<int>/4095 233 ns 222 ns 205 ns bm_mismatch<int>/4096 226 ns 223 ns 225 ns bm_mismatch<int>/4097 221 ns 219 ns 210 ns bm_mismatch<int>/8191 499 ns 485 ns 488 ns bm_mismatch<int>/8192 496 ns 490 ns 495 ns bm_mismatch<int>/8193 491 ns 485 ns 488 ns bm_mismatch<int>/16383 982 ns 962 ns 964 ns bm_mismatch<int>/16384 974 ns 971 ns 971 ns bm_mismatch<int>/16385 971 ns 961 ns 968 ns bm_mismatch<int>/32767 2003 ns 1959 ns 1920 ns bm_mismatch<int>/32768 1996 ns 1947 ns 1928 ns bm_mismatch<int>/32769 1990 ns 1945 ns 1926 ns bm_mismatch<int>/65535 4434 ns 4275 ns 4312 ns bm_mismatch<int>/65536 4437 ns 4267 ns 4321 ns bm_mismatch<int>/65537 4442 ns 4261 ns 4321 ns bm_mismatch<int>/131071 9673 ns 9648 ns 9465 ns bm_mismatch<int>/131072 9667 ns 9671 ns 9465 ns bm_mismatch<int>/131073 9661 ns 9653 ns 9464 ns bm_mismatch<int>/262143 20595 ns 19605 ns 19064 ns bm_mismatch<int>/262144 19894 ns 19572 ns 19009 ns bm_mismatch<int>/262145 19851 ns 19656 ns 18999 ns bm_mismatch<int>/524287 39556 ns 39364 ns 38131 ns bm_mismatch<int>/524288 39678 ns 39573 ns 38183 ns bm_mismatch<int>/524289 40168 ns 39301 ns 38121 ns ```
1 parent 407a2f2 commit beaff78

File tree

3 files changed

+68
-4
lines changed

3 files changed

+68
-4
lines changed

libcxx/benchmarks/algorithms/mismatch.bench.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
#include <benchmark/benchmark.h>
1111
#include <random>
1212

13+
void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) {
14+
Benchmark->DenseRange(1, 8);
15+
for (size_t i = 16; i != 1 << 20; i *= 2) {
16+
Benchmark->Arg(i - 1);
17+
Benchmark->Arg(i);
18+
Benchmark->Arg(i + 1);
19+
}
20+
}
21+
1322
// TODO: Look into benchmarking aligned and unaligned memory explicitly
1423
// (currently things happen to be aligned because they are malloced that way)
1524
template <class T>
@@ -24,8 +33,8 @@ static void bm_mismatch(benchmark::State& state) {
2433
benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
2534
}
2635
}
27-
BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
28-
BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
29-
BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
36+
BENCHMARK(bm_mismatch<char>)->Apply(BenchmarkSizes);
37+
BENCHMARK(bm_mismatch<short>)->Apply(BenchmarkSizes);
38+
BENCHMARK(bm_mismatch<int>)->Apply(BenchmarkSizes);
3039

3140
BENCHMARK_MAIN();

libcxx/include/__algorithm/mismatch.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
6464
constexpr size_t __unroll_count = 4;
6565
constexpr size_t __vec_size = __native_vector_size<_Tp>;
6666
using __vec = __simd_vector<_Tp, __vec_size>;
67+
6768
if (!__libcpp_is_constant_evaluated()) {
69+
auto __orig_first1 = __first1;
70+
auto __last2 = __first2 + (__last1 - __first1);
6871
while (static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] {
6972
__vec __lhs[__unroll_count];
7073
__vec __rhs[__unroll_count];
@@ -84,8 +87,32 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
8487
__first1 += __unroll_count * __vec_size;
8588
__first2 += __unroll_count * __vec_size;
8689
}
90+
91+
// check the remaining 0-3 vectors
92+
while (static_cast<size_t>(__last1 - __first1) >= __vec_size) {
93+
if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2);
94+
!std::__all_of(__cmp_res)) {
95+
auto __offset = std::__find_first_not_set(__cmp_res);
96+
return {__first1 + __offset, __first2 + __offset};
97+
}
98+
__first1 += __vec_size;
99+
__first2 += __vec_size;
100+
}
101+
102+
if (__last1 - __first1 == 0)
103+
return {__first1, __first2};
104+
105+
// Check if we can load elements in front of the current pointer. If that's the case load a vector at
106+
// (last - vector_size) to check the remaining elements
107+
if (static_cast<size_t>(__first1 - __orig_first1) >= __vec_size) {
108+
__first1 = __last1 - __vec_size;
109+
__first2 = __last2 - __vec_size;
110+
auto __offset =
111+
std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
112+
return {__first1 + __offset, __first2 + __offset};
113+
} // else loop over the elements individually
87114
}
88-
// TODO: Consider vectorizing the tail
115+
89116
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
90117
}
91118

libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,33 @@ int main(int, char**) {
184184
}
185185
}
186186

187+
{ // check the tail of the vectorized loop
188+
for (size_t vec_size = 1; vec_size != 256; ++vec_size) {
189+
{
190+
std::vector<char> lhs(256);
191+
std::vector<char> rhs(256);
192+
193+
check<char*>(lhs, rhs, lhs.size());
194+
lhs.back() = 1;
195+
check<char*>(lhs, rhs, lhs.size() - 1);
196+
lhs.back() = 0;
197+
rhs.back() = 1;
198+
check<char*>(lhs, rhs, lhs.size() - 1);
199+
rhs.back() = 0;
200+
}
201+
{
202+
std::vector<int> lhs(256);
203+
std::vector<int> rhs(256);
204+
205+
check<int*>(lhs, rhs, lhs.size());
206+
lhs.back() = 1;
207+
check<int*>(lhs, rhs, lhs.size() - 1);
208+
lhs.back() = 0;
209+
rhs.back() = 1;
210+
check<int*>(lhs, rhs, lhs.size() - 1);
211+
rhs.back() = 0;
212+
}
213+
}
214+
}
187215
return 0;
188216
}

0 commit comments

Comments
 (0)