Skip to content

Commit 87e12b8

Browse files
committed
[fir] Support promoting fir.do_loop with results to affine.for.
1 parent 5f704f9 commit 87e12b8

File tree

2 files changed

+99
-5
lines changed

2 files changed

+99
-5
lines changed

flang/lib/Optimizer/Transforms/AffinePromotion.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ struct AffineIfAnalysis;
4949
/// second when doing rewrite.
5050
struct AffineFunctionAnalysis {
5151
explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
52-
for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
53-
loopAnalysisMap.try_emplace(op, op, *this);
52+
funcOp->walk([&](fir::DoLoopOp doloop) {
53+
loopAnalysisMap.try_emplace(doloop, doloop, *this);
54+
});
5455
}
5556

5657
AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
@@ -102,10 +103,23 @@ struct AffineLoopAnalysis {
102103
return true;
103104
}
104105

106+
bool analysisResults(fir::DoLoopOp loopOperation) {
107+
if (loopOperation.getFinalValue() &&
108+
!loopOperation.getResult(0).use_empty()) {
109+
LLVM_DEBUG(
110+
llvm::dbgs()
111+
<< "AffineLoopAnalysis: cannot promote loop final value\n";);
112+
return false;
113+
}
114+
115+
return true;
116+
}
117+
105118
bool analyzeLoop(fir::DoLoopOp loopOperation,
106119
AffineFunctionAnalysis &functionAnalysis) {
107120
LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
108121
return analyzeMemoryAccess(loopOperation) &&
122+
analysisResults(loopOperation) &&
109123
analyzeBody(loopOperation, functionAnalysis);
110124
}
111125

@@ -461,14 +475,28 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
461475
LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
462476
functionAnalysis.getChildLoopAnalysis(loop);
463477
auto &loopOps = loop.getBody()->getOperations();
478+
auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator());
479+
auto results = resultOp.getOperands();
480+
auto loopResults = loop->getResults();
464481
auto loopAndIndex = createAffineFor(loop, rewriter);
465482
auto affineFor = loopAndIndex.first;
466483
auto inductionVar = loopAndIndex.second;
467484

485+
if (loop.getFinalValue()) {
486+
results = results.drop_front();
487+
loopResults = loopResults.drop_front();
488+
}
489+
468490
rewriter.startOpModification(affineFor.getOperation());
469491
affineFor.getBody()->getOperations().splice(
470492
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
471493
std::prev(loopOps.end()));
494+
rewriter.replaceAllUsesWith(loop.getRegionIterArgs(),
495+
affineFor.getRegionIterArgs());
496+
if (!results.empty()) {
497+
rewriter.setInsertionPointToEnd(affineFor.getBody());
498+
rewriter.create<affine::AffineYieldOp>(resultOp->getLoc(), results);
499+
}
472500
rewriter.finalizeOpModification(affineFor.getOperation());
473501

474502
rewriter.startOpModification(loop.getOperation());
@@ -479,7 +507,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
479507

480508
LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
481509
affineFor.dump(););
482-
rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
510+
rewriter.replaceAllUsesWith(loopResults, affineFor->getResults());
511+
rewriter.eraseOp(loop);
483512
return success();
484513
}
485514

@@ -503,7 +532,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
503532
ValueRange(op.getUpperBound()),
504533
mlir::AffineMap::get(0, 1,
505534
1 + mlir::getAffineSymbolExpr(0, op.getContext())),
506-
step);
535+
step, op.getIterOperands());
507536
return std::make_pair(affineFor, affineFor.getInductionVar());
508537
}
509538

@@ -528,7 +557,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
528557
genericUpperBound.getResult(),
529558
mlir::AffineMap::get(0, 1,
530559
1 + mlir::getAffineSymbolExpr(0, op.getContext())),
531-
1);
560+
1, op.getIterOperands());
532561
rewriter.setInsertionPointToStart(affineFor.getBody());
533562
auto actualIndex = rewriter.create<affine::AffineApplyOp>(
534563
op.getLoc(), actualIndexMap,

flang/test/Fir/affine-promotion.fir

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,68 @@ func.func @loop_with_if(%a: !arr_d1, %v: f32) {
131131
// CHECK: }
132132
// CHECK: return
133133
// CHECK: }
134+
135+
func.func @loop_with_result(%arg0: !fir.ref<!fir.array<100xf32>>, %arg1: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
136+
%c1 = arith.constant 1 : index
137+
%cst = arith.constant 0.000000e+00 : f32
138+
%c100 = arith.constant 100 : index
139+
%0 = fir.shape %c100 : (index) -> !fir.shape<1>
140+
%1 = fir.shape %c100, %c100 : (index, index) -> !fir.shape<2>
141+
%2 = fir.alloca i32
142+
%3:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %cst) -> (index, f32) {
143+
%6 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
144+
%7 = fir.load %6 : !fir.ref<f32>
145+
%8 = arith.addf %arg3, %7 fastmath<contract> : f32
146+
%9 = arith.addi %arg2, %c1 overflow<nsw> : index
147+
fir.result %9, %8 : index, f32
148+
}
149+
%4:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %3#1) -> (index, f32) {
150+
%6 = fir.array_coor %arg1(%1) %c1, %arg2 : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
151+
%7 = fir.convert %6 : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
152+
%8 = fir.do_loop %arg4 = %c1 to %c100 step %c1 iter_args(%arg5 = %arg3) -> (f32) {
153+
%10 = fir.array_coor %7(%0) %arg4 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
154+
%11 = fir.load %10 : !fir.ref<f32>
155+
%12 = arith.addf %arg5, %11 fastmath<contract> : f32
156+
fir.result %12 : f32
157+
}
158+
%9 = arith.addi %arg2, %c1 overflow<nsw> : index
159+
fir.result %9, %8 : index, f32
160+
}
161+
%5 = fir.convert %4#0 : (index) -> i32
162+
fir.store %5 to %2 : !fir.ref<i32>
163+
return %4#1 : f32
164+
}
165+
166+
// CHECK-LABEL: func.func @loop_with_result(
167+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xf32>>,
168+
// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
169+
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
170+
// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
171+
// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index
172+
// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
173+
// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
174+
// CHECK: %[[VAL_5:.*]] = fir.alloca i32
175+
// CHECK: %[[VAL_6:.*]] = fir.convert %[[ARG0]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
176+
// CHECK: %[[VAL_7:.*]] = affine.for %[[VAL_8:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_9:.*]] = %[[VAL_1]]) -> (f32) {
177+
// CHECK: %[[VAL_10:.*]] = affine.apply #{{.*}}(%[[VAL_8]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
178+
// CHECK: %[[VAL_11:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
179+
// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] fastmath<contract> : f32
180+
// CHECK: affine.yield %[[VAL_12]] : f32
181+
// CHECK: }
182+
// CHECK: %[[VAL_13:.*]]:2 = fir.do_loop %[[VAL_14:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] iter_args(%[[VAL_15:.*]] = %[[VAL_7]]) -> (index, f32) {
183+
// CHECK: %[[VAL_16:.*]] = fir.array_coor %[[ARG1]](%[[VAL_4]]) %[[VAL_0]], %[[VAL_14]] : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
184+
// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
185+
// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
186+
// CHECK: %[[VAL_19:.*]] = affine.for %[[VAL_20:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_21:.*]] = %[[VAL_15]]) -> (f32) {
187+
// CHECK: %[[VAL_22:.*]] = affine.apply #{{.*}}(%[[VAL_20]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
188+
// CHECK: %[[VAL_23:.*]] = affine.load %[[VAL_18]]{{\[}}%[[VAL_22]]] : memref<?xf32>
189+
// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_21]], %[[VAL_23]] fastmath<contract> : f32
190+
// CHECK: affine.yield %[[VAL_24]] : f32
191+
// CHECK: }
192+
// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_14]], %[[VAL_0]] overflow<nsw> : index
193+
// CHECK: fir.result %[[VAL_25]], %[[VAL_19]] : index, f32
194+
// CHECK: }
195+
// CHECK: %[[VAL_26:.*]] = fir.convert %[[VAL_27:.*]]#0 : (index) -> i32
196+
// CHECK: fir.store %[[VAL_26]] to %[[VAL_5]] : !fir.ref<i32>
197+
// CHECK: return %[[VAL_27]]#1 : f32
198+
// CHECK: }

0 commit comments

Comments
 (0)