-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Fold ceil/floordiv with negative RHS. #97031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently, we only fold if the RHS is a positive constant. There doesn't seem to be a good reason to do that. The comment claims that division by negative values is undefined, but I suspect that was just copied over from the `mod` simplifier.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Johannes Reifferscheid (jreiffers) ChangesCurrently, we only fold if the RHS is a positive constant. There doesn't seem to be a good reason to do that. The comment claims that division by negative values is undefined, but I suspect that was just copied over from the Full diff: https://github.com/llvm/llvm-project/pull/97031.diff 2 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cf8157cf7bb8c..798398464da8d 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -855,8 +855,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
- // mlir floordiv by zero or negative numbers is undefined and preserved as is.
- if (!rhsConst || rhsConst.getValue() < 1)
+ if (!rhsConst || rhsConst.getValue() == 0)
return nullptr;
if (lhsConst) {
@@ -875,12 +874,12 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
if (rhsConst == 1)
return lhs;
- // Simplify (expr * const) floordiv divConst when expr is known to be a
- // multiple of divConst.
+ // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
+ // multiple of `rhsConst`.
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
- // rhsConst is known to be a positive constant.
+ // `rhsConst` is known to be a nonzero constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
}
@@ -891,7 +890,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
if (lBin && lBin.getKind() == AffineExprKind::Add) {
int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
- // rhsConst is known to be a positive constant.
+ // rhsConst is known to be a nonzero constant.
if (llhsDiv % rhsConst.getValue() == 0 ||
lrhsDiv % rhsConst.getValue() == 0)
return lBin.getLHS().floorDiv(rhsConst.getValue()) +
@@ -918,7 +917,7 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
- if (!rhsConst || rhsConst.getValue() < 1)
+ if (!rhsConst || rhsConst.getValue() == 0)
return nullptr;
if (lhsConst) {
@@ -937,12 +936,12 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
if (rhsConst.getValue() == 1)
return lhs;
- // Simplify (expr * const) ceildiv divConst when const is known to be a
- // multiple of divConst.
+ // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
+ // multiple of `rhsConst`.
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
- // rhsConst is known to be a positive constant.
+ // `rhsConst` is known to be a nonzero constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
}
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index 9740165c6b324..a0affc4341b0b 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -76,3 +76,25 @@ TEST(AffineExprTest, constantFolding) {
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
}
+
+TEST(AffineExprTest, divisionSimplification) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+ auto cn6 = b.getAffineConstantExpr(-6);
+ auto c6 = b.getAffineConstantExpr(6);
+ auto d0 = b.getAffineDimExpr(0);
+ auto d1 = b.getAffineDimExpr(1);
+
+ ASSERT_EQ(c6.floorDiv(-1), cn6);
+ ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3);
+ ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
+ ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3);
+ ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2));
+ ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2));
+ ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
+
+ ASSERT_EQ(c6.ceilDiv(-1), cn6);
+ ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3);
+ ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv);
+ ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3);
+}
|
Currently, we only fold if the RHS is a positive constant. There doesn't seem to be a good reason to do that. The comment claims that division by negative values is undefined, but I suspect that was just copied over from the `mod` simplifier.
Currently, we only fold if the RHS is a positive constant. There doesn't seem to be a good reason to do that. The comment claims that division by negative values is undefined, but I suspect that was just copied over from the
mod
simplifier.