50
50
#include " bad_comparator_values.h"
51
51
#include " check_assertion.h"
52
52
53
- void check_oob_sort_read () {
54
- std::map<std::size_t , std::map<std::size_t , bool >> comparison_results; // terrible for performance, but really convenient
55
- for (auto line : std::views::split (DATA, ' \n ' ) | std::views::filter ([](auto const & line) { return !line.empty (); })) {
56
- auto values = std::views::split (line, ' ' );
57
- auto it = values.begin ();
58
- std::size_t left = std::stol (std::string ((*it).data (), (*it).size ()));
59
- it = std::next (it);
60
- std::size_t right = std::stol (std::string ((*it).data (), (*it).size ()));
61
- it = std::next (it);
62
- bool result = static_cast <bool >(std::stol (std::string ((*it).data (), (*it).size ())));
63
- comparison_results[left][right] = result;
64
- }
65
- auto predicate = [&](std::size_t * left, std::size_t * right) {
53
+ class ComparisonResults {
54
+ public:
55
+ explicit ComparisonResults (std::string_view data) {
56
+ for (auto line : std::views::split (data, ' \n ' ) | std::views::filter ([](auto const & line) { return !line.empty (); })) {
57
+ auto values = std::views::split (line, ' ' );
58
+ auto it = values.begin ();
59
+ std::size_t left = std::stol (std::string ((*it).data (), (*it).size ()));
60
+ it = std::next (it);
61
+ std::size_t right = std::stol (std::string ((*it).data (), (*it).size ()));
62
+ it = std::next (it);
63
+ bool result = static_cast <bool >(std::stol (std::string ((*it).data (), (*it).size ())));
64
+ comparison_results[left][right] = result;
65
+ }
66
+ }
67
+
68
+ bool compare (size_t * left, size_t * right) const {
66
69
assert (left != nullptr && right != nullptr && " something is wrong with the test" );
67
- assert (comparison_results.contains (*left) && comparison_results[ *left] .contains (*right) && " malformed input data?" );
68
- return comparison_results[ *left][ *right] ;
69
- };
70
+ assert (comparison_results.contains (*left) && comparison_results. at ( *left) .contains (*right) && " malformed input data?" );
71
+ return comparison_results. at ( *left). at ( *right) ;
72
+ }
70
73
74
+ size_t size () const { return comparison_results.size (); }
75
+ private:
76
+ std::map<std::size_t , std::map<std::size_t , bool >> comparison_results; // terrible for performance, but really convenient
77
+ };
78
+
79
+ void check_oob_sort_read () {
80
+ ComparisonResults comparison_results (SORT_DATA);
71
81
std::vector<std::unique_ptr<std::size_t >> elements;
72
82
std::set<std::size_t *> valid_ptrs;
73
83
for (std::size_t i = 0 ; i != comparison_results.size (); ++i) {
@@ -81,7 +91,7 @@ void check_oob_sort_read() {
81
91
// because we're reading OOB.
82
92
assert (valid_ptrs.contains (left));
83
93
assert (valid_ptrs.contains (right));
84
- return predicate (left, right);
94
+ return comparison_results. compare (left, right);
85
95
};
86
96
87
97
// Check the classic sorting algorithms
@@ -117,12 +127,6 @@ void check_oob_sort_read() {
117
127
std::vector<std::size_t *> results (copy.size (), nullptr );
118
128
TEST_LIBCPP_ASSERT_FAILURE (std::partial_sort_copy (copy.begin (), copy.end (), results.begin (), results.end (), checked_predicate), " not a valid strict-weak ordering" );
119
129
}
120
- {
121
- std::vector<std::size_t *> copy;
122
- for (auto const & e : elements)
123
- copy.push_back (e.get ());
124
- std::nth_element (copy.begin (), copy.end (), copy.end (), checked_predicate); // doesn't go OOB even with invalid comparator
125
- }
126
130
127
131
// Check the Ranges sorting algorithms
128
132
{
@@ -157,11 +161,38 @@ void check_oob_sort_read() {
157
161
std::vector<std::size_t *> results (copy.size (), nullptr );
158
162
TEST_LIBCPP_ASSERT_FAILURE (std::ranges::partial_sort_copy (copy, results, checked_predicate), " not a valid strict-weak ordering" );
159
163
}
164
+ }
165
+
166
+ void check_oob_nth_element_read () {
167
+ ComparisonResults results (NTH_ELEMENT_DATA);
168
+ std::vector<std::unique_ptr<std::size_t >> elements;
169
+ std::set<std::size_t *> valid_ptrs;
170
+ for (std::size_t i = 0 ; i != results.size (); ++i) {
171
+ elements.push_back (std::make_unique<std::size_t >(i));
172
+ valid_ptrs.insert (elements.back ().get ());
173
+ }
174
+
175
+ auto checked_predicate = [&](size_t * left, size_t * right) {
176
+ // If the pointers passed to the comparator are not in the set of pointers we
177
+ // set up above, then we're being passed garbage values from the algorithm
178
+ // because we're reading OOB.
179
+ assert (valid_ptrs.contains (left));
180
+ assert (valid_ptrs.contains (right));
181
+ return results.compare (left, right);
182
+ };
183
+
160
184
{
161
185
std::vector<std::size_t *> copy;
162
186
for (auto const & e : elements)
163
187
copy.push_back (e.get ());
164
- std::ranges::nth_element (copy, copy.end (), checked_predicate); // doesn't go OOB even with invalid comparator
188
+ TEST_LIBCPP_ASSERT_FAILURE (std::nth_element (copy.begin (), copy.begin (), copy.end (), checked_predicate), " Would read out of bounds" );
189
+ }
190
+
191
+ {
192
+ std::vector<std::size_t *> copy;
193
+ for (auto const & e : elements)
194
+ copy.push_back (e.get ());
195
+ TEST_LIBCPP_ASSERT_FAILURE (std::ranges::nth_element (copy, copy.begin (), checked_predicate), " Would read out of bounds" );
165
196
}
166
197
}
167
198
@@ -214,6 +245,8 @@ int main(int, char**) {
214
245
215
246
check_oob_sort_read ();
216
247
248
+ check_oob_nth_element_read ();
249
+
217
250
check_nan_floats ();
218
251
219
252
check_irreflexive ();
0 commit comments