Skip to content

Commit d40bab3

Browse files
authored
[mlir][liveness] fix bugs in liveness analysis (#133416)
This patch fixes the following bugs: - In SparseBackwardAnalysis, the setToExitState function should propagate changes if it modifies the lattice. Previously, this issue was masked because multi-block scenarios were not tested, and the traversal order of backward data flow analysis starts from the end of the program. - The method in liveness analysis for determining whether the non-forwarded operand in branch/region branch operations is live is incorrect, which may cause originally live variables to be marked as not live.
1 parent 03a791f commit d40bab3

File tree

3 files changed

+107
-29
lines changed

3 files changed

+107
-29
lines changed

mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,12 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
413413
// Visit operands on call instructions that are not forwarded.
414414
virtual void visitCallOperand(OpOperand &operand) = 0;
415415

416-
/// Set the given lattice element(s) at control flow exit point(s).
416+
/// Set the given lattice element(s) at control flow exit point(s) and
417+
/// propagate the update if it chaned.
417418
virtual void setToExitState(AbstractSparseLattice *lattice) = 0;
418419

419-
/// Set the given lattice element(s) at control flow exit point(s).
420+
/// Set the given lattice element(s) at control flow exit point(s) and
421+
/// propagate the update if it chaned.
420422
void setAllToExitStates(ArrayRef<AbstractSparseLattice *> lattices);
421423

422424
/// Get the lattice element for a value.

mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) {
5959
/// (1.a) is an operand of an op with memory effects OR
6060
/// (1.b) is a non-forwarded branch operand and its branch op could take the
6161
/// control to a block that has an op with memory effects OR
62-
/// (1.c) is a non-forwarded call operand.
62+
/// (1.c) is a non-forwarded branch operand and its branch op could result
63+
/// in different live result OR
64+
/// (1.d) is a non-forwarded call operand.
6365
///
6466
/// A value `A` is said to be "used to compute" value `B` iff `B` cannot be
6567
/// computed in the absence of `A`. Thus, in this implementation, we say that
@@ -106,51 +108,88 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
106108
// the forwarded branch operands or the non-branch operands. Thus they need
107109
// to be handled separately. This is where we handle them.
108110

109-
// This marks values of type (1.b) liveness as "live". A non-forwarded
111+
// This marks values of type (1.b/1.c) liveness as "live". A non-forwarded
110112
// branch operand will be live if a block where its op could take the control
111-
// has an op with memory effects.
113+
// has an op with memory effects or could result in different results.
112114
// Populating such blocks in `blocks`.
115+
bool mayLive = false;
113116
SmallVector<Block *, 4> blocks;
114117
if (isa<RegionBranchOpInterface>(op)) {
115-
// When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
116-
// `scf.index_switch` op, its branch operand controls the flow into this
117-
// op's regions.
118-
for (Region &region : op->getRegions()) {
119-
for (Block &block : region)
120-
blocks.push_back(&block);
118+
if (op->getNumResults() != 0) {
119+
// This mark value of type 1.c liveness as may live, because the region
120+
// branch operation has a return value, and the non-forwarded operand can
121+
// determine the region to jump to, it can thereby control the result of
122+
// the region branch operation.
123+
// Therefore, if the result value is live, we conservatively consider the
124+
// non-forwarded operand of the region branch operation with result may
125+
// live and record all result.
126+
for (Value result : op->getResults()) {
127+
if (getLatticeElement(result)->isLive) {
128+
mayLive = true;
129+
break;
130+
}
131+
}
132+
} else {
133+
// When the op is a `RegionBranchOpInterface`, like an `scf.for` or an
134+
// `scf.index_switch` op, its branch operand controls the flow into this
135+
// op's regions.
136+
for (Region &region : op->getRegions()) {
137+
for (Block &block : region)
138+
blocks.push_back(&block);
139+
}
121140
}
122141
} else if (isa<BranchOpInterface>(op)) {
123-
// When the op is a `BranchOpInterface`, like a `cf.cond_br` or a
124-
// `cf.switch` op, its branch operand controls the flow into this op's
125-
// successors.
126-
blocks = op->getSuccessors();
142+
// We cannot track all successor blocks of the branch operation(More
143+
// specifically, it's the successor's successor). Additionally, different
144+
// blocks might also lead to the different block argument described in 1.c.
145+
// Therefore, we conservatively consider the non-forwarded operand of the
146+
// branch operation may live.
147+
mayLive = true;
127148
} else {
128-
// When the op is a `RegionBranchTerminatorOpInterface`, like an
129-
// `scf.condition` op or return-like, like an `scf.yield` op, its branch
130-
// operand controls the flow into this op's parent's (which is a
131-
// `RegionBranchOpInterface`'s) regions.
132149
Operation *parentOp = op->getParentOp();
133150
assert(isa<RegionBranchOpInterface>(parentOp) &&
134151
"expected parent op to implement `RegionBranchOpInterface`");
135-
for (Region &region : parentOp->getRegions()) {
136-
for (Block &block : region)
137-
blocks.push_back(&block);
152+
if (parentOp->getNumResults() != 0) {
153+
// This mark value of type 1.c liveness as may live, because the region
154+
// branch operation has a return value, and the non-forwarded operand can
155+
// determine the region to jump to, it can thereby control the result of
156+
// the region branch operation.
157+
// Therefore, if the result value is live, we conservatively consider the
158+
// non-forwarded operand of the region branch operation with result may
159+
// live and record all result.
160+
for (Value result : parentOp->getResults()) {
161+
if (getLatticeElement(result)->isLive) {
162+
mayLive = true;
163+
break;
164+
}
165+
}
166+
} else {
167+
// When the op is a `RegionBranchTerminatorOpInterface`, like an
168+
// `scf.condition` op or return-like, like an `scf.yield` op, its branch
169+
// operand controls the flow into this op's parent's (which is a
170+
// `RegionBranchOpInterface`'s) regions.
171+
for (Region &region : parentOp->getRegions()) {
172+
for (Block &block : region)
173+
blocks.push_back(&block);
174+
}
138175
}
139176
}
140-
bool foundMemoryEffectingOp = false;
141177
for (Block *block : blocks) {
142-
if (foundMemoryEffectingOp)
178+
if (mayLive)
143179
break;
144180
for (Operation &nestedOp : *block) {
145181
if (!isMemoryEffectFree(&nestedOp)) {
146-
Liveness *operandLiveness = getLatticeElement(operand.get());
147-
propagateIfChanged(operandLiveness, operandLiveness->markLive());
148-
foundMemoryEffectingOp = true;
182+
mayLive = true;
149183
break;
150184
}
151185
}
152186
}
153187

188+
if (mayLive) {
189+
Liveness *operandLiveness = getLatticeElement(operand.get());
190+
propagateIfChanged(operandLiveness, operandLiveness->markLive());
191+
}
192+
154193
// Now that we have checked for memory-effecting ops in the blocks of concern,
155194
// we will simply visit the op with this non-forwarded operand to potentially
156195
// mark it "live" due to type (1.a/3) liveness.
@@ -191,8 +230,12 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) {
191230
}
192231

193232
void LivenessAnalysis::setToExitState(Liveness *lattice) {
233+
if (lattice->isLive) {
234+
return;
235+
}
194236
// This marks values of type (2) liveness as "live".
195237
(void)lattice->markLive();
238+
propagateIfChanged(lattice, ChangeResult::Change);
196239
}
197240

198241
//===----------------------------------------------------------------------===//

mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,49 @@ func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref<i32>, %ar
5959

6060
// -----
6161

62+
// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
63+
// op could result in different result"
64+
// CHECK-LABEL: test_tag: cond_br:
65+
// CHECK-NEXT: operand #0: live
66+
// CHECK-NEXT: operand #1: live
67+
// CHECK-NEXT: operand #2: live
68+
func.func @test_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
69+
cf.cond_br %arg2, ^bb1(%arg0 : tensor<f32>), ^bb2(%arg1 : tensor<f32>) {tag = "cond_br"}
70+
^bb1(%0 : tensor<f32>):
71+
cf.br ^bb3(%0 : tensor<f32>)
72+
^bb2(%1 : tensor<f32>):
73+
cf.br ^bb3(%1 : tensor<f32>)
74+
^bb3(%2 : tensor<f32>):
75+
return %2 : tensor<f32>
76+
}
77+
78+
// -----
79+
80+
// Positive test: Type(1.c) "is a non-forwarded branch operand and its branch
81+
// op could result in different result"
82+
// CHECK-LABEL: test_tag: region_branch:
83+
// CHECK-NEXT: operand #0: live
84+
func.func @test_region_branch_result_in_different_result_1.c(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : i1) -> tensor<f32> {
85+
%0 = scf.if %arg2 -> tensor<f32> {
86+
scf.yield %arg0 : tensor<f32>
87+
} else {
88+
scf.yield %arg1 : tensor<f32>
89+
} {tag="region_branch"}
90+
return %0 : tensor<f32>
91+
}
92+
93+
// -----
94+
6295
func.func private @private(%arg0 : i32, %arg1 : i32) {
6396
func.return
6497
}
6598

66-
// Positive test: Type (1.c) "is a non-forwarded call operand"
99+
// Positive test: Type (1.d) "is a non-forwarded call operand"
67100
// CHECK-LABEL: test_tag: call
68101
// CHECK-LABEL: operand #0: not live
69102
// CHECK-LABEL: operand #1: not live
70103
// CHECK-LABEL: operand #2: live
71-
func.func @test_4_type_1.c(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
104+
func.func @test_4_type_1.d(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
72105
test.call_on_device @private(%arg0, %arg1), %device {tag = "call"} : (i32, i32, i32) -> ()
73106
return
74107
}

0 commit comments

Comments
 (0)