Skip to content

Commit 49212d1

Browse files
authored
[Flang] Fix for replacing loop uses in LoopVersioning pass (#77899)
The added test case has a loop that is versioned, which has a use of the loop in an if block after the loop. The current code replaces all uses of the loop with the new version If, but only if the parent blocks match. As far as I can see it should be safe to replace all the uses, then construct the result for the If with op.op.
1 parent 46a9135 commit 49212d1

File tree

2 files changed

+81
-12
lines changed

2 files changed

+81
-12
lines changed

flang/lib/Optimizer/Transforms/LoopVersioning.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,6 @@ struct ArgsUsageInLoop {
144144
};
145145
} // namespace
146146

147-
/// @c replaceOuterUses - replace uses outside of @c op with result of @c
148-
/// outerOp
149-
static void replaceOuterUses(mlir::Operation *op, mlir::Operation *outerOp) {
150-
const mlir::Operation *outerParent = outerOp->getParentOp();
151-
op->replaceUsesWithIf(outerOp, [&](mlir::OpOperand &operand) {
152-
mlir::Operation *owner = operand.getOwner();
153-
return outerParent == owner->getParentOp();
154-
});
155-
}
156-
157147
static fir::SequenceType getAsSequenceType(mlir::Value *v) {
158148
mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType()));
159149
return argTy.dyn_cast<fir::SequenceType>();
@@ -538,7 +528,7 @@ void LoopVersioningPass::runOnOperation() {
538528

539529
// Add the original loop in the else-side of the if operation.
540530
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
541-
replaceOuterUses(op.op, ifOp);
531+
op.op->replaceAllUsesWith(ifOp);
542532
op.op->remove();
543533
builder.insert(op.op);
544534
// Rely on "cloned loop has results, so original loop also has results".

flang/test/Transforms/loop-versioning.fir

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func.func @sum1dfixed(%arg0: !fir.ref<!fir.array<?xf64>> {fir.bindc_name = "a"},
263263
// CHECK-SAME: %[[N1:.*]]: !fir.ref<i32> {{.*}},
264264
// CHECK-SAME: %[[M1:.*]]: !fir.ref<i32> {{.*}}) {
265265
// CHECK: fir.do_loop
266-
// CHECL: %[[FOUR:.*]] = arith.constant 4 : index
266+
// CHECK: %[[FOUR:.*]] = arith.constant 4 : index
267267
// CHECK: %[[COMP:.*]] = arith.cmpi {{.*}}, %[[FOUR]]
268268
// CHECK: fir.if %[[COMP]] -> {{.*}} {
269269
// CHECK: %[[CONV:.*]] = fir.convert %[[B]] :
@@ -1478,4 +1478,83 @@ func.func @sum1drebox(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"},
14781478
// CHECK-NOT: fir.if
14791479

14801480

1481+
// Check for a use in a different block (%12 = do_loop is used inside the if %14 block)
1482+
func.func @minloc(%arg0: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "x"}, %arg1: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "mask"}) -> f32 {
1483+
%c2147483647_i32 = arith.constant 2147483647 : i32
1484+
%c1_i32 = arith.constant 1 : i32
1485+
%c0 = arith.constant 0 : index
1486+
%c0_i32 = arith.constant 0 : i32
1487+
%c5_i32 = arith.constant 5 : i32
1488+
%c5 = arith.constant 5 : index
1489+
%c1 = arith.constant 1 : index
1490+
%0 = fir.alloca i32
1491+
%1 = fir.alloca !fir.array<1xi32>
1492+
%2 = fir.declare %arg1 {uniq_name = "_QFtestEmask"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
1493+
%3 = fir.rebox %2 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
1494+
%4 = fir.alloca f32 {bindc_name = "test", uniq_name = "_QFtestEtest"}
1495+
%5 = fir.declare %4 {uniq_name = "_QFtestEtest"} : (!fir.ref<f32>) -> !fir.ref<f32>
1496+
%6 = fir.declare %arg0 {uniq_name = "_QFtestEx"} : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
1497+
%7 = fir.rebox %6 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
1498+
%8 = fir.shape %c1 : (index) -> !fir.shape<1>
1499+
%9 = fir.array_coor %1(%8) %c1 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
1500+
fir.store %c0_i32 to %9 : !fir.ref<i32>
1501+
fir.store %c0_i32 to %0 : !fir.ref<i32>
1502+
%10:3 = fir.box_dims %7, %c0 : (!fir.box<!fir.array<?xi32>>, index) -> (index, index, index)
1503+
%11 = arith.subi %10#1, %c1 : index
1504+
%12 = fir.do_loop %arg2 = %c0 to %11 step %c1 iter_args(%arg3 = %c2147483647_i32) -> (i32) {
1505+
%18 = arith.addi %arg2, %c1 : index
1506+
%19 = fir.array_coor %3 %18 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
1507+
%20 = fir.load %19 : !fir.ref<i32>
1508+
%21 = arith.cmpi sge, %20, %c5_i32 : i32
1509+
%22 = fir.if %21 -> (i32) {
1510+
fir.store %c1_i32 to %0 : !fir.ref<i32>
1511+
%23 = arith.subi %10#0, %c1 : index
1512+
%24 = arith.addi %18, %23 : index
1513+
%25 = fir.array_coor %7 %24 : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
1514+
%26 = fir.load %25 : !fir.ref<i32>
1515+
%27 = arith.cmpi slt, %26, %arg3 : i32
1516+
%28 = fir.if %27 -> (i32) {
1517+
%29 = fir.convert %18 : (index) -> i32
1518+
fir.store %29 to %9 : !fir.ref<i32>
1519+
fir.result %26 : i32
1520+
} else {
1521+
fir.result %arg3 : i32
1522+
}
1523+
fir.result %28 : i32
1524+
} else {
1525+
fir.result %arg3 : i32
1526+
}
1527+
fir.result %22 : i32
1528+
}
1529+
%13 = fir.load %0 : !fir.ref<i32>
1530+
%14 = arith.cmpi eq, %13, %c1_i32 : i32
1531+
fir.if %14 {
1532+
%18 = arith.cmpi eq, %12, %c2147483647_i32 : i32
1533+
fir.if %18 {
1534+
%19 = fir.array_coor %1(%8) %c0 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
1535+
fir.store %c1_i32 to %19 : !fir.ref<i32>
1536+
}
1537+
}
1538+
%15 = fir.slice %c5, %c5, %c1 : (index, index, index) -> !fir.slice<1>
1539+
%16 = fir.rebox %7 [%15] : (!fir.box<!fir.array<?xi32>>, !fir.slice<1>) -> !fir.box<!fir.array<1xi32>>
1540+
fir.do_loop %arg2 = %c1 to %c1 step %c1 unordered {
1541+
%18 = fir.array_coor %1(%8) %arg2 : (!fir.ref<!fir.array<1xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
1542+
%19 = fir.load %18 : !fir.ref<i32>
1543+
%20 = fir.array_coor %16 %arg2 : (!fir.box<!fir.array<1xi32>>, index) -> !fir.ref<i32>
1544+
fir.store %19 to %20 : !fir.ref<i32>
1545+
}
1546+
%17 = fir.load %5 : !fir.ref<f32>
1547+
return %17 : f32
1548+
}
1549+
// CHECK-LABEL: func @minloc
1550+
// CHECK: %[[V17:.*]] = fir.if %{{.*}} -> (i32) {
1551+
// CHECK: %[[V27:.*]] = fir.do_loop
1552+
// CHECK: fir.result %[[V27]] : i32
1553+
// CHECK: } else {
1554+
// CHECK: %[[V23:.*]] = fir.do_loop
1555+
// CHECK: fir.result %[[V23]] : i32
1556+
// CHECK: fir.if %{{.*}} {
1557+
// CHECK: {{.*}} = arith.cmpi eq, %[[V17]], %c2147483647_i32
1558+
1559+
14811560
} // End module

0 commit comments

Comments
 (0)