Skip to content

Commit 9aa7434

Browse files
[mlir][SCF] Further simplify affine maps during for-loop-canonicalization
* Implement `FlatAffineConstraints::getConstantBound(EQ)`. * Inject a simpler constraint for loops that have at most 1 iteration. * Taking into account constant EQ bounds of FlatAffineConstraint dims/symbols during canonicalization of the resulting affine map in `canonicalizeMinMaxOp`. Differential Revision: https://reviews.llvm.org/D119153
1 parent 0220110 commit 9aa7434

File tree

4 files changed

+61
-13
lines changed

4 files changed

+61
-13
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@ class IntegerPolyhedron {
374374

375375
/// Returns the constant bound for the pos^th identifier if there is one;
376376
/// None otherwise.
377-
// TODO: Support EQ bounds.
378377
Optional<int64_t> getConstantBound(BoundType type, unsigned pos) const;
379378

380379
/// Removes constraints that are independent of (i.e., do not have a

mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,11 +1518,21 @@ IntegerPolyhedron::computeConstantLowerOrUpperBound(unsigned pos) {
15181518

15191519
Optional<int64_t> IntegerPolyhedron::getConstantBound(BoundType type,
15201520
unsigned pos) const {
1521-
assert(type != BoundType::EQ && "EQ not implemented");
1522-
IntegerPolyhedron tmpCst(*this);
15231521
if (type == BoundType::LB)
1524-
return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1525-
return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1522+
return IntegerPolyhedron(*this)
1523+
.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1524+
if (type == BoundType::UB)
1525+
return IntegerPolyhedron(*this)
1526+
.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1527+
1528+
assert(type == BoundType::EQ && "expected EQ");
1529+
Optional<int64_t> lb =
1530+
IntegerPolyhedron(*this)
1531+
.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
1532+
Optional<int64_t> ub =
1533+
IntegerPolyhedron(*this)
1534+
.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
1535+
return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
15261536
}
15271537

15281538
// A simple (naive and conservative) check for hyper-rectangularity.

mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
183183
AffineMap newMap = alignedBoundMap;
184184
SmallVector<Value> newOperands;
185185
unpackOptionalValues(constraints.getMaybeDimAndSymbolValues(), newOperands);
186+
// If dims/symbols have known constant values, use those in order to simplify
187+
// the affine map further.
188+
for (int64_t i = 0, e = constraints.getNumIds(); i < e; ++i) {
189+
// Skip unused operands and operands that are already constants.
190+
if (!newOperands[i] || getConstantIntValue(newOperands[i]))
191+
continue;
192+
if (auto bound = constraints.getConstantBound(FlatAffineConstraints::EQ, i))
193+
newOperands[i] =
194+
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), *bound);
195+
}
186196
mlir::canonicalizeMapAndOperands(&newMap, &newOperands);
187197
rewriter.setInsertionPoint(op);
188198
rewriter.replaceOpWithNewOp<AffineApplyOp>(op, newMap, newOperands);
@@ -211,19 +221,29 @@ addLoopRangeConstraints(FlatAffineValueConstraints &constraints, Value iv,
211221
if (ubInt)
212222
constraints.addBound(FlatAffineConstraints::EQ, dimUb, *ubInt);
213223

214-
// iv >= lb (equiv.: iv - lb >= 0)
224+
// Lower bound: iv >= lb (equiv.: iv - lb >= 0)
215225
SmallVector<int64_t> ineqLb(constraints.getNumCols(), 0);
216226
ineqLb[dimIv] = 1;
217227
ineqLb[dimLb] = -1;
218228
constraints.addInequality(ineqLb);
219229

220-
// iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
221-
AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt)
222-
: rewriter.getAffineDimExpr(dimLb);
223-
AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt)
224-
: rewriter.getAffineDimExpr(dimUb);
225-
AffineExpr ivUb =
226-
exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt)));
230+
// Upper bound
231+
AffineExpr ivUb;
232+
if (lbInt && ubInt && (*lbInt + *stepInt >= *ubInt)) {
233+
// The loop has at most one iteration.
234+
// iv < lb + 1
235+
// TODO: Try to derive this constraint by simplifying the expression in
236+
// the else-branch.
237+
ivUb = rewriter.getAffineDimExpr(dimLb) + 1;
238+
} else {
239+
// The loop may have more than one iteration.
240+
// iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
241+
AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr(*lbInt)
242+
: rewriter.getAffineDimExpr(dimLb);
243+
AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr(*ubInt)
244+
: rewriter.getAffineDimExpr(dimUb);
245+
ivUb = exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1).floorDiv(*stepInt)));
246+
}
227247
auto map = AffineMap::get(
228248
/*dimCount=*/constraints.getNumDimIds(),
229249
/*symbolCount=*/constraints.getNumSymbolIds(), /*result=*/ivUb);

mlir/test/Dialect/SCF/for-loop-canonicalization.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,22 @@ func @tensor_dim_of_loop_result_no_canonicalize(%t : tensor<?x?xf32>,
349349
%dim = tensor.dim %1, %c0 : tensor<?x?xf32>
350350
return %dim : index
351351
}
352+
353+
// -----
354+
355+
// CHECK-LABEL: func @one_trip_scf_for_canonicalize_min
356+
// CHECK: %[[C4:.*]] = arith.constant 4 : i64
357+
// CHECK: scf.for
358+
// CHECK: memref.store %[[C4]], %{{.*}}[] : memref<i64>
359+
func @one_trip_scf_for_canonicalize_min(%A : memref<i64>) {
360+
%c0 = arith.constant 0 : index
361+
%c2 = arith.constant 2 : index
362+
%c4 = arith.constant 4 : index
363+
364+
scf.for %i = %c0 to %c4 step %c4 {
365+
%1 = affine.min affine_map<(d0, d1)[] -> (4, d1 - d0)> (%i, %c4)
366+
%2 = arith.index_cast %1: index to i64
367+
memref.store %2, %A[]: memref<i64>
368+
}
369+
return
370+
}

0 commit comments

Comments
 (0)