Skip to content

Commit 4f1eec1

Browse files
committed
[flang] Fix crash in folding of DPROD() with non-scalar arguments
DPROD(x,y) is defined as DBLE(x)*DBLE(y) and that's exactly how the implementation of its rewriting and possible folding should be implemented, instead of the current code that only works when both arguments are scalar and crashes otherwise. Fixes llvm#63991. Differential Revision: https://reviews.llvm.org/D156754
1 parent 9c446da commit 4f1eec1

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

flang/lib/Evaluate/fold-real.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,15 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
149149
} else if (name == "dot_product") {
150150
return FoldDotProduct<T>(context, std::move(funcRef));
151151
} else if (name == "dprod") {
152-
if (auto scalars{GetScalarConstantArguments<T, T>(context, args)}) {
153-
return Fold(context,
154-
Expr<T>{Multiply<T>{
155-
Expr<T>{std::get<0>(*scalars)}, Expr<T>{std::get<1>(*scalars)}}});
152+
// Rewrite DPROD(x,y) -> DBLE(x)*DBLE(y)
153+
if (args.at(0) && args.at(1)) {
154+
const auto *xExpr{args[0]->UnwrapExpr()};
155+
const auto *yExpr{args[1]->UnwrapExpr()};
156+
if (xExpr && yExpr) {
157+
return Fold(context,
158+
ToReal<T::kind>(context, common::Clone(*xExpr)) *
159+
ToReal<T::kind>(context, common::Clone(*yExpr)));
160+
}
156161
}
157162
} else if (name == "epsilon") {
158163
return Expr<T>{Scalar<T>::EPSILON()};

flang/test/Evaluate/fold-dprod.f90

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
! RUN: %python %S/test_folding.py %s %flang_fc1
2+
! Tests folding of DPROD()
3+
module m
4+
logical, parameter :: test_kind = kind(dprod(2., 3.)) == kind(0.d0)
5+
logical, parameter :: test_ss = dprod(2., 3.) == 6.d0
6+
logical, parameter :: test_sv = all(dprod(2., [3.,4.]) == [6.d0,8.d0])
7+
logical, parameter :: test_vv = all(dprod([2.,3.], [4.,5.]) == [8.d0,15.0d0])
8+
end

0 commit comments

Comments
 (0)