Skip to content

Commit 0501102

Browse files
committed
[flang][LoopVersioning] support fir.declare
When FIR comes from HLFIR, there will be a fir.declare operation between the source and the usage of each source variable (and some temporary allocations). This pass needs to be able to follow these so that it can still transform loops when HLFIR is used, otherwise it mistakenly assumes these values are not function arguments. More work is needed after this patch to fully support HLFIR, because the generated code tends to use fir.array_coor instead of fir.coordinate_of. Differential Revision: https://reviews.llvm.org/D157964
1 parent 66abe64 commit 0501102

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

flang/lib/Optimizer/Transforms/LoopVersioning.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ static fir::SequenceType getAsSequenceType(mlir::Value *v) {
9494
return argTy.dyn_cast<fir::SequenceType>();
9595
}
9696

97+
/// if a value comes from a fir.declare, follow it to the original source,
98+
/// otherwise return the value
99+
static mlir::Value unwrapFirDeclare(mlir::Value val) {
100+
// fir.declare is for source code variables. We don't have declares of
101+
// declares
102+
if (fir::DeclareOp declare = val.getDefiningOp<fir::DeclareOp>())
103+
return declare.getMemref();
104+
return val;
105+
}
106+
97107
void LoopVersioningPass::runOnOperation() {
98108
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
99109
mlir::func::FuncOp func = getOperation();
@@ -154,9 +164,9 @@ void LoopVersioningPass::runOnOperation() {
154164
// to it later.
155165
if (op->getParentOfType<fir::DoLoopOp>() != loop)
156166
return;
157-
const mlir::Value &operand = op->getOperand(0);
167+
mlir::Value operand = op->getOperand(0);
158168
for (auto a : argsOfInterest) {
159-
if (*a.arg == operand) {
169+
if (*a.arg == unwrapFirDeclare(operand)) {
160170
// Only add if it's not already in the list.
161171
if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
162172
return it.arg == a.arg;
@@ -244,7 +254,7 @@ void LoopVersioningPass::runOnOperation() {
244254
// arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x)
245255
// where stride is the distance between elements in the dimensions
246256
// 0, 1 and 2 or x, y and z.
247-
if (coop->getOperand(0) == *arg.arg &&
257+
if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg &&
248258
coop->getOperands().size() >= 2) {
249259
builder.setInsertionPoint(coop);
250260
mlir::Value totalIndex;

flang/test/Transforms/loop-versioning.fir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// end subroutine sum1d
1414
module {
1515
func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
16+
%decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
1617
%0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
1718
%1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
1819
%cst = arith.constant 0.000000e+00 : f64
@@ -30,7 +31,7 @@ module {
3031
%9 = fir.convert %8 : (i32) -> i64
3132
%c1_i64 = arith.constant 1 : i64
3233
%10 = arith.subi %9, %c1_i64 : i64
33-
%11 = fir.coordinate_of %arg0, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
34+
%11 = fir.coordinate_of %decl, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
3435
%12 = fir.load %11 : !fir.ref<f64>
3536
%13 = arith.addf %7, %12 fastmath<contract> : f64
3637
fir.store %13 to %1 : !fir.ref<f64>
@@ -47,6 +48,7 @@ module {
4748
// Note this only checks the expected transformation, not the entire generated code:
4849
// CHECK-LABEL: func.func @sum1d(
4950
// CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xf64>> {{.*}})
51+
// CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
5052
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
5153
// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}}
5254
// CHECK: %[[SIZE:.*]] = arith.constant 8 : index
@@ -62,7 +64,7 @@ module {
6264
// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1
6365
// CHECK: } else {
6466
// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}}
65-
// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
67+
// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
6668
// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref<f64>
6769
// CHECK: fir.result %{{.*}}, %{{.*}}
6870
// CHECK: }

0 commit comments

Comments
 (0)