Skip to content

[libc++] Optimize lexicographical_compare #65279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion libcxx/docs/ReleaseNotes/20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ Implemented Papers
Improvements and New Features
-----------------------------

- TODO
- The ``lexicographical_compare`` and ``ranges::lexicographical_compare`` algorithms have been optimized for trivially
equality comparable types, resulting in a performance improvement of up to 40x.


Deprecations and Removals
Expand Down
3 changes: 2 additions & 1 deletion libcxx/include/__algorithm/comp.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <__config>
#include <__type_traits/desugars_to.h>
#include <__type_traits/is_integral.h>

#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
Expand Down Expand Up @@ -42,7 +43,7 @@ struct __less<void, void> {
};

template <class _Tp>
inline const bool __desugars_to_v<__less_tag, __less<>, _Tp, _Tp> = true;
inline const bool __desugars_to_v<__totally_ordered_less_tag, __less<>, _Tp, _Tp> = is_integral<_Tp>::value;

_LIBCPP_END_NAMESPACE_STD

Expand Down
93 changes: 82 additions & 11 deletions libcxx/include/__algorithm/lexicographical_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,109 @@
#define _LIBCPP___ALGORITHM_LEXICOGRAPHICAL_COMPARE_H

#include <__algorithm/comp.h>
#include <__algorithm/comp_ref_type.h>
#include <__algorithm/min.h>
#include <__algorithm/mismatch.h>
#include <__algorithm/simd_utils.h>
#include <__algorithm/unwrap_iter.h>
#include <__config>
#include <__functional/identity.h>
#include <__iterator/iterator_traits.h>
#include <__string/constexpr_c_functions.h>
#include <__type_traits/desugars_to.h>
#include <__type_traits/invoke.h>
#include <__type_traits/is_equality_comparable.h>
#include <__type_traits/is_integral.h>
#include <__type_traits/is_trivially_lexicographically_comparable.h>
#include <__type_traits/is_volatile.h>

#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS
# include <cwchar>
#endif

#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
#endif

_LIBCPP_PUSH_MACROS
#include <__undef_macros>

_LIBCPP_BEGIN_NAMESPACE_STD

template <class _Compare, class _InputIterator1, class _InputIterator2>
template <class _Iter1, class _Sent1, class _Iter2, class _Sent2, class _Proj1, class _Proj2, class _Comp>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool __lexicographical_compare(
_InputIterator1 __first1,
_InputIterator1 __last1,
_InputIterator2 __first2,
_InputIterator2 __last2,
_Compare __comp) {
for (; __first2 != __last2; ++__first1, (void)++__first2) {
if (__first1 == __last1 || __comp(*__first1, *__first2))
_Iter1 __first1, _Sent1 __last1, _Iter2 __first2, _Sent2 __last2, _Comp& __comp, _Proj1& __proj1, _Proj2& __proj2) {
while (__first2 != __last2) {
if (__first1 == __last1 ||
std::__invoke(__comp, std::__invoke(__proj1, *__first1), std::__invoke(__proj2, *__first2)))
return true;
if (__comp(*__first2, *__first1))
if (std::__invoke(__comp, std::__invoke(__proj2, *__first2), std::__invoke(__proj1, *__first1)))
return false;
++__first1;
++__first2;
}
return false;
}

#if _LIBCPP_STD_VER >= 14

// If the comparison operation is equivalent to < and that is a total order, we know that we can use equality comparison
// on that type instead to extract some information. Furthermore, if equality comparison on that type is trivial, the
// user can't observe that we're calling it. So instead of using the user-provided total order, we use std::mismatch,
// which uses equality comparison (and is vertorized). Additionally, if the type is trivially lexicographically
// comparable, we can go one step further and use std::memcmp directly instead of calling std::mismatch.
template <class _Tp,
class _Proj1,
class _Proj2,
class _Comp,
__enable_if_t<__desugars_to_v<__totally_ordered_less_tag, _Comp, _Tp, _Tp> && !is_volatile<_Tp>::value &&
__libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value &&
__is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
int> = 0>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
__lexicographical_compare(_Tp* __first1, _Tp* __last1, _Tp* __first2, _Tp* __last2, _Comp&, _Proj1&, _Proj2&) {
if constexpr (__is_trivially_lexicographically_comparable_v<_Tp, _Tp>) {
auto __res =
std::__constexpr_memcmp(__first1, __first2, __element_count(std::min(__last1 - __first1, __last2 - __first2)));
if (__res == 0)
return __last1 - __first1 < __last2 - __first2;
return __res < 0;
}
# ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS
else if constexpr (is_same<__remove_cv_t<_Tp>, wchar_t>::value) {
auto __res = std::__constexpr_wmemcmp(__first1, __first2, std::min(__last1 - __first1, __last2 - __first2));
if (__res == 0)
return __last1 - __first1 < __last2 - __first2;
return __res < 0;
}
# endif // _LIBCPP_HAS_NO_WIDE_CHARACTERS
else {
auto __res = std::mismatch(__first1, __last1, __first2, __last2);
if (__res.second == __last2)
return false;
if (__res.first == __last1)
return true;
return *__res.first < *__res.second;
}
}

#endif // _LIBCPP_STD_VER >= 14

template <class _InputIterator1, class _InputIterator2, class _Compare>
_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool lexicographical_compare(
_InputIterator1 __first1,
_InputIterator1 __last1,
_InputIterator2 __first2,
_InputIterator2 __last2,
_Compare __comp) {
return std::__lexicographical_compare<__comp_ref_type<_Compare> >(__first1, __last1, __first2, __last2, __comp);
__identity __proj;
return std::__lexicographical_compare(
std::__unwrap_iter(__first1),
std::__unwrap_iter(__last1),
std::__unwrap_iter(__first2),
std::__unwrap_iter(__last2),
__comp,
__proj,
__proj);
}

template <class _InputIterator1, class _InputIterator2>
Expand All @@ -54,4 +123,6 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 boo

_LIBCPP_END_NAMESPACE_STD

_LIBCPP_POP_MACROS

#endif // _LIBCPP___ALGORITHM_LEXICOGRAPHICAL_COMPARE_H
27 changes: 15 additions & 12 deletions libcxx/include/__algorithm/ranges_lexicographical_compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#ifndef _LIBCPP___ALGORITHM_RANGES_LEXICOGRAPHICAL_COMPARE_H
#define _LIBCPP___ALGORITHM_RANGES_LEXICOGRAPHICAL_COMPARE_H

#include <__algorithm/lexicographical_compare.h>
#include <__algorithm/unwrap_range.h>
#include <__config>
#include <__functional/identity.h>
#include <__functional/invoke.h>
Expand All @@ -34,23 +36,24 @@ namespace ranges {
namespace __lexicographical_compare {
struct __fn {
template <class _Iter1, class _Sent1, class _Iter2, class _Sent2, class _Proj1, class _Proj2, class _Comp>
_LIBCPP_HIDE_FROM_ABI constexpr static bool __lexicographical_compare_impl(
static _LIBCPP_HIDE_FROM_ABI constexpr bool __lexicographical_compare_unwrap(
_Iter1 __first1,
_Sent1 __last1,
_Iter2 __first2,
_Sent2 __last2,
_Comp& __comp,
_Proj1& __proj1,
_Proj2& __proj2) {
while (__first2 != __last2) {
if (__first1 == __last1 || std::invoke(__comp, std::invoke(__proj1, *__first1), std::invoke(__proj2, *__first2)))
return true;
if (std::invoke(__comp, std::invoke(__proj2, *__first2), std::invoke(__proj1, *__first1)))
return false;
++__first1;
++__first2;
}
return false;
auto [__first1_un, __last1_un] = std::__unwrap_range(std::move(__first1), std::move(__last1));
auto [__first2_un, __last2_un] = std::__unwrap_range(std::move(__first2), std::move(__last2));
return std::__lexicographical_compare(
std::move(__first1_un),
std::move(__last1_un),
std::move(__first2_un),
std::move(__last2_un),
__comp,
__proj1,
__proj2);
}

template <input_iterator _Iter1,
Expand All @@ -68,7 +71,7 @@ struct __fn {
_Comp __comp = {},
_Proj1 __proj1 = {},
_Proj2 __proj2 = {}) const {
return __lexicographical_compare_impl(
return __lexicographical_compare_unwrap(
std::move(__first1), std::move(__last1), std::move(__first2), std::move(__last2), __comp, __proj1, __proj2);
}

Expand All @@ -80,7 +83,7 @@ struct __fn {
_Comp = ranges::less>
[[nodiscard]] _LIBCPP_HIDE_FROM_ABI constexpr bool operator()(
_Range1&& __range1, _Range2&& __range2, _Comp __comp = {}, _Proj1 __proj1 = {}, _Proj2 __proj2 = {}) const {
return __lexicographical_compare_impl(
return __lexicographical_compare_unwrap(
ranges::begin(__range1),
ranges::end(__range1),
ranges::begin(__range2),
Expand Down
2 changes: 1 addition & 1 deletion libcxx/include/__algorithm/ranges_minmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct __fn {
// vectorize the code.
if constexpr (contiguous_range<_Range> && is_integral_v<_ValueT> &&
__is_cheap_to_copy<_ValueT> & __is_identity<_Proj>::value &&
__desugars_to_v<__less_tag, _Comp, _ValueT, _ValueT>) {
__desugars_to_v<__totally_ordered_less_tag, _Comp, _ValueT, _ValueT>) {
minmax_result<_ValueT> __result = {__r[0], __r[0]};
for (auto __e : __r) {
if (__e < __result.min)
Expand Down
5 changes: 3 additions & 2 deletions libcxx/include/__functional/operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <__functional/binary_function.h>
#include <__functional/unary_function.h>
#include <__type_traits/desugars_to.h>
#include <__type_traits/is_integral.h>
#include <__utility/forward.h>

#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
Expand Down Expand Up @@ -362,7 +363,7 @@ struct _LIBCPP_TEMPLATE_VIS less : __binary_function<_Tp, _Tp, bool> {
_LIBCPP_CTAD_SUPPORTED_FOR_TYPE(less);

template <class _Tp>
inline const bool __desugars_to_v<__less_tag, less<_Tp>, _Tp, _Tp> = true;
inline const bool __desugars_to_v<__totally_ordered_less_tag, less<_Tp>, _Tp, _Tp> = is_integral<_Tp>::value;

#if _LIBCPP_STD_VER >= 14
template <>
Expand All @@ -377,7 +378,7 @@ struct _LIBCPP_TEMPLATE_VIS less<void> {
};

template <class _Tp>
inline const bool __desugars_to_v<__less_tag, less<>, _Tp, _Tp> = true;
inline const bool __desugars_to_v<__totally_ordered_less_tag, less<>, _Tp, _Tp> = is_integral<_Tp>::value;
#endif

#if _LIBCPP_STD_VER >= 14
Expand Down
2 changes: 1 addition & 1 deletion libcxx/include/__functional/ranges_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ template <class _Tp, class _Up>
inline const bool __desugars_to_v<__equal_tag, ranges::equal_to, _Tp, _Up> = true;

template <class _Tp, class _Up>
inline const bool __desugars_to_v<__less_tag, ranges::less, _Tp, _Up> = true;
inline const bool __desugars_to_v<__totally_ordered_less_tag, ranges::less, _Tp, _Up> = true;

#endif // _LIBCPP_STD_VER >= 20

Expand Down
4 changes: 2 additions & 2 deletions libcxx/include/__string/constexpr_c_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 size_t __constexpr_st
return __builtin_strlen(reinterpret_cast<const char*>(__str));
}

// Because of __libcpp_is_trivially_lexicographically_comparable we know that comparing the object representations is
// Because of __is_trivially_lexicographically_comparable_v we know that comparing the object representations is
// equivalent to a std::memcmp. Since we have multiple objects contiguously in memory, we can call memcmp once instead
// of invoking it on every object individually.
template <class _Tp, class _Up>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 int
__constexpr_memcmp(const _Tp* __lhs, const _Up* __rhs, __element_count __n) {
static_assert(__libcpp_is_trivially_lexicographically_comparable<_Tp, _Up>::value,
static_assert(__is_trivially_lexicographically_comparable_v<_Tp, _Up>,
"_Tp and _Up have to be trivially lexicographically comparable");

auto __count = static_cast<size_t>(__n);
Expand Down
15 changes: 13 additions & 2 deletions libcxx/include/__type_traits/desugars_to.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@

_LIBCPP_BEGIN_NAMESPACE_STD

// Tags to represent the canonical operations
// Tags to represent the canonical operations.

// syntactically, the operation is equivalent to calling `a == b`
struct __equal_tag {};

// syntactically, the operation is equivalent to calling `a + b`
struct __plus_tag {};
struct __less_tag {};

// syntactically, the operation is equivalent to calling `a < b`, and these expressions
// have to be true for any `a` and `b`:
// - `(a < b) == (b > a)`
// - `(!(a < b) && !(b < a)) == (a == b)`
// For example, this is satisfied for std::less on integral types, but also for ranges::less on all types due to
// additional semantic requirements on that operation.
struct __totally_ordered_less_tag {};

// This class template is used to determine whether an operation "desugars"
// (or boils down) to a given canonical operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <__type_traits/remove_cv.h>
#include <__type_traits/void_t.h>
#include <__utility/declval.h>
#include <cstddef>

#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
Expand All @@ -40,13 +41,22 @@ _LIBCPP_BEGIN_NAMESPACE_STD
// unsigned integer types with sizeof(T) > 1: depending on the endianness, the LSB might be the first byte to be
// compared. This means that when comparing unsigned(129) and unsigned(2)
// using memcmp(), the result would be that 2 > 129.
// TODO: Do we want to enable this on big-endian systems?

template <class _Tp>
inline const bool __is_std_byte_v = false;

#if _LIBCPP_STD_VER >= 17
template <>
inline const bool __is_std_byte_v<byte> = true;
#endif

template <class _Tp, class _Up>
struct __libcpp_is_trivially_lexicographically_comparable
: integral_constant<bool,
is_same<__remove_cv_t<_Tp>, __remove_cv_t<_Up> >::value && sizeof(_Tp) == 1 &&
is_unsigned<_Tp>::value> {};
inline const bool __is_trivially_lexicographically_comparable_v =
is_same<__remove_cv_t<_Tp>, __remove_cv_t<_Up> >::value &&
#ifdef _LIBCPP_LITTLE_ENDIAN
sizeof(_Tp) == 1 &&
#endif
(is_unsigned<_Tp>::value || __is_std_byte_v<_Tp>);

_LIBCPP_END_NAMESPACE_STD

Expand Down
1 change: 1 addition & 0 deletions libcxx/test/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ set(BENCHMARK_TESTS
algorithms/find.bench.cpp
algorithms/fill.bench.cpp
algorithms/for_each.bench.cpp
algorithms/lexicographical_compare.bench.cpp
algorithms/lower_bound.bench.cpp
algorithms/make_heap.bench.cpp
algorithms/make_heap_then_sort_heap.bench.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <algorithm>
#include <benchmark/benchmark.h>
#include <vector>

// Benchmarks the worst case: check the whole range just to find out that they compare equal
template <class T>
static void bm_lexicographical_compare(benchmark::State& state) {
std::vector<T> vec1(state.range(), '1');
std::vector<T> vec2(state.range(), '1');

for (auto _ : state) {
benchmark::DoNotOptimize(vec1);
benchmark::DoNotOptimize(vec2);
benchmark::DoNotOptimize(std::lexicographical_compare(vec1.begin(), vec1.end(), vec2.begin(), vec2.end()));
}
}
BENCHMARK(bm_lexicographical_compare<unsigned char>)->DenseRange(1, 8)->Range(16, 1 << 20);
BENCHMARK(bm_lexicographical_compare<signed char>)->DenseRange(1, 8)->Range(16, 1 << 20);
BENCHMARK(bm_lexicographical_compare<int>)->DenseRange(1, 8)->Range(16, 1 << 20);

template <class T>
static void bm_ranges_lexicographical_compare(benchmark::State& state) {
std::vector<T> vec1(state.range(), '1');
std::vector<T> vec2(state.range(), '1');

for (auto _ : state) {
benchmark::DoNotOptimize(vec1);
benchmark::DoNotOptimize(vec2);
benchmark::DoNotOptimize(std::ranges::lexicographical_compare(vec1.begin(), vec1.end(), vec2.begin(), vec2.end()));
}
}
BENCHMARK(bm_ranges_lexicographical_compare<unsigned char>)->DenseRange(1, 8)->Range(16, 1 << 20);
BENCHMARK(bm_ranges_lexicographical_compare<signed char>)->DenseRange(1, 8)->Range(16, 1 << 20);
BENCHMARK(bm_ranges_lexicographical_compare<int>)->DenseRange(1, 8)->Range(16, 1 << 20);

BENCHMARK_MAIN();
Loading
Loading