Skip to content

Commit 8ddae7d

Browse files
committed
[libc++] Optimize mismatch tail
1 parent b68e2eb commit 8ddae7d

File tree

3 files changed

+72
-4
lines changed

3 files changed

+72
-4
lines changed

libcxx/benchmarks/algorithms/mismatch.bench.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010
#include <benchmark/benchmark.h>
1111
#include <random>
1212

13+
void BenchmarkSizes(benchmark::internal::Benchmark* Benchmark) {
14+
Benchmark->DenseRange(1, 8);
15+
for (size_t i = 16; i != 1 << 20; i *= 2) {
16+
Benchmark->Arg(i - 1);
17+
Benchmark->Arg(i);
18+
Benchmark->Arg(i + 1);
19+
}
20+
}
21+
1322
// TODO: Look into benchmarking aligned and unaligned memory explicitly
1423
// (currently things happen to be aligned because they are malloced that way)
1524
template <class T>
@@ -24,8 +33,8 @@ static void bm_mismatch(benchmark::State& state) {
2433
benchmark::DoNotOptimize(std::mismatch(vec1.begin(), vec1.end(), vec2.begin()));
2534
}
2635
}
27-
BENCHMARK(bm_mismatch<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
28-
BENCHMARK(bm_mismatch<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
29-
BENCHMARK(bm_mismatch<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
36+
BENCHMARK(bm_mismatch<char>)->Apply(BenchmarkSizes);
37+
BENCHMARK(bm_mismatch<short>)->Apply(BenchmarkSizes);
38+
BENCHMARK(bm_mismatch<int>)->Apply(BenchmarkSizes);
3039

3140
BENCHMARK_MAIN();

libcxx/include/__algorithm/mismatch.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
6464
constexpr size_t __unroll_count = 4;
6565
constexpr size_t __vec_size = __native_vector_size<_Tp>;
6666
using __vec = __simd_vector<_Tp, __vec_size>;
67+
6768
if (!__libcpp_is_constant_evaluated()) {
69+
auto __orig_first1 = __first1;
70+
auto __last2 = __first2 + (__last1 - __first1);
6871
while (static_cast<size_t>(__last1 - __first1) >= __unroll_count * __vec_size) [[__unlikely__]] {
6972
__vec __lhs[__unroll_count];
7073
__vec __rhs[__unroll_count];
@@ -84,8 +87,36 @@ __mismatch(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Pred& __pred, _Proj1& __
8487
__first1 += __unroll_count * __vec_size;
8588
__first2 += __unroll_count * __vec_size;
8689
}
90+
91+
// check the remaining 0-3 vectors
92+
while (static_cast<size_t>(__last1 - __first1) >= __vec_size) {
93+
if (auto __cmp_res = std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2);
94+
!std::__all_of(__cmp_res)) {
95+
auto __offset = std::__find_first_not_set(__cmp_res);
96+
return {__first1 + __offset, __first2 + __offset};
97+
}
98+
__first1 += __vec_size;
99+
__first2 += __vec_size;
100+
}
101+
102+
if (__last1 - __first1 == 0)
103+
return {__first1, __first2};
104+
105+
// Check if we can load elements in fron of the current pointer. If that's the case load a vector at
106+
// (last - vector_size) to check the remaining elements
107+
if (static_cast<size_t>(__first1 - __orig_first1) >= __vec_size) {
108+
__first1 = __last1 - __vec_size;
109+
__first2 = __last2 - __vec_size;
110+
auto __offset =
111+
std::__find_first_not_set(std::__load_vector<__vec>(__first1) == std::__load_vector<__vec>(__first2));
112+
return {__first1 + __offset, __first2 + __offset};
113+
} // else loop over the elements individually
114+
115+
// TODO: Consider vectorizing the loop tail further with
116+
// - smaller vectors
117+
// - loading bytes out of range if it's known to be safe
87118
}
88-
// TODO: Consider vectorizing the tail
119+
89120
return std::__mismatch_loop(__first1, __last1, __first2, __pred, __proj1, __proj2);
90121
}
91122

libcxx/test/std/algorithms/alg.nonmodifying/mismatch/mismatch.pass.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,5 +184,33 @@ int main(int, char**) {
184184
}
185185
}
186186

187+
{ // check the tail of the vectorized loop
188+
for (size_t vec_size = 1; vec_size != 256; ++vec_size) {
189+
{
190+
std::vector<char> lhs(256);
191+
std::vector<char> rhs(256);
192+
193+
check<char*>(lhs, rhs, lhs.size());
194+
lhs.back() = 1;
195+
check<char*>(lhs, rhs, lhs.size() - 1);
196+
lhs.back() = 0;
197+
rhs.back() = 1;
198+
check<char*>(lhs, rhs, lhs.size() - 1);
199+
rhs.back() = 0;
200+
}
201+
{
202+
std::vector<int> lhs(256);
203+
std::vector<int> rhs(256);
204+
205+
check<int*>(lhs, rhs, lhs.size());
206+
lhs.back() = 1;
207+
check<int*>(lhs, rhs, lhs.size() - 1);
208+
lhs.back() = 0;
209+
rhs.back() = 1;
210+
check<int*>(lhs, rhs, lhs.size() - 1);
211+
rhs.back() = 0;
212+
}
213+
}
214+
}
187215
return 0;
188216
}

0 commit comments

Comments
 (0)