Skip to content

[libc++] Add assertions for potential OOB reads in std::nth_element #67023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions libcxx/include/__algorithm/nth_element.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <__algorithm/comp_ref_type.h>
#include <__algorithm/iterator_operations.h>
#include <__algorithm/sort.h>
#include <__assert>
#include <__config>
#include <__debug_utils/randomize_range.h>
#include <__iterator/iterator_traits.h>
Expand Down Expand Up @@ -42,6 +43,7 @@ __nth_element_find_guard(_RandomAccessIterator& __i, _RandomAccessIterator& __j,

template <class _AlgPolicy, class _Compare, class _RandomAccessIterator>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 void
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
__nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _RandomAccessIterator __last, _Compare __comp)
{
using _Ops = _IterOps<_AlgPolicy>;
Expand Down Expand Up @@ -116,10 +118,18 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
return;
}
while (true) {
while (!__comp(*__first, *__i))
while (!__comp(*__first, *__i)) {
++__i;
while (__comp(*__first, *--__j))
;
_LIBCPP_ASSERT_UNCATEGORIZED(
__i != __last,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
}
do {
_LIBCPP_ASSERT_UNCATEGORIZED(
__j != __first,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
--__j;
} while (__comp(*__first, *__j));
if (__i >= __j)
break;
_Ops::iter_swap(__i, __j);
Expand All @@ -146,11 +156,19 @@ __nth_element(_RandomAccessIterator __first, _RandomAccessIterator __nth, _Rando
while (true)
{
// __m still guards upward moving __i
while (__comp(*__i, *__m))
while (__comp(*__i, *__m)) {
++__i;
_LIBCPP_ASSERT_UNCATEGORIZED(
__i != __last,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
}
// It is now known that a guard exists for downward moving __j
while (!__comp(*--__j, *__m))
;
do {
_LIBCPP_ASSERT_UNCATEGORIZED(
__j != __first,
"Would read out of bounds, does your comparator satisfy the strict-weak ordering requirement?");
--__j;
} while (!__comp(*__j, *__m));
if (__i >= __j)
break;
_Ops::iter_swap(__i, __j);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,34 @@
#include "bad_comparator_values.h"
#include "check_assertion.h"

void check_oob_sort_read() {
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
for (auto line : std::views::split(DATA, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
auto values = std::views::split(line, ' ');
auto it = values.begin();
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
comparison_results[left][right] = result;
}
auto predicate = [&](std::size_t* left, std::size_t* right) {
class ComparisonResults {
public:
explicit ComparisonResults(std::string_view data) {
for (auto line : std::views::split(data, '\n') | std::views::filter([](auto const& line) { return !line.empty(); })) {
auto values = std::views::split(line, ' ');
auto it = values.begin();
std::size_t left = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
std::size_t right = std::stol(std::string((*it).data(), (*it).size()));
it = std::next(it);
bool result = static_cast<bool>(std::stol(std::string((*it).data(), (*it).size())));
comparison_results[left][right] = result;
}
}

bool compare(size_t* left, size_t* right) const {
assert(left != nullptr && right != nullptr && "something is wrong with the test");
assert(comparison_results.contains(*left) && comparison_results[*left].contains(*right) && "malformed input data?");
return comparison_results[*left][*right];
};
assert(comparison_results.contains(*left) && comparison_results.at(*left).contains(*right) && "malformed input data?");
return comparison_results.at(*left).at(*right);
}

size_t size() const { return comparison_results.size(); }
private:
std::map<std::size_t, std::map<std::size_t, bool>> comparison_results; // terrible for performance, but really convenient
};

void check_oob_sort_read() {
ComparisonResults comparison_results(SORT_DATA);
std::vector<std::unique_ptr<std::size_t>> elements;
std::set<std::size_t*> valid_ptrs;
for (std::size_t i = 0; i != comparison_results.size(); ++i) {
Expand All @@ -81,7 +91,7 @@ void check_oob_sort_read() {
// because we're reading OOB.
assert(valid_ptrs.contains(left));
assert(valid_ptrs.contains(right));
return predicate(left, right);
return comparison_results.compare(left, right);
};

// Check the classic sorting algorithms
Expand Down Expand Up @@ -117,12 +127,6 @@ void check_oob_sort_read() {
std::vector<std::size_t*> results(copy.size(), nullptr);
TEST_LIBCPP_ASSERT_FAILURE(std::partial_sort_copy(copy.begin(), copy.end(), results.begin(), results.end(), checked_predicate), "not a valid strict-weak ordering");
}
{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
std::nth_element(copy.begin(), copy.end(), copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
}

// Check the Ranges sorting algorithms
{
Expand Down Expand Up @@ -157,11 +161,38 @@ void check_oob_sort_read() {
std::vector<std::size_t*> results(copy.size(), nullptr);
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::partial_sort_copy(copy, results, checked_predicate), "not a valid strict-weak ordering");
}
}

void check_oob_nth_element_read() {
ComparisonResults results(NTH_ELEMENT_DATA);
std::vector<std::unique_ptr<std::size_t>> elements;
std::set<std::size_t*> valid_ptrs;
for (std::size_t i = 0; i != results.size(); ++i) {
elements.push_back(std::make_unique<std::size_t>(i));
valid_ptrs.insert(elements.back().get());
}

auto checked_predicate = [&](size_t* left, size_t* right) {
// If the pointers passed to the comparator are not in the set of pointers we
// set up above, then we're being passed garbage values from the algorithm
// because we're reading OOB.
assert(valid_ptrs.contains(left));
assert(valid_ptrs.contains(right));
return results.compare(left, right);
};

{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
std::ranges::nth_element(copy, copy.end(), checked_predicate); // doesn't go OOB even with invalid comparator
TEST_LIBCPP_ASSERT_FAILURE(std::nth_element(copy.begin(), copy.begin(), copy.end(), checked_predicate), "Would read out of bounds");
}

{
std::vector<std::size_t*> copy;
for (auto const& e : elements)
copy.push_back(e.get());
TEST_LIBCPP_ASSERT_FAILURE(std::ranges::nth_element(copy, copy.begin(), checked_predicate), "Would read out of bounds");
}
}

Expand Down Expand Up @@ -214,6 +245,8 @@ int main(int, char**) {

check_oob_sort_read();

check_oob_nth_element_read();

check_nan_floats();

check_irreflexive();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,74 @@

#include <string_view>

inline constexpr std::string_view DATA = R"(
inline constexpr std::string_view NTH_ELEMENT_DATA = R"(
0 0 0
0 1 0
0 2 0
0 3 0
0 4 1
0 5 0
0 6 0
0 7 0
1 0 0
1 1 0
1 2 0
1 3 1
1 4 1
1 5 1
1 6 1
1 7 1
2 0 1
2 1 1
2 2 1
2 3 1
2 4 1
2 5 1
2 6 1
2 7 1
3 0 1
3 1 1
3 2 1
3 3 1
3 4 1
3 5 1
3 6 1
3 7 1
4 0 1
4 1 1
4 2 1
4 3 1
4 4 1
4 5 1
4 6 1
4 7 1
5 0 1
5 1 1
5 2 1
5 3 1
5 4 1
5 5 1
5 6 1
5 7 1
6 0 1
6 1 1
6 2 1
6 3 1
6 4 1
6 5 1
6 6 1
6 7 1
7 0 1
7 1 1
7 2 1
7 3 1
7 4 1
7 5 1
7 6 1
7 7 1
)";

inline constexpr std::string_view SORT_DATA = R"(
0 0 0
0 1 1
0 2 1
Expand Down