Skip to content

Commit 7e1ee1e

Browse files
danlark1mordante
authored andcommitted
[libcxx] Add strict weak ordering checks to sorting algorithms
This is the implementation of the first proposal of strict weak ordering checks described in https://discourse.llvm.org/t/rfc-strict-weak-ordering-checks-in-the-debug-libc/70217 This targets the most vulnerable algorithms like std::sort Reviewed By: philnik, #libc Differential Revision: https://reviews.llvm.org/D150264
1 parent b0525f6 commit 7e1ee1e

File tree

12 files changed

+193
-24
lines changed

12 files changed

+193
-24
lines changed

libcxx/include/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ set(files
324324
__coroutine/trivial_awaitables.h
325325
__debug
326326
__debug_utils/randomize_range.h
327+
__debug_utils/strict_weak_ordering_check.h
327328
__exception/exception.h
328329
__exception/exception_ptr.h
329330
__exception/nested_exception.h

libcxx/include/__algorithm/sort.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <__config>
2424
#include <__debug>
2525
#include <__debug_utils/randomize_range.h>
26+
#include <__debug_utils/strict_weak_ordering_check.h>
2627
#include <__functional/operations.h>
2728
#include <__functional/ranges_operations.h>
2829
#include <__iterator/iterator_traits.h>
@@ -921,6 +922,7 @@ void __sort_impl(_RandomAccessIterator __first, _RandomAccessIterator __last, _C
921922
} else {
922923
std::__sort_dispatch<_AlgPolicy>(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __comp);
923924
}
925+
std::__check_strict_weak_ordering_sorted(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __comp);
924926
}
925927

926928
template <class _RandomAccessIterator, class _Comp>

libcxx/include/__algorithm/sort_heap.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <__algorithm/iterator_operations.h>
1515
#include <__algorithm/pop_heap.h>
1616
#include <__config>
17+
#include <__debug_utils/strict_weak_ordering_check.h>
1718
#include <__iterator/iterator_traits.h>
1819
#include <__type_traits/is_copy_assignable.h>
1920
#include <__type_traits/is_copy_constructible.h>
@@ -28,11 +29,13 @@ _LIBCPP_BEGIN_NAMESPACE_STD
2829
template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
2930
inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14
3031
void __sort_heap(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compare&& __comp) {
32+
_RandomAccessIterator __saved_last = __last;
3133
__comp_ref_type<_Compare> __comp_ref = __comp;
3234

3335
using difference_type = typename iterator_traits<_RandomAccessIterator>::difference_type;
3436
for (difference_type __n = __last - __first; __n > 1; --__last, (void) --__n)
3537
std::__pop_heap<_AlgPolicy>(__first, __last, __comp_ref, __n);
38+
std::__check_strict_weak_ordering_sorted(__first, __saved_last, __comp_ref);
3639
}
3740

3841
template <class _RandomAccessIterator, class _Compare>

libcxx/include/__algorithm/stable_sort.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <__algorithm/iterator_operations.h>
1616
#include <__algorithm/sort.h>
1717
#include <__config>
18+
#include <__debug_utils/strict_weak_ordering_check.h>
1819
#include <__iterator/iterator_traits.h>
1920
#include <__memory/destruct_n.h>
2021
#include <__memory/temporary_buffer.h>
@@ -259,6 +260,7 @@ _LIBCPP_SUPPRESS_DEPRECATED_POP
259260
}
260261

261262
std::__stable_sort<_AlgPolicy, __comp_ref_type<_Compare> >(__first, __last, __comp, __len, __buf.first, __buf.second);
263+
std::__check_strict_weak_ordering_sorted(__first, __last, __comp);
262264
}
263265

264266
template <class _RandomAccessIterator, class _Compare>

libcxx/include/__debug

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
# define _LIBCPP_DEBUG_RANDOMIZE_UNSPECIFIED_STABILITY
2424
#endif
2525

26+
#if defined(_LIBCPP_ENABLE_DEBUG_MODE) && !defined(_LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK)
27+
# define _LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK
28+
#endif
29+
2630
#if defined(_LIBCPP_ENABLE_DEBUG_MODE) && !defined(_LIBCPP_DEBUG_ITERATOR_BOUNDS_CHECKING)
2731
# define _LIBCPP_DEBUG_ITERATOR_BOUNDS_CHECKING
2832
#endif
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef _LIBCPP___LIBCXX_DEBUG_STRICT_WEAK_ORDERING_CHECK
10+
#define _LIBCPP___LIBCXX_DEBUG_STRICT_WEAK_ORDERING_CHECK
11+
12+
#include <__config>
13+
14+
#include <__algorithm/comp_ref_type.h>
15+
#include <__algorithm/is_sorted.h>
16+
#include <__assert>
17+
#include <__iterator/iterator_traits.h>
18+
#include <__type_traits/is_constant_evaluated.h>
19+
20+
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
21+
# pragma GCC system_header
22+
#endif
23+
24+
_LIBCPP_BEGIN_NAMESPACE_STD
25+
26+
template <class _RandomAccessIterator, class _Comp>
27+
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
28+
__check_strict_weak_ordering_sorted(_RandomAccessIterator __first, _RandomAccessIterator __last, _Comp& __comp) {
29+
#ifdef _LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK
30+
using __diff_t = __iter_diff_t<_RandomAccessIterator>;
31+
using _Comp_ref = __comp_ref_type<_Comp>;
32+
if (!__libcpp_is_constant_evaluated()) {
33+
// Check if the range is actually sorted.
34+
_LIBCPP_ASSERT((std::is_sorted<_RandomAccessIterator, _Comp_ref>(__first, __last, _Comp_ref(__comp))),
35+
"The range is not sorted after the sort, your comparator is not a valid strict-weak ordering");
36+
// Limit the number of elements we need to check.
37+
__diff_t __size = __last - __first > __diff_t(100) ? __diff_t(100) : __last - __first;
38+
__diff_t __p = 0;
39+
while (__p < __size) {
40+
__diff_t __q = __p + __diff_t(1);
41+
// Find first element that is greater than *(__first+__p).
42+
while (__q < __size && !__comp(*(__first + __p), *(__first + __q))) {
43+
++__q;
44+
}
45+
// Check that the elements from __p to __q are equal between each other.
46+
for (__diff_t __b = __p; __b < __q; ++__b) {
47+
for (__diff_t __a = __p; __a <= __b; ++__a) {
48+
_LIBCPP_ASSERT(
49+
!__comp(*(__first + __a), *(__first + __b)), "Your comparator is not a valid strict-weak ordering");
50+
_LIBCPP_ASSERT(
51+
!__comp(*(__first + __b), *(__first + __a)), "Your comparator is not a valid strict-weak ordering");
52+
}
53+
}
54+
// Check that elements between __p and __q are less than between __q and __size.
55+
for (__diff_t __a = __p; __a < __q; ++__a) {
56+
for (__diff_t __b = __q; __b < __size; ++__b) {
57+
_LIBCPP_ASSERT(
58+
__comp(*(__first + __a), *(__first + __b)), "Your comparator is not a valid strict-weak ordering");
59+
_LIBCPP_ASSERT(
60+
!__comp(*(__first + __b), *(__first + __a)), "Your comparator is not a valid strict-weak ordering");
61+
}
62+
}
63+
// Skip these equal elements.
64+
__p = __q;
65+
}
66+
}
67+
#else
68+
(void)__first;
69+
(void)__last;
70+
(void)__comp;
71+
#endif
72+
}
73+
74+
_LIBCPP_END_NAMESPACE_STD
75+
76+
#endif // _LIBCPP___LIBCXX_DEBUG_STRICT_WEAK_ORDERING_CHECK

libcxx/include/module.modulemap.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,8 @@ module std [system] {
11411141
}
11421142

11431143
module __debug_utils {
1144-
module randomize_range { private header "__debug_utils/randomize_range.h" }
1144+
module randomize_range { private header "__debug_utils/randomize_range.h" }
1145+
module strict_weak_ordering_check { private header "__debug_utils/strict_weak_ordering_check.h" }
11451146
}
11461147

11471148
module limits {

libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp

Lines changed: 91 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,41 +11,46 @@
1111
// REQUIRES: has-unix-headers
1212
// UNSUPPORTED: c++03, c++11, c++14, c++17
1313
// XFAIL: availability-verbose_abort-missing
14-
// ADDITIONAL_COMPILE_FLAGS: -D_LIBCPP_ENABLE_ASSERTIONS=1
14+
// ADDITIONAL_COMPILE_FLAGS: -D_LIBCPP_ENABLE_ASSERTIONS=1 -D_LIBCPP_DEBUG_STRICT_WEAK_ORDERING_CHECK
1515

1616
// This test uses a specific combination of an invalid comparator and sequence of values to
17-
// ensure that our sorting functions do not go out-of-bounds in that case. Instead, we should
18-
// fail loud with an assertion. The specific issue we're looking for here is when the comparator
19-
// does not satisfy the following property:
17+
// ensure that our sorting functions do not go out-of-bounds and satisfy strict weak ordering in that case.
18+
// Instead, we should fail loud with an assertion. The specific issue we're looking for here is when the comparator
19+
// does not satisfy the strict weak ordering:
2020
//
21-
// comp(a, b) implies that !comp(b, a)
22-
//
23-
// In other words,
24-
//
25-
// a < b implies that !(b < a)
21+
// Irreflexivity: comp(a, a) is false
22+
// Antisymmetry: comp(a, b) implies that !comp(b, a)
23+
// Transitivity: comp(a, b), comp(b, c) imply comp(a, c)
24+
// Transitivity of equivalence: !comp(a, b), !comp(b, a), !comp(b, c), !comp(c, b) imply !comp(a, c), !comp(c, a)
2625
//
2726
// If this is not satisfied, we have seen issues in the past where the std::sort implementation
28-
// would proceed to do OOB reads (rdar://106897934).
27+
// would proceed to do OOB reads. (rdar://106897934).
28+
// Other algorithms like std::stable_sort, std::sort_heap do not go out of bounds but can produce
29+
// incorrect results, we also want to assert on that.
30+
// Sometimes std::sort does not go out of bounds as well, for example, right now if transitivity
31+
// of equivalence is not met, std::sort can only produce incorrect result but would not fail.
2932

30-
// When the debug mode is enabled, this test fails because we actually catch that the comparator
33+
// When the debug mode is enabled, this test fails because we actually catch on the fly that the comparator
3134
// is not a strict-weak ordering before we catch that we'd dereference out-of-bounds inside std::sort,
3235
// which leads to different errors than the ones tested below.
3336
// XFAIL: libcpp-has-debug-mode
3437

3538
#include <algorithm>
3639
#include <cassert>
3740
#include <cstddef>
41+
#include <limits>
3842
#include <map>
3943
#include <memory>
4044
#include <ranges>
45+
#include <random>
4146
#include <set>
4247
#include <string>
4348
#include <vector>
4449

4550
#include "bad_comparator_values.h"
4651
#include "check_assertion.h"
4752

48-
int main(int, char**) {
53+
void check_oob_sort_read() {
4954
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
5055
for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
5156
auto values = std::views::split(line, ' ');
@@ -90,20 +95,27 @@ int main(int, char**) {
9095
std::vector<std::size_t*> copy;
9196
for (auto const& e : elements)
9297
copy.push_back(e.get());
93-
std::stable_sort(copy.begin(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
98+
TEST_LIBCPP_ASSERT_FAILURE(std::stable_sort(copy.begin(), copy.end(), checked_predicate), "not a valid strict-weak ordering");
99+
}
100+
{
101+
std::vector<std::size_t*> copy;
102+
for (auto const& e : elements)
103+
copy.push_back(e.get());
104+
std::make_heap(copy.begin(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
105+
TEST_LIBCPP_ASSERT_FAILURE(std::sort_heap(copy.begin(), copy.end(), checked_predicate), "not a valid strict-weak ordering");
94106
}
95107
{
96108
std::vector<std::size_t*> copy;
97109
for (auto const& e : elements)
98110
copy.push_back(e.get());
99-
std::partial_sort(copy.begin(), copy.begin(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
111+
TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort(copy.begin(), copy.end(), copy.end(), checked_predicate), "not a valid strict-weak ordering");
100112
}
101113
{
102114
std::vector<std::size_t*> copy;
103115
for (auto const& e : elements)
104116
copy.push_back(e.get());
105117
std::vector<std::size_t*> results(copy.size(), nullptr);
106-
std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate); // doesn't go OOB even with invalid comparator
118+
TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering");
107119
}
108120
{
109121
std::vector<std::size_t*> copy;
@@ -123,27 +135,88 @@ int main(int, char**) {
123135
std::vector<std::size_t*> copy;
124136
for (auto const& e : elements)
125137
copy.push_back(e.get());
126-
std::ranges::stable_sort(copy, checked_predicate); // doesn't go OOB even with invalid comparator
138+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::stable_sort(copy, checked_predicate), "not a valid strict-weak ordering");
127139
}
128140
{
129141
std::vector<std::size_t*> copy;
130142
for (auto const& e : elements)
131143
copy.push_back(e.get());
132-
std::ranges::partial_sort(copy, copy.begin(), checked_predicate); // doesn't go OOB even with invalid comparator
144+
std::ranges::make_heap(copy, checked_predicate); // doesn't go OOB even with invalid comparator
145+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort_heap(copy, checked_predicate), "not a valid strict-weak ordering");
146+
}
147+
{
148+
std::vector<std::size_t*> copy;
149+
for (auto const& e : elements)
150+
copy.push_back(e.get());
151+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort(copy, copy.end(), checked_predicate), "not a valid strict-weak ordering");
133152
}
134153
{
135154
std::vector<std::size_t*> copy;
136155
for (auto const& e : elements)
137156
copy.push_back(e.get());
138157
std::vector<std::size_t*> results(copy.size(), nullptr);
139-
std::ranges::partial_sort_copy(copy, results, checked_predicate); // doesn't go OOB even with invalid comparator
158+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering");
140159
}
141160
{
142161
std::vector<std::size_t*> copy;
143162
for (auto const& e : elements)
144163
copy.push_back(e.get());
145164
std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
146165
}
166+
}
167+
168+
struct FloatContainer {
169+
float value;
170+
bool operator<(const FloatContainer& other) const {
171+
return value < other.value;
172+
}
173+
};
174+
175+
// Nans in floats do not satisfy strict weak ordering by breaking transitivity of equivalence.
176+
std::vector<FloatContainer> generate_float_data() {
177+
std::vector<FloatContainer> floats(50);
178+
for (int i = 0; i < 50; ++i) {
179+
floats[i].value = static_cast<float>(i);
180+
}
181+
floats.push_back(FloatContainer{std::numeric_limits<float>::quiet_NaN()});
182+
std::shuffle(floats.begin(), floats.end(), std::default_random_engine());
183+
return floats;
184+
}
185+
186+
void check_nan_floats() {
187+
auto floats = generate_float_data();
188+
TEST_LIBCPP_ASSERT_FAILURE(std::sort(floats.begin(), floats.end()), "not a valid strict-weak ordering");
189+
floats = generate_float_data();
190+
TEST_LIBCPP_ASSERT_FAILURE(std::stable_sort(floats.begin(), floats.end()), "not a valid strict-weak ordering");
191+
floats = generate_float_data();
192+
std::make_heap(floats.begin(), floats.end());
193+
TEST_LIBCPP_ASSERT_FAILURE(std::sort_heap(floats.begin(), floats.end()), "not a valid strict-weak ordering");
194+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort(generate_float_data(), std::less()), "not a valid strict-weak ordering");
195+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::stable_sort(generate_float_data(), std::less()), "not a valid strict-weak ordering");
196+
floats = generate_float_data();
197+
std::ranges::make_heap(floats, std::less());
198+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort_heap(floats, std::less()), "not a valid strict-weak ordering");
199+
}
200+
201+
void check_irreflexive() {
202+
std::vector<int> v(1);
203+
TEST_LIBCPP_ASSERT_FAILURE(std::sort(v.begin(), v.end(), std::greater_equal<int>()), "not a valid strict-weak ordering");
204+
TEST_LIBCPP_ASSERT_FAILURE(std::stable_sort(v.begin(), v.end(), std::greater_equal<int>()), "not a valid strict-weak ordering");
205+
std::make_heap(v.begin(), v.end(), std::greater_equal<int>());
206+
TEST_LIBCPP_ASSERT_FAILURE(std::sort_heap(v.begin(), v.end(), std::greater_equal<int>()), "not a valid strict-weak ordering");
207+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort(v, std::greater_equal<int>()), "not a valid strict-weak ordering");
208+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::stable_sort(v, std::greater_equal<int>()), "not a valid strict-weak ordering");
209+
std::ranges::make_heap(v, std::greater_equal<int>());
210+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::sort_heap(v, std::greater_equal<int>()), "not a valid strict-weak ordering");
211+
}
212+
213+
int main(int, char**) {
214+
215+
check_oob_sort_read();
216+
217+
check_nan_floats();
218+
219+
check_irreflexive();
147220

148221
return 0;
149222
}

libcxx/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/complexity.pass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ int main(int, char**) {
5858
const int n = (1 << logn);
5959
auto first = v.begin();
6060
auto last = v.begin() + n;
61+
const int debug_elements = std::min(100, n);
62+
// Multiplier 2 because of comp(a,b) comp(b, a) checks.
63+
const int debug_comparisons = 2 * (debug_elements + 1) * debug_elements;
6164
std::shuffle(first, last, g);
6265
std::make_heap(first, last);
6366
// The exact stats of our current implementation are recorded here.
@@ -69,7 +72,7 @@ int main(int, char**) {
6972
LIBCPP_ASSERT(stats.compared <= n * logn);
7073
#endif
7174
LIBCPP_ASSERT(std::is_sorted(first, last));
72-
LIBCPP_ASSERT(stats.compared <= 2 * n * logn);
75+
LIBCPP_ASSERT(stats.compared <= 2 * n * logn + debug_comparisons);
7376
}
7477
return 0;
7578
}

libcxx/test/std/algorithms/alg.sorting/alg.heap.operations/sort.heap/ranges_sort_heap.pass.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ constexpr bool test() {
207207

208208
{ // `std::ranges::dangling` is returned.
209209
[[maybe_unused]] std::same_as<std::ranges::dangling> decltype(auto) result =
210-
std::ranges::sort_heap(std::array{2, 1, 3});
210+
std::ranges::sort_heap(std::array{3, 1, 2});
211211
}
212212

213213
return true;
@@ -252,6 +252,9 @@ void test_complexity() {
252252
const int n = (1 << logn);
253253
auto first = v.begin();
254254
auto last = v.begin() + n;
255+
const int debug_elements = std::min(100, n);
256+
// Multiplier 2 because of comp(a,b) comp(b, a) checks.
257+
const int debug_comparisons = 2 * (debug_elements + 1) * debug_elements;
255258
std::shuffle(first, last, g);
256259
std::make_heap(first, last, &MyInt::Comp);
257260
// The exact stats of our current implementation are recorded here.
@@ -263,7 +266,7 @@ void test_complexity() {
263266
LIBCPP_ASSERT(stats.compared <= n * logn);
264267
#endif
265268
LIBCPP_ASSERT(std::is_sorted(first, last, &MyInt::Comp));
266-
LIBCPP_ASSERT(stats.compared <= 2 * n * logn);
269+
LIBCPP_ASSERT(stats.compared <= 2 * n * logn + debug_comparisons);
267270
}
268271
}
269272

libcxx/test/std/algorithms/ranges_robust_against_dangling.pass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ constexpr bool test_all() {
201201
dangling_1st(std::ranges::make_heap, in);
202202
dangling_1st(std::ranges::push_heap, in);
203203
dangling_1st(std::ranges::pop_heap, in);
204+
dangling_1st(std::ranges::make_heap, in);
204205
dangling_1st(std::ranges::sort_heap, in);
205206
dangling_1st<prev_permutation_result<dangling>>(std::ranges::prev_permutation, in);
206207
dangling_1st<next_permutation_result<dangling>>(std::ranges::next_permutation, in);

0 commit comments

Comments
 (0)