8
8
9
9
#include " src/__support/CPP/type_traits.h"
10
10
#include " src/__support/FPUtil/FPBits.h"
11
+ #include " src/__support/macros/properties/types.h"
11
12
#include " test/UnitTest/FPMatcher.h"
12
13
#include " test/UnitTest/Test.h"
13
14
#include " utils/MPFRWrapper/MPFRUtils.h"
@@ -68,6 +69,43 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
68
69
}
69
70
};
70
71
72
+ template <typename T> using BinaryOp = T(T, T);
73
+
74
+ template <typename T, mpfr::Operation Op, BinaryOp<T> Func>
75
+ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
76
+ using FloatType = T;
77
+ using FPBits = LIBC_NAMESPACE::fputil::FPBits<FloatType>;
78
+ using StorageType = typename FPBits::StorageType;
79
+
80
+ static constexpr BinaryOp<FloatType> *FUNC = Func;
81
+
82
+ // Check in a range, return the number of failures.
83
+ uint64_t check (StorageType x_start, StorageType x_stop, StorageType y_start,
84
+ StorageType y_stop, mpfr::RoundingMode rounding) {
85
+ mpfr::ForceRoundingMode r (rounding);
86
+ if (!r.success )
87
+ return (x_stop > x_start || y_stop > y_start);
88
+ StorageType xbits = x_start;
89
+ uint64_t failed = 0 ;
90
+ do {
91
+ FloatType x = FPBits (xbits).get_val ();
92
+ StorageType ybits = y_start;
93
+ do {
94
+ FloatType y = FPBits (ybits).get_val ();
95
+ mpfr::BinaryInput<FloatType> input{x, y};
96
+ bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY (Op, input, FUNC (x, y),
97
+ 0.5 , rounding);
98
+ failed += (!correct);
99
+ // Uncomment to print out failed values.
100
+ // if (!correct) {
101
+ // TEST_MPFR_MATCH(Op::Operation, x, Op::func(x, y), 0.5, rounding);
102
+ // }
103
+ } while (ybits++ < y_stop);
104
+ } while (xbits++ < x_stop);
105
+ return failed;
106
+ }
107
+ };
108
+
71
109
// Checker class needs inherit from LIBC_NAMESPACE::testing::Test and provide
72
110
// StorageType and check method.
73
111
template <typename Checker>
@@ -167,6 +205,114 @@ struct LlvmLibcExhaustiveMathTest
167
205
};
168
206
};
169
207
208
+ template <typename Checker>
209
+ struct LlvmLibcBinaryInputExhaustiveMathTest
210
+ : public virtual LIBC_NAMESPACE::testing::Test,
211
+ public Checker {
212
+ using FloatType = typename Checker::FloatType;
213
+ using FPBits = typename Checker::FPBits;
214
+ using StorageType = typename Checker::StorageType;
215
+
216
+ static constexpr StorageType Increment = (1 << 2 );
217
+
218
+ // Break [start, stop) into `nthreads` subintervals and apply *check to each
219
+ // subinterval in parallel.
220
+ void test_full_range (StorageType x_start, StorageType x_stop,
221
+ StorageType y_start, StorageType y_stop,
222
+ mpfr::RoundingMode rounding) {
223
+ int n_threads = std::thread::hardware_concurrency ();
224
+ std::vector<std::thread> thread_list;
225
+ std::mutex mx_cur_val;
226
+ int current_percent = -1 ;
227
+ StorageType current_value = x_start;
228
+ std::atomic<uint64_t > failed (0 );
229
+
230
+ for (int i = 0 ; i < n_threads; ++i) {
231
+ thread_list.emplace_back ([&, this ]() {
232
+ while (true ) {
233
+ StorageType range_begin, range_end;
234
+ int new_percent = -1 ;
235
+ {
236
+ std::lock_guard<std::mutex> lock (mx_cur_val);
237
+ if (current_value == x_stop)
238
+ return ;
239
+
240
+ range_begin = current_value;
241
+ if (x_stop >= Increment && x_stop - Increment >= current_value) {
242
+ range_end = current_value + Increment;
243
+ } else {
244
+ range_end = x_stop;
245
+ }
246
+ current_value = range_end;
247
+ int pc = 100.0 * (range_end - x_start) / (x_stop - x_start);
248
+ if (current_percent != pc) {
249
+ new_percent = pc;
250
+ current_percent = pc;
251
+ }
252
+ }
253
+ if (new_percent >= 0 ) {
254
+ std::stringstream msg;
255
+ msg << new_percent << " % is in process \r " ;
256
+ std::cout << msg.str () << std::flush;
257
+ }
258
+
259
+ uint64_t failed_in_range =
260
+ Checker::check (range_begin, range_end, y_start, y_stop, rounding);
261
+ if (failed_in_range > 0 ) {
262
+ using T = LIBC_NAMESPACE::cpp::conditional_t <
263
+ LIBC_NAMESPACE::cpp::is_same_v<FloatType, float16>, float ,
264
+ FloatType>;
265
+ std::stringstream msg;
266
+ msg << " Test failed for " << std::dec << failed_in_range
267
+ << " inputs in range: " << range_begin << " to " << range_end
268
+ << " [0x" << std::hex << range_begin << " , 0x" << range_end
269
+ << " ), [" << std::hexfloat
270
+ << static_cast <T>(FPBits (range_begin).get_val ()) << " , "
271
+ << static_cast <T>(FPBits (range_end).get_val ()) << " )\n " ;
272
+ std::cerr << msg.str () << std::flush;
273
+
274
+ failed.fetch_add (failed_in_range);
275
+ }
276
+ }
277
+ });
278
+ }
279
+
280
+ for (auto &thread : thread_list) {
281
+ if (thread.joinable ()) {
282
+ thread.join ();
283
+ }
284
+ }
285
+
286
+ std::cout << std::endl;
287
+ std::cout << " Test " << ((failed > 0 ) ? " FAILED" : " PASSED" ) << std::endl;
288
+ ASSERT_EQ (failed.load (), uint64_t (0 ));
289
+ }
290
+
291
+ void test_full_range_all_roundings (StorageType x_start, StorageType x_stop,
292
+ StorageType y_start, StorageType y_stop) {
293
+ test_full_range (x_start, x_stop, y_start, y_stop,
294
+ mpfr::RoundingMode::Nearest);
295
+
296
+ std::cout << " -- Testing for FE_UPWARD in x range [0x" << std::hex
297
+ << x_start << " , 0x" << x_stop << " ) y range [0x" << std::hex
298
+ << y_start << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
299
+ test_full_range (x_start, x_stop, y_start, y_stop,
300
+ mpfr::RoundingMode::Upward);
301
+
302
+ std::cout << " -- Testing for FE_DOWNWARD in x range [0x" << std::hex
303
+ << x_start << " , 0x" << x_stop << " ) y range [0x" << std::hex
304
+ << y_start << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
305
+ test_full_range (x_start, x_stop, y_start, y_stop,
306
+ mpfr::RoundingMode::Downward);
307
+
308
+ std::cout << " -- Testing for FE_TOWARDZERO in x range [0x" << std::hex
309
+ << x_start << " , 0x" << x_stop << " ) y range [0x" << std::hex
310
+ << y_start << " , 0x" << y_stop << " ) --" << std::dec << std::endl;
311
+ test_full_range (x_start, x_stop, y_start, y_stop,
312
+ mpfr::RoundingMode::TowardZero);
313
+ };
314
+ };
315
+
170
316
template <typename FloatType, mpfr::Operation Op, UnaryOp<FloatType> Func>
171
317
using LlvmLibcUnaryOpExhaustiveMathTest =
172
318
LlvmLibcExhaustiveMathTest<UnaryOpChecker<FloatType, FloatType, Op, Func>>;
@@ -175,3 +321,7 @@ template <typename OutType, typename InType, mpfr::Operation Op,
175
321
UnaryOp<OutType, InType> Func>
176
322
using LlvmLibcUnaryNarrowingOpExhaustiveMathTest =
177
323
LlvmLibcExhaustiveMathTest<UnaryOpChecker<OutType, InType, Op, Func>>;
324
+
325
+ template <typename FloatType, mpfr::Operation Op, BinaryOp<FloatType> Func>
326
+ using LlvmLibcBinaryOpExhaustiveMathTest =
327
+ LlvmLibcBinaryInputExhaustiveMathTest<BinaryOpChecker<FloatType, Op, Func>>;
0 commit comments