Skip to content

Commit dbf547f

Browse files
authored
[flang][runtime] Add limit check to MOD/MODULO (#80026)
When testing the arguments to see whether they are integers, check first that they are within the maximum range of a 64-bit integer; otherwise, a value of larger magnitude will set an invalid operand exception flag.
1 parent dc15524 commit dbf547f

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

flang/runtime/numeric.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,19 @@ inline RT_API_ATTRS T RealMod(
145145
} else if (std::isinf(p)) {
146146
return a;
147147
}
148-
if (auto aInt{static_cast<std::int64_t>(a)}; a == aInt) {
149-
if (auto pInt{static_cast<std::int64_t>(p)}; p == pInt) {
150-
// Fast exact case for integer operands
151-
auto mod{aInt - (aInt / pInt) * pInt};
152-
if (IS_MODULO && (aInt > 0) != (pInt > 0)) {
153-
mod += pInt;
148+
T aAbs{std::abs(a)};
149+
T pAbs{std::abs(p)};
150+
if (aAbs <= static_cast<T>(std::numeric_limits<std::int64_t>::max()) &&
151+
pAbs <= static_cast<T>(std::numeric_limits<std::int64_t>::max())) {
152+
if (auto aInt{static_cast<std::int64_t>(a)}; a == aInt) {
153+
if (auto pInt{static_cast<std::int64_t>(p)}; p == pInt) {
154+
// Fast exact case for integer operands
155+
auto mod{aInt - (aInt / pInt) * pInt};
156+
if (IS_MODULO && (aInt > 0) != (pInt > 0)) {
157+
mod += pInt;
158+
}
159+
return static_cast<T>(mod);
154160
}
155-
return static_cast<T>(mod);
156161
}
157162
}
158163
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double> ||
@@ -183,9 +188,8 @@ inline RT_API_ATTRS T RealMod(
183188
// what's left is the desired remainder. This is basically
184189
// the same algorithm as arbitrary precision binary long division,
185190
// discarding the quotient.
186-
T tmp{std::abs(a)};
187-
T pAbs{std::abs(p)};
188-
for (T adj{SetExponent(pAbs, Exponent<int>(tmp))}; tmp >= pAbs; adj /= 2) {
191+
T tmp{aAbs};
192+
for (T adj{SetExponent(pAbs, Exponent<int>(aAbs))}; tmp >= pAbs; adj /= 2) {
189193
if (tmp >= adj) {
190194
tmp -= adj;
191195
if (tmp == 0) {

0 commit comments

Comments
 (0)