-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[libc++] Add input validation for set_intersection() in debug mode. #101508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5620dce
7f2beae
3006f13
22c9f1c
3511bed
e8469c1
9a7cceb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,18 +20,18 @@ | |
|
||
_LIBCPP_BEGIN_NAMESPACE_STD | ||
|
||
template <class _Compare, class _ForwardIterator> | ||
template <class _Compare, class _ForwardIterator, class _Sent> | ||
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 _ForwardIterator | ||
__is_sorted_until(_ForwardIterator __first, _ForwardIterator __last, _Compare __comp) { | ||
__is_sorted_until(_ForwardIterator __first, _Sent __last, _Compare&& __comp) { | ||
if (__first != __last) { | ||
_ForwardIterator __i = __first; | ||
while (++__i != __last) { | ||
if (__comp(*__i, *__first)) | ||
return __i; | ||
__first = __i; | ||
_ForwardIterator __prev = __first; | ||
while (++__first != __last) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: what is the reason to swap There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's because we might skip the loop altogether |
||
if (__comp(*__first, *__prev)) | ||
return __first; | ||
__prev = __first; | ||
} | ||
} | ||
return __last; | ||
return __first; | ||
} | ||
|
||
template <class _ForwardIterator, class _Compare> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,12 +11,15 @@ | |
|
||
#include <__algorithm/comp.h> | ||
#include <__algorithm/comp_ref_type.h> | ||
#include <__algorithm/is_sorted_until.h> | ||
#include <__algorithm/iterator_operations.h> | ||
#include <__algorithm/lower_bound.h> | ||
#include <__assert> | ||
#include <__config> | ||
#include <__functional/identity.h> | ||
#include <__iterator/iterator_traits.h> | ||
#include <__iterator/next.h> | ||
#include <__type_traits/is_constant_evaluated.h> | ||
#include <__type_traits/is_same.h> | ||
#include <__utility/exchange.h> | ||
#include <__utility/forward.h> | ||
|
@@ -96,6 +99,12 @@ __set_intersection( | |
_Compare&& __comp, | ||
std::forward_iterator_tag, | ||
std::forward_iterator_tag) { | ||
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG | ||
_LIBCPP_ASSERT_SEMANTIC_REQUIREMENT( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This patch should also add tests to trigger these assertions. You can see an example of how that's done in e.g. |
||
std::__is_sorted_until(__first1, __last1, __comp) == __last1, "set_intersection: input range 1 must be sorted"); | ||
_LIBCPP_ASSERT_SEMANTIC_REQUIREMENT( | ||
std::__is_sorted_until(__first2, __last2, __comp) == __last2, "set_intersection: input range 2 must be sorted"); | ||
#endif | ||
_LIBCPP_CONSTEXPR std::__identity __proj; | ||
bool __prev_may_be_equal = false; | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -43,33 +43,31 @@ | |||||
|
||||||
#include "test_iterators.h" | ||||||
|
||||||
namespace { | ||||||
|
||||||
// __debug_less will perform an additional comparison in an assertion | ||||||
static constexpr unsigned std_less_comparison_count_multiplier() noexcept { | ||||||
#if _LIBCPP_HARDENING_MODE == _LIBCPP_HARDENING_MODE_DEBUG | ||||||
return 2; | ||||||
// We don't check number of operations in Debug mode because they are not stable enough due to additional validations. | ||||||
#if defined(_LIBCPP_HARDENING_MODE_DEBUG) && _LIBCPP_HARDENING_MODE_DEBUG | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# define ASSERT_COMPLEXITY(expression) (void)(expression) | ||||||
#else | ||||||
return 1; | ||||||
# define ASSERT_COMPLEXITY(expression) assert(expression) | ||||||
#endif | ||||||
} | ||||||
|
||||||
namespace { | ||||||
|
||||||
struct [[nodiscard]] OperationCounts { | ||||||
std::size_t comparisons{}; | ||||||
struct PerInput { | ||||||
std::size_t proj{}; | ||||||
IteratorOpCounts iterops; | ||||||
|
||||||
[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) { | ||||||
[[nodiscard]] constexpr bool isNotBetterThan(const PerInput& other) const noexcept { | ||||||
return proj >= other.proj && iterops.increments + iterops.decrements + iterops.zero_moves >= | ||||||
other.iterops.increments + other.iterops.decrements + other.iterops.zero_moves; | ||||||
} | ||||||
}; | ||||||
std::array<PerInput, 2> in; | ||||||
|
||||||
[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) { | ||||||
return std_less_comparison_count_multiplier() * comparisons >= expect.comparisons && | ||||||
in[0].isNotBetterThan(expect.in[0]) && in[1].isNotBetterThan(expect.in[1]); | ||||||
[[nodiscard]] constexpr bool isNotBetterThan(const OperationCounts& expect) const noexcept { | ||||||
return comparisons >= expect.comparisons && in[0].isNotBetterThan(expect.in[0]) && | ||||||
in[1].isNotBetterThan(expect.in[1]); | ||||||
} | ||||||
}; | ||||||
|
||||||
|
@@ -80,16 +78,17 @@ struct counted_set_intersection_result { | |||||
|
||||||
constexpr counted_set_intersection_result() = default; | ||||||
|
||||||
constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) : result{contents} {} | ||||||
constexpr explicit counted_set_intersection_result(std::array<int, ResultSize>&& contents) noexcept | ||||||
: result{contents} {} | ||||||
|
||||||
constexpr void assertNotBetterThan(const counted_set_intersection_result& other) { | ||||||
constexpr void assertNotBetterThan(const counted_set_intersection_result& other) const noexcept { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious, why add |
||||||
assert(result == other.result); | ||||||
assert(opcounts.isNotBetterThan(other.opcounts)); | ||||||
ASSERT_COMPLEXITY(opcounts.isNotBetterThan(other.opcounts)); | ||||||
} | ||||||
}; | ||||||
|
||||||
template <std::size_t ResultSize> | ||||||
counted_set_intersection_result(std::array<int, ResultSize>) -> counted_set_intersection_result<ResultSize>; | ||||||
counted_set_intersection_result(std::array<int, ResultSize>) noexcept -> counted_set_intersection_result<ResultSize>; | ||||||
|
||||||
template <template <class...> class InIterType1, | ||||||
template <class...> | ||||||
|
@@ -306,7 +305,7 @@ constexpr bool testComplexityBasic() { | |||||
std::array<int, 5> r2{2, 4, 6, 8, 10}; | ||||||
std::array<int, 0> expected{}; | ||||||
|
||||||
const std::size_t maxOperation = std_less_comparison_count_multiplier() * (2 * (r1.size() + r2.size()) - 1); | ||||||
const std::size_t maxOperation = 2 * (r1.size() + r2.size()) - 1; | ||||||
|
||||||
// std::set_intersection | ||||||
{ | ||||||
|
@@ -321,7 +320,7 @@ constexpr bool testComplexityBasic() { | |||||
std::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp); | ||||||
|
||||||
assert(std::ranges::equal(out, expected)); | ||||||
assert(numberOfComp <= maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfComp <= maxOperation); | ||||||
} | ||||||
|
||||||
// ranges::set_intersection iterator overload | ||||||
|
@@ -349,9 +348,9 @@ constexpr bool testComplexityBasic() { | |||||
std::ranges::set_intersection(r1.begin(), r1.end(), r2.begin(), r2.end(), out.data(), comp, proj1, proj2); | ||||||
|
||||||
assert(std::ranges::equal(out, expected)); | ||||||
assert(numberOfComp <= maxOperation); | ||||||
assert(numberOfProj1 <= maxOperation); | ||||||
assert(numberOfProj2 <= maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfComp <= maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfProj1 <= maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfProj2 <= maxOperation); | ||||||
} | ||||||
|
||||||
// ranges::set_intersection range overload | ||||||
|
@@ -379,9 +378,9 @@ constexpr bool testComplexityBasic() { | |||||
std::ranges::set_intersection(r1, r2, out.data(), comp, proj1, proj2); | ||||||
|
||||||
assert(std::ranges::equal(out, expected)); | ||||||
assert(numberOfComp < maxOperation); | ||||||
assert(numberOfProj1 < maxOperation); | ||||||
assert(numberOfProj2 < maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfComp < maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfProj1 < maxOperation); | ||||||
ASSERT_COMPLEXITY(numberOfProj2 < maxOperation); | ||||||
} | ||||||
return true; | ||||||
} | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,44 +40,45 @@ constexpr bool test_all() { | |
constexpr auto operator<=>(const A&) const = default; | ||
}; | ||
|
||
std::array in = {1, 2, 3}; | ||
std::array in2 = {A{4}, A{5}, A{6}}; | ||
const std::array in = {1, 2, 3}; | ||
const std::array in2 = {A{4}, A{5}, A{6}}; | ||
|
||
std::array output = {7, 8, 9, 10, 11, 12}; | ||
auto out = output.begin(); | ||
std::array output2 = {A{7}, A{8}, A{9}, A{10}, A{11}, A{12}}; | ||
auto out2 = output2.begin(); | ||
|
||
std::ranges::equal_to eq; | ||
std::ranges::less less; | ||
auto sum = [](int lhs, A rhs) { return lhs + rhs.x; }; | ||
auto proj1 = [](int x) { return x * -1; }; | ||
auto proj2 = [](A a) { return a.x * -1; }; | ||
const std::ranges::equal_to eq; | ||
const std::ranges::less less; | ||
const std::ranges::greater greater; | ||
const auto sum = [](int lhs, A rhs) { return lhs + rhs.x; }; | ||
const auto proj1 = [](int x) { return x * -1; }; | ||
const auto proj2 = [](A a) { return a.x * -1; }; | ||
|
||
#if TEST_STD_VER >= 23 | ||
test(std::ranges::ends_with, in, in2, eq, proj1, proj2); | ||
#endif | ||
test(std::ranges::equal, in, in2, eq, proj1, proj2); | ||
test(std::ranges::lexicographical_compare, in, in2, eq, proj1, proj2); | ||
test(std::ranges::is_permutation, in, in2, eq, proj1, proj2); | ||
test(std::ranges::includes, in, in2, less, proj1, proj2); | ||
test(std::ranges::includes, in, in2, greater, proj1, proj2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this change needed? |
||
test(std::ranges::find_first_of, in, in2, eq, proj1, proj2); | ||
test(std::ranges::mismatch, in, in2, eq, proj1, proj2); | ||
test(std::ranges::search, in, in2, eq, proj1, proj2); | ||
test(std::ranges::find_end, in, in2, eq, proj1, proj2); | ||
test(std::ranges::transform, in, in2, out, sum, proj1, proj2); | ||
test(std::ranges::transform, in, in2, out2, sum, proj1, proj2); | ||
test(std::ranges::partial_sort_copy, in, in2, less, proj1, proj2); | ||
test(std::ranges::merge, in, in2, out, less, proj1, proj2); | ||
test(std::ranges::merge, in, in2, out2, less, proj1, proj2); | ||
test(std::ranges::set_intersection, in, in2, out, less, proj1, proj2); | ||
test(std::ranges::set_intersection, in, in2, out2, less, proj1, proj2); | ||
test(std::ranges::set_difference, in, in2, out, less, proj1, proj2); | ||
test(std::ranges::set_difference, in, in2, out2, less, proj1, proj2); | ||
test(std::ranges::set_symmetric_difference, in, in2, out, less, proj1, proj2); | ||
test(std::ranges::set_symmetric_difference, in, in2, out2, less, proj1, proj2); | ||
test(std::ranges::set_union, in, in2, out, less, proj1, proj2); | ||
test(std::ranges::set_union, in, in2, out2, less, proj1, proj2); | ||
test(std::ranges::partial_sort_copy, in, output, less, proj1, proj2); | ||
test(std::ranges::merge, in, in2, out, greater, proj1, proj2); | ||
test(std::ranges::merge, in, in2, out2, greater, proj1, proj2); | ||
test(std::ranges::set_intersection, in, in2, out, greater, proj1, proj2); | ||
test(std::ranges::set_intersection, in, in2, out2, greater, proj1, proj2); | ||
test(std::ranges::set_difference, in, in2, out, greater, proj1, proj2); | ||
test(std::ranges::set_difference, in, in2, out2, greater, proj1, proj2); | ||
test(std::ranges::set_symmetric_difference, in, in2, out, greater, proj1, proj2); | ||
test(std::ranges::set_symmetric_difference, in, in2, out2, greater, proj1, proj2); | ||
test(std::ranges::set_union, in, in2, out, greater, proj1, proj2); | ||
test(std::ranges::set_union, in, in2, out2, greater, proj1, proj2); | ||
#if TEST_STD_VER > 20 | ||
test(std::ranges::starts_with, in, in2, eq, proj1, proj2); | ||
#endif | ||
|
Uh oh!
There was an error while loading. Please reload this page.