Skip to content

Commit 4506de1

Browse files
authored
NFC. Move out and expose affine expression simplification utility out of AffineOps lib (#69813)
Move out trivial affine expression simplification out of AffineOps library. Expose it from libIR. Users of such methods shouldn't have to rely on the AffineOps dialect. For eg., with this change, the method can be used now from lib/Analysis/ (FlatLinearConstraints) as well as AffineOps dialect canonicalization. This way those one won't need to depend on AffineOps for some simplification of affine expressions.
1 parent f7dc26c commit 4506de1

File tree

3 files changed

+108
-101
lines changed

3 files changed

+108
-101
lines changed

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,20 @@ void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
353353
e = getAffineSymbolExpr(idx++, ctx);
354354
}
355355

356+
/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
357+
/// the constant lower and upper bounds for its inputs provided in
358+
/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
359+
/// bound can't be computed. This method only handles simple sum of product
360+
/// expressions (w.r.t constant coefficients) so as to not depend on anything
361+
/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
362+
/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
363+
/// semi-affine ones will lead a none being returned.
364+
std::optional<int64_t>
365+
getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
366+
ArrayRef<std::optional<int64_t>> constLowerBounds,
367+
ArrayRef<std::optional<int64_t>> constUpperBounds,
368+
bool isUpper);
369+
356370
} // namespace mlir
357371

358372
namespace llvm {

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 14 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -700,93 +700,6 @@ static std::optional<int64_t> getUpperBound(Value iv) {
700700
return forOp.getConstantUpperBound() - 1;
701701
}
702702

703-
/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
704-
/// the constant lower and upper bounds for its inputs provided in
705-
/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
706-
/// bound can't be computed. This method only handles simple sum of product
707-
/// expressions (w.r.t constant coefficients) so as to not depend on anything
708-
/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
709-
/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
710-
/// semi-affine ones will lead std::nullopt being returned.
711-
static std::optional<int64_t>
712-
getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
713-
ArrayRef<std::optional<int64_t>> constLowerBounds,
714-
ArrayRef<std::optional<int64_t>> constUpperBounds,
715-
bool isUpper) {
716-
// Handle divs and mods.
717-
if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
718-
// If the LHS of a floor or ceil is bounded and the RHS is a constant, we
719-
// can compute an upper bound.
720-
if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
721-
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
722-
if (!rhsConst || rhsConst.getValue() < 1)
723-
return std::nullopt;
724-
auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
725-
constLowerBounds, constUpperBounds, isUpper);
726-
if (!bound)
727-
return std::nullopt;
728-
return mlir::floorDiv(*bound, rhsConst.getValue());
729-
}
730-
if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
731-
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
732-
if (rhsConst && rhsConst.getValue() >= 1) {
733-
auto bound =
734-
getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
735-
constLowerBounds, constUpperBounds, isUpper);
736-
if (!bound)
737-
return std::nullopt;
738-
return mlir::ceilDiv(*bound, rhsConst.getValue());
739-
}
740-
return std::nullopt;
741-
}
742-
if (binOpExpr.getKind() == AffineExprKind::Mod) {
743-
// lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
744-
// bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
745-
// (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
746-
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
747-
if (rhsConst && rhsConst.getValue() >= 1) {
748-
int64_t rhsConstVal = rhsConst.getValue();
749-
auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
750-
constLowerBounds, constUpperBounds,
751-
/*isUpper=*/false);
752-
auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
753-
constLowerBounds, constUpperBounds, isUpper);
754-
if (ub && lb &&
755-
floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
756-
return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
757-
return isUpper ? rhsConstVal - 1 : 0;
758-
}
759-
}
760-
}
761-
// Flatten the expression.
762-
SimpleAffineExprFlattener flattener(numDims, numSymbols);
763-
flattener.walkPostOrder(expr);
764-
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
765-
// TODO: Handle local variables. We can get hold of flattener.localExprs and
766-
// get bound on the local expr recursively.
767-
if (flattener.numLocals > 0)
768-
return std::nullopt;
769-
int64_t bound = 0;
770-
// Substitute the constant lower or upper bound for the dimensional or
771-
// symbolic input depending on `isUpper` to determine the bound.
772-
for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
773-
if (flattenedExpr[i] > 0) {
774-
auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
775-
if (!constBound)
776-
return std::nullopt;
777-
bound += *constBound * flattenedExpr[i];
778-
} else if (flattenedExpr[i] < 0) {
779-
auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
780-
if (!constBound)
781-
return std::nullopt;
782-
bound += *constBound * flattenedExpr[i];
783-
}
784-
}
785-
// Constant term.
786-
bound += flattenedExpr.back();
787-
return bound;
788-
}
789-
790703
/// Determine a constant upper bound for `expr` if one exists while exploiting
791704
/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
792705
/// is guaranteed to be less than or equal to it.
@@ -805,9 +718,9 @@ static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
805718
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
806719
return constExpr.getValue();
807720

808-
return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
809-
constUpperBounds,
810-
/*isUpper=*/true);
721+
return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
722+
constUpperBounds,
723+
/*isUpper=*/true);
811724
}
812725

813726
/// Determine a constant lower bound for `expr` if one exists while exploiting
@@ -829,9 +742,9 @@ static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
829742
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
830743
lowerBound = constExpr.getValue();
831744
} else {
832-
lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
833-
constUpperBounds,
834-
/*isUpper=*/false);
745+
lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
746+
constLowerBounds, constUpperBounds,
747+
/*isUpper=*/false);
835748
}
836749
return lowerBound;
837750
}
@@ -970,14 +883,14 @@ static void simplifyMinOrMaxExprWithOperands(AffineMap &map,
970883
lowerBounds.push_back(constExpr.getValue());
971884
upperBounds.push_back(constExpr.getValue());
972885
} else {
973-
lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(),
974-
map.getNumSymbols(),
975-
constLowerBounds, constUpperBounds,
976-
/*isUpper=*/false));
977-
upperBounds.push_back(getBoundForExpr(e, map.getNumDims(),
978-
map.getNumSymbols(),
979-
constLowerBounds, constUpperBounds,
980-
/*isUpper=*/true));
886+
lowerBounds.push_back(
887+
getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
888+
constLowerBounds, constUpperBounds,
889+
/*isUpper=*/false));
890+
upperBounds.push_back(
891+
getBoundForAffineExpr(e, map.getNumDims(), map.getNumSymbols(),
892+
constLowerBounds, constUpperBounds,
893+
/*isUpper=*/true));
981894
}
982895
}
983896

mlir/lib/IR/AffineExpr.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,3 +1438,83 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
14381438
assert(flattener.operandExprStack.empty());
14391439
return simplifiedExpr;
14401440
}
1441+
1442+
std::optional<int64_t> mlir::getBoundForAffineExpr(
1443+
AffineExpr expr, unsigned numDims, unsigned numSymbols,
1444+
ArrayRef<std::optional<int64_t>> constLowerBounds,
1445+
ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1446+
// Handle divs and mods.
1447+
if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
1448+
// If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1449+
// can compute an upper bound.
1450+
if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1451+
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
1452+
if (!rhsConst || rhsConst.getValue() < 1)
1453+
return std::nullopt;
1454+
auto bound =
1455+
getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1456+
constLowerBounds, constUpperBounds, isUpper);
1457+
if (!bound)
1458+
return std::nullopt;
1459+
return mlir::floorDiv(*bound, rhsConst.getValue());
1460+
}
1461+
if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1462+
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
1463+
if (rhsConst && rhsConst.getValue() >= 1) {
1464+
auto bound =
1465+
getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1466+
constLowerBounds, constUpperBounds, isUpper);
1467+
if (!bound)
1468+
return std::nullopt;
1469+
return mlir::ceilDiv(*bound, rhsConst.getValue());
1470+
}
1471+
return std::nullopt;
1472+
}
1473+
if (binOpExpr.getKind() == AffineExprKind::Mod) {
1474+
// lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1475+
// bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1476+
// (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1477+
auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
1478+
if (rhsConst && rhsConst.getValue() >= 1) {
1479+
int64_t rhsConstVal = rhsConst.getValue();
1480+
auto lb = getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1481+
constLowerBounds, constUpperBounds,
1482+
/*isUpper=*/false);
1483+
auto ub =
1484+
getBoundForAffineExpr(binOpExpr.getLHS(), numDims, numSymbols,
1485+
constLowerBounds, constUpperBounds, isUpper);
1486+
if (ub && lb &&
1487+
floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
1488+
return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
1489+
return isUpper ? rhsConstVal - 1 : 0;
1490+
}
1491+
}
1492+
}
1493+
// Flatten the expression.
1494+
SimpleAffineExprFlattener flattener(numDims, numSymbols);
1495+
flattener.walkPostOrder(expr);
1496+
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1497+
// TODO: Handle local variables. We can get hold of flattener.localExprs and
1498+
// get bound on the local expr recursively.
1499+
if (flattener.numLocals > 0)
1500+
return std::nullopt;
1501+
int64_t bound = 0;
1502+
// Substitute the constant lower or upper bound for the dimensional or
1503+
// symbolic input depending on `isUpper` to determine the bound.
1504+
for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1505+
if (flattenedExpr[i] > 0) {
1506+
auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1507+
if (!constBound)
1508+
return std::nullopt;
1509+
bound += *constBound * flattenedExpr[i];
1510+
} else if (flattenedExpr[i] < 0) {
1511+
auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1512+
if (!constBound)
1513+
return std::nullopt;
1514+
bound += *constBound * flattenedExpr[i];
1515+
}
1516+
}
1517+
// Constant term.
1518+
bound += flattenedExpr.back();
1519+
return bound;
1520+
}

0 commit comments

Comments
 (0)