Skip to content

Commit b3f57e9

Browse files
committed
[libc] Restore libc/utils/MPFRWrapper from branch overmighty:libc-math-f16divf
1 parent 62baf21 commit b3f57e9

File tree

2 files changed

+81
-42
lines changed

2 files changed

+81
-42
lines changed

libc/utils/MPFRWrapper/MPFRUtils.cpp

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ class MPFRNumber {
296296
return result;
297297
}
298298

299+
MPFRNumber div(const MPFRNumber &b) const {
300+
MPFRNumber result(*this);
301+
mpfr_div(result.value, value, b.value, mpfr_rounding);
302+
return result;
303+
}
304+
299305
MPFRNumber floor() const {
300306
MPFRNumber result(*this);
301307
mpfr_floor(result.value, value);
@@ -708,6 +714,8 @@ binary_operation_one_output(Operation op, InputType x, InputType y,
708714
switch (op) {
709715
case Operation::Atan2:
710716
return inputX.atan2(inputY);
717+
case Operation::Div:
718+
return inputX.div(inputY);
711719
case Operation::Fmod:
712720
return inputX.fmod(inputY);
713721
case Operation::Hypot:
@@ -885,42 +893,47 @@ template void explain_binary_operation_two_outputs_error<long double>(
885893
Operation, const BinaryInput<long double> &,
886894
const BinaryOutput<long double> &, double, RoundingMode);
887895

888-
template <typename T>
889-
void explain_binary_operation_one_output_error(Operation op,
890-
const BinaryInput<T> &input,
891-
T libc_result,
892-
double ulp_tolerance,
893-
RoundingMode rounding) {
894-
unsigned int precision = get_precision<T>(ulp_tolerance);
896+
template <typename InputType, typename OutputType>
897+
void explain_binary_operation_one_output_error(
898+
Operation op, const BinaryInput<InputType> &input, OutputType libc_result,
899+
double ulp_tolerance, RoundingMode rounding) {
900+
unsigned int precision = get_precision<InputType>(ulp_tolerance);
895901
MPFRNumber mpfrX(input.x, precision);
896902
MPFRNumber mpfrY(input.y, precision);
897-
FPBits<T> xbits(input.x);
898-
FPBits<T> ybits(input.y);
903+
FPBits<InputType> xbits(input.x);
904+
FPBits<InputType> ybits(input.y);
899905
MPFRNumber mpfr_result =
900906
binary_operation_one_output(op, input.x, input.y, precision, rounding);
901907
MPFRNumber mpfrMatchValue(libc_result);
902908

903909
tlog << "Input decimal: x: " << mpfrX.str() << " y: " << mpfrY.str() << '\n';
904-
tlog << "First input bits: " << str(FPBits<T>(input.x)) << '\n';
905-
tlog << "Second input bits: " << str(FPBits<T>(input.y)) << '\n';
910+
tlog << "First input bits: " << str(FPBits<InputType>(input.x)) << '\n';
911+
tlog << "Second input bits: " << str(FPBits<InputType>(input.y)) << '\n';
906912

907913
tlog << "Libc result: " << mpfrMatchValue.str() << '\n'
908914
<< "MPFR result: " << mpfr_result.str() << '\n';
909-
tlog << "Libc floating point result bits: " << str(FPBits<T>(libc_result))
910-
<< '\n';
915+
tlog << "Libc floating point result bits: "
916+
<< str(FPBits<OutputType>(libc_result)) << '\n';
911917
tlog << " MPFR rounded bits: "
912-
<< str(FPBits<T>(mpfr_result.as<T>())) << '\n';
918+
<< str(FPBits<OutputType>(mpfr_result.as<OutputType>())) << '\n';
913919
tlog << "ULP error: " << mpfr_result.ulp_as_mpfr_number(libc_result).str()
914920
<< '\n';
915921
}
916922

917-
template void explain_binary_operation_one_output_error<float>(
918-
Operation, const BinaryInput<float> &, float, double, RoundingMode);
919-
template void explain_binary_operation_one_output_error<double>(
923+
template void
924+
explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
925+
float, double, RoundingMode);
926+
template void explain_binary_operation_one_output_error(
920927
Operation, const BinaryInput<double> &, double, double, RoundingMode);
921-
template void explain_binary_operation_one_output_error<long double>(
922-
Operation, const BinaryInput<long double> &, long double, double,
923-
RoundingMode);
928+
template void
929+
explain_binary_operation_one_output_error(Operation,
930+
const BinaryInput<long double> &,
931+
long double, double, RoundingMode);
932+
#ifdef LIBC_TYPES_HAS_FLOAT16
933+
template void
934+
explain_binary_operation_one_output_error(Operation, const BinaryInput<float> &,
935+
float16, double, RoundingMode);
936+
#endif
924937

925938
template <typename InputType, typename OutputType>
926939
void explain_ternary_operation_one_output_error(
@@ -1051,26 +1064,35 @@ template bool compare_binary_operation_two_outputs<long double>(
10511064
Operation, const BinaryInput<long double> &,
10521065
const BinaryOutput<long double> &, double, RoundingMode);
10531066

1054-
template <typename T>
1067+
template <typename InputType, typename OutputType>
10551068
bool compare_binary_operation_one_output(Operation op,
1056-
const BinaryInput<T> &input,
1057-
T libc_result, double ulp_tolerance,
1069+
const BinaryInput<InputType> &input,
1070+
OutputType libc_result,
1071+
double ulp_tolerance,
10581072
RoundingMode rounding) {
1059-
unsigned int precision = get_precision<T>(ulp_tolerance);
1073+
unsigned int precision = get_precision<InputType>(ulp_tolerance);
10601074
MPFRNumber mpfr_result =
10611075
binary_operation_one_output(op, input.x, input.y, precision, rounding);
10621076
double ulp = mpfr_result.ulp(libc_result);
10631077

10641078
return (ulp <= ulp_tolerance);
10651079
}
10661080

1067-
template bool compare_binary_operation_one_output<float>(
1068-
Operation, const BinaryInput<float> &, float, double, RoundingMode);
1069-
template bool compare_binary_operation_one_output<double>(
1070-
Operation, const BinaryInput<double> &, double, double, RoundingMode);
1071-
template bool compare_binary_operation_one_output<long double>(
1072-
Operation, const BinaryInput<long double> &, long double, double,
1073-
RoundingMode);
1081+
template bool compare_binary_operation_one_output(Operation,
1082+
const BinaryInput<float> &,
1083+
float, double, RoundingMode);
1084+
template bool compare_binary_operation_one_output(Operation,
1085+
const BinaryInput<double> &,
1086+
double, double, RoundingMode);
1087+
template bool
1088+
compare_binary_operation_one_output(Operation, const BinaryInput<long double> &,
1089+
long double, double, RoundingMode);
1090+
#ifdef LIBC_TYPES_HAS_FLOAT16
1091+
template bool compare_binary_operation_one_output(Operation,
1092+
const BinaryInput<float> &,
1093+
float16, double,
1094+
RoundingMode);
1095+
#endif
10741096

10751097
template <typename InputType, typename OutputType>
10761098
bool compare_ternary_operation_one_output(Operation op,

libc/utils/MPFRWrapper/MPFRUtils.h

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ enum class Operation : int {
7171
// output.
7272
BeginBinaryOperationsSingleOutput,
7373
Atan2,
74+
Div,
7475
Fmod,
7576
Hypot,
7677
Pow,
@@ -129,6 +130,14 @@ struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
129130
static constexpr bool VALUE = cpp::is_floating_point_v<T>;
130131
};
131132

133+
template <typename T> struct IsBinaryInput {
134+
static constexpr bool VALUE = false;
135+
};
136+
137+
template <typename T> struct IsBinaryInput<BinaryInput<T>> {
138+
static constexpr bool VALUE = true;
139+
};
140+
132141
template <typename T> struct IsTernaryInput {
133142
static constexpr bool VALUE = false;
134143
};
@@ -139,6 +148,9 @@ template <typename T> struct IsTernaryInput<TernaryInput<T>> {
139148

140149
template <typename T> struct MakeScalarInput : cpp::type_identity<T> {};
141150

151+
template <typename T>
152+
struct MakeScalarInput<BinaryInput<T>> : cpp::type_identity<T> {};
153+
142154
template <typename T>
143155
struct MakeScalarInput<TernaryInput<T>> : cpp::type_identity<T> {};
144156

@@ -159,10 +171,11 @@ bool compare_binary_operation_two_outputs(Operation op,
159171
double ulp_tolerance,
160172
RoundingMode rounding);
161173

162-
template <typename T>
174+
template <typename InputType, typename OutputType>
163175
bool compare_binary_operation_one_output(Operation op,
164-
const BinaryInput<T> &input,
165-
T libc_output, double ulp_tolerance,
176+
const BinaryInput<InputType> &input,
177+
OutputType libc_output,
178+
double ulp_tolerance,
166179
RoundingMode rounding);
167180

168181
template <typename InputType, typename OutputType>
@@ -187,12 +200,10 @@ void explain_binary_operation_two_outputs_error(
187200
const BinaryOutput<T> &match_value, double ulp_tolerance,
188201
RoundingMode rounding);
189202

190-
template <typename T>
191-
void explain_binary_operation_one_output_error(Operation op,
192-
const BinaryInput<T> &input,
193-
T match_value,
194-
double ulp_tolerance,
195-
RoundingMode rounding);
203+
template <typename InputType, typename OutputType>
204+
void explain_binary_operation_one_output_error(
205+
Operation op, const BinaryInput<InputType> &input, OutputType match_value,
206+
double ulp_tolerance, RoundingMode rounding);
196207

197208
template <typename InputType, typename OutputType>
198209
void explain_ternary_operation_one_output_error(
@@ -235,7 +246,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
235246
rounding);
236247
}
237248

238-
template <typename T> bool match(const BinaryInput<T> &in, T out) {
249+
template <typename T, typename U>
250+
bool match(const BinaryInput<T> &in, U out) {
239251
return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
240252
rounding);
241253
}
@@ -268,7 +280,8 @@ class MPFRMatcher : public testing::Matcher<OutputType> {
268280
rounding);
269281
}
270282

271-
template <typename T> void explain_error(const BinaryInput<T> &in, T out) {
283+
template <typename T, typename U>
284+
void explain_error(const BinaryInput<T> &in, U out) {
272285
explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
273286
rounding);
274287
}
@@ -290,6 +303,10 @@ constexpr bool is_valid_operation() {
290303
(op == Operation::Sqrt && cpp::is_floating_point_v<InputType> &&
291304
cpp::is_floating_point_v<OutputType> &&
292305
sizeof(OutputType) <= sizeof(InputType)) ||
306+
(op == Operation::Div && internal::IsBinaryInput<InputType>::VALUE &&
307+
cpp::is_floating_point_v<
308+
typename internal::MakeScalarInput<InputType>::type> &&
309+
cpp::is_floating_point_v<OutputType>) ||
293310
(op == Operation::Fma && internal::IsTernaryInput<InputType>::VALUE &&
294311
cpp::is_floating_point_v<
295312
typename internal::MakeScalarInput<InputType>::type> &&

0 commit comments

Comments
 (0)