Skip to content

[libc++] Add input validation for set_intersection() in debug mode. #101508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libcxx/docs/ReleaseNotes/20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ Improvements and New Features
- The ``_LIBCPP_ABI_BOUNDED_ITERATORS_IN_STD_ARRAY`` ABI configuration was added, which allows storing valid bounds
in ``std::array::iterator`` and detecting OOB accesses when the appropriate hardening mode is enabled.

- ``std::set_intersection`` and ``std::ranges::set_intersection`` will now validate that inputs are sorted when compiled
in :ref:`debug hardening mode mode <using-hardening-modes>`. Results from these functions were changed in LLVM 19
with this class of invalid inputs, the new validation helps identify the source of undefined behavior.

Deprecations and Removals
-------------------------

Expand Down
16 changes: 8 additions & 8 deletions libcxx/include/__algorithm/is_sorted_until.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@

_LIBCPP_BEGIN_NAMESPACE_STD

template <class _Compare, class _ForwardIterator>
template <class _Compare, class _ForwardIterator, class _Sent>
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator
__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) {
__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) {
if (__first != __last) {
_ForwardIterator __i = __first;
while (++__i != __last) {
if (__comp(*__i, *__first))
return __i;
__first = __i;
_ForwardIterator __prev = __first;
while (++__first != __last) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what is the reason to swap first and i here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because we might skip the loop altogether if (__first == __last), and with the sentinel-friendly interface we can no longer return __last. In the old version __first would always be behind __i, so a fully-sorted non-empty input would have us returning the one-before-last position.

if (__comp(*__first, *__prev))
return __first;
__prev = __first;
}
}
return __last;
return __first;
}

template <class _ForwardIterator, class _Compare>
Expand Down
9 changes: 9 additions & 0 deletions libcxx/include/__algorithm/set_intersection.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@

#include <__algorithm/comp.h>
#include <__algorithm/comp_ref_type.h>
#include <__algorithm/is_sorted_until.h>
#include <__algorithm/iterator_operations.h>
#include <__algorithm/lower_bound.h>
#include <__assert>
#include <__config>
#include <__functional/identity.h>
#include <__iterator/iterator_traits.h>
#include <__iterator/next.h>
#include <__type_traits/is_constant_evaluated.h>
#include <__type_traits/is_same.h>
#include <__utility/exchange.h>
#include <__utility/forward.h>
Expand Down Expand Up @@ -96,6 +99,12 @@ __set_intersection(
_Compare&& __comp,
std::forward_iterator_tag,
std::forward_iterator_tag) {
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
_LIBCPP_ASSERT_SEMANTIC_REQUIREMENT(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch should also add tests to trigger these assertions. You can see an example of how that's done in e.g. libcxx/test/libcxx/algorithms/alg.sorting/assert.min.max.pass.cpp. You should be able to just purposefully call the algorithm with something that's badly sorted.

std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted");
_LIBCPP_ASSERT_SEMANTIC_REQUIREMENT(
std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted");
#endif
_LIBCPP_CONSTEXPR std::__identity __proj;
bool __prev_may_be_equal = false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,31 @@

#include "test_iterators.h"

namespace {

// __debug_less will perform an additional comparison in an assertion
static constexpr unsigned std_less_comparison_count_multiplier() noexcept {
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG
return 2;
// We don't check number of operations in Debug mode because they are not stable enough due to additional validations.
#if defined(_LIBCPP_HARDENING_MODE_DEBUG) && _LIBCPP_HARDENING_MODE_DEBUG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#if defined(_LIBCPP_HARDENING_MODE_DEBUG) && _LIBCPP_HARDENING_MODE_DEBUG
#if defined(_LIBCPP_VERSION) && _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG

# define ASSERT_COMPLEXITY(expression) (void)(expression)
#else
return 1;
# define ASSERT_COMPLEXITY(expression) assert(expression)
#endif
}

namespace {

struct [[nodiscard]] OperationCounts {
std::size_t comparisons{};
struct PerInput {
std::size_t proj{};
IteratorOpCounts iterops;

[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) {
[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept {
return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >=
other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves;
}
};
std::array<PerInput, 2> in;

[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) {
return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons &&
in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]);
[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept {
return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) &&
in[1].isNotBetterThan(expect.in[1]);
}
};

Expand All @@ -80,16 +78,17 @@ struct counted_set_intersection_result {

constexpr counted_set_intersection_result() = default;

constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {}
constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept
: result{contents} {}

constexpr void assertNotBetterThan(const counted_set_intersection_result& other) {
constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why add noexcept?

assert(result == other.result);
assert(opcounts.isNotBetterThan(other.opcounts));
ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts));
}
};

template <std::size_t ResultSize>
counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>;
counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>;

template <template <class...> class InIterType1,
template <class...>
Expand Down Expand Up @@ -306,7 +305,7 @@ constexpr bool testComplexityBasic() {
std::array<int, 5> r2{2, 4, 6, 8, 10};
std::array<int, 0> expected{};

const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1);
const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1;

// std::set_intersection
{
Expand All @@ -321,7 +320,7 @@ constexpr bool testComplexityBasic() {
std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp);

assert(std::ranges::equal(out, expected));
assert(numberOfComp <= maxOperation);
ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
}

// ranges::set_intersection iterator overload
Expand Down Expand Up @@ -349,9 +348,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2);

assert(std::ranges::equal(out, expected));
assert(numberOfComp <= maxOperation);
assert(numberOfProj1 <= maxOperation);
assert(numberOfProj2 <= maxOperation);
ASSERT_COMPLEXITY(numberOfComp <= maxOperation);
ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation);
ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation);
}

// ranges::set_intersection range overload
Expand Down Expand Up @@ -379,9 +378,9 @@ constexpr bool testComplexityBasic() {
std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2);

assert(std::ranges::equal(out, expected));
assert(numberOfComp < maxOperation);
assert(numberOfProj1 < maxOperation);
assert(numberOfProj2 < maxOperation);
ASSERT_COMPLEXITY(numberOfComp < maxOperation);
ASSERT_COMPLEXITY(numberOfProj1 < maxOperation);
ASSERT_COMPLEXITY(numberOfProj2 < maxOperation);
}
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,44 +40,45 @@ constexpr bool test_all() {
constexpr auto operator<=>(const A&) const = default;
};

std::array in = {1, 2, 3};
std::array in2 = {A{4}, A{5}, A{6}};
const std::array in = {1, 2, 3};
const std::array in2 = {A{4}, A{5}, A{6}};

std::array output = {7, 8, 9, 10, 11, 12};
auto out = output.begin();
std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}};
auto out2 = output2.begin();

std::ranges::equal_to eq;
std::ranges::less less;
auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
auto proj1 = [](int x) { return x * -1; };
auto proj2 = [](A a) { return a.x * -1; };
const std::ranges::equal_to eq;
const std::ranges::less less;
const std::ranges::greater greater;
const auto sum = [](int lhs, A rhs) { return lhs + rhs.x; };
const auto proj1 = [](int x) { return x * -1; };
const auto proj2 = [](A a) { return a.x * -1; };

#if TEST_STD_VER >= 23
test(std::ranges::ends_with, in, in2, eq, proj1, proj2);
#endif
test(std::ranges::equal, in, in2, eq, proj1, proj2);
test(std::ranges::lexicographical_compare, in, in2, eq, proj1, proj2);
test(std::ranges::is_permutation, in, in2, eq, proj1, proj2);
test(std::ranges::includes, in, in2, less, proj1, proj2);
test(std::ranges::includes, in, in2, greater, proj1, proj2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

test(std::ranges::find_first_of, in, in2, eq, proj1, proj2);
test(std::ranges::mismatch, in, in2, eq, proj1, proj2);
test(std::ranges::search, in, in2, eq, proj1, proj2);
test(std::ranges::find_end, in, in2, eq, proj1, proj2);
test(std::ranges::transform, in, in2, out, sum, proj1, proj2);
test(std::ranges::transform, in, in2, out2, sum, proj1, proj2);
test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2);
test(std::ranges::merge, in, in2, out, less, proj1, proj2);
test(std::ranges::merge, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_difference, in, in2, out, less, proj1, proj2);
test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2);
test(std::ranges::set_union, in, in2, out, less, proj1, proj2);
test(std::ranges::set_union, in, in2, out2, less, proj1, proj2);
test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2);
test(std::ranges::merge, in, in2, out, greater, proj1, proj2);
test(std::ranges::merge, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2);
test(std::ranges::set_union, in, in2, out, greater, proj1, proj2);
test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2);
#if TEST_STD_VER > 20
test(std::ranges::starts_with, in, in2, eq, proj1, proj2);
#endif
Expand Down
Loading