Skip to content

Commit 8af0860

Browse files
authored
AffineExpr: Fix result of d0 + (d0 // -c) * c. (#107530)
Currently, this is rewritten to d0 mod -c. However, we do not support modulo with a negative RHS in our lowering passes, so this triggers undefined behavior. It would be better to not have these ad hoc simplifications at all, but I guess that ship has sailed.
1 parent b11a703 commit 8af0860

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

mlir/lib/IR/AffineExpr.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,11 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
760760

761761
llrhs = lrBinOpExpr.getLHS();
762762
rlrhs = lrBinOpExpr.getRHS();
763+
auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(rlrhs);
764+
// We don't support modulo with a negative RHS.
765+
bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
763766

764-
if (lhs == llrhs && rlrhs == -rrhs) {
767+
if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
765768
return lhs % rlrhs;
766769
}
767770
return nullptr;

mlir/unittests/IR/AffineExprTest.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515

1616
using namespace mlir;
1717

18+
static std::string toString(AffineExpr expr) {
19+
std::string s;
20+
llvm::raw_string_ostream ss(s);
21+
ss << expr;
22+
return s;
23+
}
24+
1825
// Test creating AffineExprs using the overloaded binary operators.
1926
TEST(AffineExprTest, constructFromBinaryOperators) {
2027
MLIRContext ctx;
@@ -112,3 +119,13 @@ TEST(AffineExprTest, divisorOfNegativeFloorDiv) {
112119
OpBuilder b(&ctx);
113120
ASSERT_EQ(b.getAffineDimExpr(0).floorDiv(-1).getLargestKnownDivisor(), 1);
114121
}
122+
123+
TEST(AffineExprTest, d0PlusD0FloorDivNeg2) {
124+
// Regression test for a bug where this was rewritten to d0 mod -2. We do not
125+
// support a negative RHS for mod in LowerAffinePass.
126+
MLIRContext ctx;
127+
OpBuilder b(&ctx);
128+
auto d0 = b.getAffineDimExpr(0);
129+
auto sum = d0 + d0.floorDiv(-2) * 2;
130+
ASSERT_EQ(toString(sum), "d0 + (d0 floordiv -2) * 2");
131+
}

0 commit comments

Comments
 (0)