Skip to content

Commit ea9af5e

Browse files
authored
[libc++] Add assertions for potential OOB reads in std::nth_element (#67023)
Same as https://reviews.llvm.org/D147089 but for std::nth_element
1 parent a574242 commit ea9af5e

File tree

3 files changed

+149
-31
lines changed

3 files changed

+149
-31
lines changed

libcxx/include/__algorithm/nth_element.h

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <__algorithm/comp_ref_type.h>
1414
#include <__algorithm/iterator_operations.h>
1515
#include <__algorithm/sort.h>
16+
#include <__assert>
1617
#include <__config>
1718
#include <__debug_utils/randomize_range.h>
1819
#include <__iterator/iterator_traits.h>
@@ -42,6 +43,7 @@ __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,
4243

4344
template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
4445
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
46+
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
4547
__nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
4648
{
4749
using _Ops = _IterOps<_AlgPolicy>;
@@ -116,10 +118,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
116118
return;
117119
}
118120
while (true) {
119-
while (!__comp(*__first, *__i))
121+
while (!__comp(*__first, *__i)) {
120122
++__i;
121-
while (__comp(*__first, *--__j))
122-
;
123+
_LIBCPP_ASSERT_UNCATEGORIZED(
124+
__i != __last,
125+
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
126+
}
127+
do {
128+
_LIBCPP_ASSERT_UNCATEGORIZED(
129+
__j != __first,
130+
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
131+
--__j;
132+
} while (__comp(*__first, *__j));
123133
if (__i >= __j)
124134
break;
125135
_Ops::iter_swap(__i, __j);
@@ -146,11 +156,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
146156
while (true)
147157
{
148158
// __m still guards upward moving __i
149-
while (__comp(*__i, *__m))
159+
while (__comp(*__i, *__m)) {
150160
++__i;
161+
_LIBCPP_ASSERT_UNCATEGORIZED(
162+
__i != __last,
163+
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
164+
}
151165
// It is now known that a guard exists for downward moving __j
152-
while (!__comp(*--__j, *__m))
153-
;
166+
do {
167+
_LIBCPP_ASSERT_UNCATEGORIZED(
168+
__j != __first,
169+
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
170+
--__j;
171+
} while (!__comp(*__j, *__m));
154172
if (__i >= __j)
155173
break;
156174
_Ops::iter_swap(__i, __j);

libcxx/test/libcxx/algorithms/alg.sorting/assert.sort.invalid_comparator.pass.cpp

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,34 @@
5050
#include "bad_comparator_values.h"
5151
#include "check_assertion.h"
5252

53-
void check_oob_sort_read() {
54-
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
55-
for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
56-
auto values = std::views::split(line, ' ');
57-
auto it = values.begin();
58-
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
59-
it = std::next(it);
60-
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
61-
it = std::next(it);
62-
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
63-
comparison_results[left][right] = result;
64-
}
65-
auto predicate = [&](std::size_t* left, std::size_t* right) {
53+
class ComparisonResults {
54+
public:
55+
explicit ComparisonResults(std::string_view data) {
56+
for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
57+
auto values = std::views::split(line, ' ');
58+
auto it = values.begin();
59+
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
60+
it = std::next(it);
61+
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
62+
it = std::next(it);
63+
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
64+
comparison_results[left][right] = result;
65+
}
66+
}
67+
68+
bool compare(size_t* left, size_t* right) const {
6669
assert(left != nullptr && right != nullptr && "something is wrong with the test");
67-
assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
68-
return comparison_results[*left][*right];
69-
};
70+
assert(comparison_results.contains(*left) && comparison_results.at(*left).contains(*right) && "malformed input data?");
71+
return comparison_results.at(*left).at(*right);
72+
}
7073

74+
size_t size() const { return comparison_results.size(); }
75+
private:
76+
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
77+
};
78+
79+
void check_oob_sort_read() {
80+
ComparisonResults comparison_results(SORT_DATA);
7181
std::vector<std::unique_ptr<std::size_t>> elements;
7282
std::set<std::size_t*> valid_ptrs;
7383
for (std::size_t i = 0; i != comparison_results.size(); ++i) {
@@ -81,7 +91,7 @@ void check_oob_sort_read() {
8191
// because we're reading OOB.
8292
assert(valid_ptrs.contains(left));
8393
assert(valid_ptrs.contains(right));
84-
return predicate(left, right);
94+
return comparison_results.compare(left, right);
8595
};
8696

8797
// Check the classic sorting algorithms
@@ -117,12 +127,6 @@ void check_oob_sort_read() {
117127
std::vector<std::size_t*> results(copy.size(), nullptr);
118128
TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering");
119129
}
120-
{
121-
std::vector<std::size_t*> copy;
122-
for (auto const& e : elements)
123-
copy.push_back(e.get());
124-
std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
125-
}
126130

127131
// Check the Ranges sorting algorithms
128132
{
@@ -157,11 +161,38 @@ void check_oob_sort_read() {
157161
std::vector<std::size_t*> results(copy.size(), nullptr);
158162
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering");
159163
}
164+
}
165+
166+
void check_oob_nth_element_read() {
167+
ComparisonResults results(NTH_ELEMENT_DATA);
168+
std::vector<std::unique_ptr<std::size_t>> elements;
169+
std::set<std::size_t*> valid_ptrs;
170+
for (std::size_t i = 0; i != results.size(); ++i) {
171+
elements.push_back(std::make_unique<std::size_t>(i));
172+
valid_ptrs.insert(elements.back().get());
173+
}
174+
175+
auto checked_predicate = [&](size_t* left, size_t* right) {
176+
// If the pointers passed to the comparator are not in the set of pointers we
177+
// set up above, then we're being passed garbage values from the algorithm
178+
// because we're reading OOB.
179+
assert(valid_ptrs.contains(left));
180+
assert(valid_ptrs.contains(right));
181+
return results.compare(left, right);
182+
};
183+
160184
{
161185
std::vector<std::size_t*> copy;
162186
for (auto const& e : elements)
163187
copy.push_back(e.get());
164-
std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
188+
TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds");
189+
}
190+
191+
{
192+
std::vector<std::size_t*> copy;
193+
for (auto const& e : elements)
194+
copy.push_back(e.get());
195+
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds");
165196
}
166197
}
167198

@@ -214,6 +245,8 @@ int main(int, char**) {
214245

215246
check_oob_sort_read();
216247

248+
check_oob_nth_element_read();
249+
217250
check_nan_floats();
218251

219252
check_irreflexive();

libcxx/test/libcxx/algorithms/alg.sorting/bad_comparator_values.h

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,74 @@
1111

1212
#include <string_view>
1313

14-
inline constexpr std::string_view DATA = R"(
14+
inline constexpr std::string_view NTH_ELEMENT_DATA = R"(
15+
0 0 0
16+
0 1 0
17+
0 2 0
18+
0 3 0
19+
0 4 1
20+
0 5 0
21+
0 6 0
22+
0 7 0
23+
1 0 0
24+
1 1 0
25+
1 2 0
26+
1 3 1
27+
1 4 1
28+
1 5 1
29+
1 6 1
30+
1 7 1
31+
2 0 1
32+
2 1 1
33+
2 2 1
34+
2 3 1
35+
2 4 1
36+
2 5 1
37+
2 6 1
38+
2 7 1
39+
3 0 1
40+
3 1 1
41+
3 2 1
42+
3 3 1
43+
3 4 1
44+
3 5 1
45+
3 6 1
46+
3 7 1
47+
4 0 1
48+
4 1 1
49+
4 2 1
50+
4 3 1
51+
4 4 1
52+
4 5 1
53+
4 6 1
54+
4 7 1
55+
5 0 1
56+
5 1 1
57+
5 2 1
58+
5 3 1
59+
5 4 1
60+
5 5 1
61+
5 6 1
62+
5 7 1
63+
6 0 1
64+
6 1 1
65+
6 2 1
66+
6 3 1
67+
6 4 1
68+
6 5 1
69+
6 6 1
70+
6 7 1
71+
7 0 1
72+
7 1 1
73+
7 2 1
74+
7 3 1
75+
7 4 1
76+
7 5 1
77+
7 6 1
78+
7 7 1
79+
)";
80+
81+
inline constexpr std::string_view SORT_DATA = R"(
1582
0 0 0
1683
0 1 1
1784
0 2 1

0 commit comments

Comments
 (0)