Skip to content

[mlir] Fold ceil/floordiv with negative RHS. #97031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions mlir/lib/IR/AffineExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);

// mlir floordiv by zero or negative numbers is undefined and preserved as is.
if (!rhsConst || rhsConst.getValue() < 1)
if (!rhsConst || rhsConst.getValue() == 0)
return nullptr;

if (lhsConst) {
Expand All @@ -875,12 +874,12 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
if (rhsConst == 1)
return lhs;

// Simplify (expr * const) floordiv divConst when expr is known to be a
// multiple of divConst.
// Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
// multiple of `rhsConst`.
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
// rhsConst is known to be a positive constant.
// `rhsConst` is known to be a nonzero constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
}
Expand All @@ -891,7 +890,7 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
if (lBin && lBin.getKind() == AffineExprKind::Add) {
int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
// rhsConst is known to be a positive constant.
// rhsConst is known to be a nonzero constant.
if (llhsDiv % rhsConst.getValue() == 0 ||
lrhsDiv % rhsConst.getValue() == 0)
return lBin.getLHS().floorDiv(rhsConst.getValue()) +
Expand All @@ -918,7 +917,7 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);

if (!rhsConst || rhsConst.getValue() < 1)
if (!rhsConst || rhsConst.getValue() == 0)
return nullptr;

if (lhsConst) {
Expand All @@ -937,12 +936,12 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
if (rhsConst.getValue() == 1)
return lhs;

// Simplify (expr * const) ceildiv divConst when const is known to be a
// multiple of divConst.
// Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
// multiple of `rhsConst`.
auto lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
if (lBin && lBin.getKind() == AffineExprKind::Mul) {
if (auto lrhs = dyn_cast<AffineConstantExpr>(lBin.getRHS())) {
// rhsConst is known to be a positive constant.
// `rhsConst` is known to be a nonzero constant.
if (lrhs.getValue() % rhsConst.getValue() == 0)
return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
}
Expand Down
22 changes: 22 additions & 0 deletions mlir/unittests/IR/AffineExprTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,25 @@ TEST(AffineExprTest, constantFolding) {
getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
}

TEST(AffineExprTest, divisionSimplification) {
MLIRContext ctx;
OpBuilder b(&ctx);
auto cn6 = b.getAffineConstantExpr(-6);
auto c6 = b.getAffineConstantExpr(6);
auto d0 = b.getAffineDimExpr(0);
auto d1 = b.getAffineDimExpr(1);

ASSERT_EQ(c6.floorDiv(-1), cn6);
ASSERT_EQ((d0 * 6).floorDiv(2), d0 * 3);
ASSERT_EQ((d0 * 6).floorDiv(4).getKind(), AffineExprKind::FloorDiv);
ASSERT_EQ((d0 * 6).floorDiv(-2), d0 * -3);
ASSERT_EQ((d0 * 6 + d1).floorDiv(2), d0 * 3 + d1.floorDiv(2));
ASSERT_EQ((d0 * 6 + d1).floorDiv(-2), d0 * -3 + d1.floorDiv(-2));
ASSERT_EQ((d0 * 6 + d1).floorDiv(4).getKind(), AffineExprKind::FloorDiv);

ASSERT_EQ(c6.ceilDiv(-1), cn6);
ASSERT_EQ((d0 * 6).ceilDiv(2), d0 * 3);
ASSERT_EQ((d0 * 6).ceilDiv(4).getKind(), AffineExprKind::CeilDiv);
ASSERT_EQ((d0 * 6).ceilDiv(-2), d0 * -3);
}
Loading