-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] [DataFlow] Fix bug in int-range-analysis #126708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: donald chen (cxy-1993) ChangesThis patch fix bug in int range analysis: When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation. Full diff: https://github.com/llvm/llvm-project/pull/126708.diff 2 Files Affected:
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 9e9411e5ede12c8..722f4df18e9818c 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -152,7 +152,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
/// on a LoopLikeInterface return the lower/upper bound for that result if
/// possible.
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
- Type boundType, bool getUpper) {
+ Type boundType, Block *block, bool getUpper) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
if (loopBound.has_value()) {
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
@@ -160,7 +160,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return bound.getValue();
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
- getLatticeElementFor(getProgramPointAfter(op), value);
+ getLatticeElementFor(getProgramPointBefore(block), value);
if (lattice != nullptr && !lattice->getValue().isUninitialized())
return getUpper ? lattice->getValue().getValue().smax()
: lattice->getValue().getValue().smin();
@@ -180,16 +180,17 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
op, successor, argLattices, firstIndex);
}
+ Block *block = iv->getParentBlock();
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
std::optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
/*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+ getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);
if (stepVal.isNegative()) {
std::swap(min, max);
diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
index 1ec3441b1fde817..b98e8b07db5ce2b 100644
--- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
+++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir
@@ -154,3 +154,33 @@ func.func @dont_propagate_across_infinite_loop() -> index {
return %2 : index
}
+// CHECK-LABEL: @propagate_from_block_to_iterarg
+func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = scf.if %arg1 -> (index) {
+ %1 = scf.if %arg1 -> (index) {
+ scf.yield %arg0 : index
+ } else {
+ scf.yield %arg0 : index
+ }
+ scf.yield %1 : index
+ } else {
+ scf.yield %c1 : index
+ }
+ scf.for %arg2 = %c0 to %arg0 step %c1 {
+ scf.if %arg1 {
+ %1 = arith.subi %0, %c1 : index
+ %2 = arith.muli %0, %1 : index
+ %3 = arith.addi %2, %c1 : index
+ scf.for %arg3 = %c0 to %3 step %c1 {
+ %4 = arith.cmpi uge, %arg3, %c1 : index
+ // CHECK-NOT: scf.if %false
+ scf.if %4 {
+ "test.foo"() : () -> ()
+ }
+ }
+ }
+ }
+ return
+}
|
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.
7fb01cf
to
4755da6
Compare
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.
When querying the lower bound and upper bound of loop to update the value range of a loop iteration variable, the program point to depend on should be the block corresponding to the iteration variable rather than the loop operation.