-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Do not trigger UB during AffineExpr parsing. #96896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include <cstdint> | ||
#include <limits> | ||
#include <utility> | ||
|
||
#include "AffineExprDetail.h" | ||
|
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants, | |
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { | ||
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); | ||
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); | ||
// Fold if both LHS, RHS are a constant. | ||
if (lhsConst && rhsConst) | ||
return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(), | ||
lhs.getContext()); | ||
// Fold if both LHS, RHS are a constant and the sum does not overflow. | ||
if (lhsConst && rhsConst) { | ||
int64_t sum; | ||
if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) { | ||
return nullptr; | ||
} | ||
return getAffineConstantExpr(sum, lhs.getContext()); | ||
} | ||
|
||
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4). | ||
// If only one of them is a symbolic expressions, make it the RHS. | ||
|
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) { | |
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs); | ||
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs); | ||
|
||
if (lhsConst && rhsConst) | ||
return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(), | ||
lhs.getContext()); | ||
if (lhsConst && rhsConst) { | ||
int64_t product; | ||
if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) { | ||
return nullptr; | ||
} | ||
return getAffineConstantExpr(product, lhs.getContext()); | ||
} | ||
|
||
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) | ||
return nullptr; | ||
|
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) { | |
if (!rhsConst || rhsConst.getValue() < 1) | ||
return nullptr; | ||
|
||
if (lhsConst) | ||
if (lhsConst) { | ||
// divideFloorSigned can only overflow in this case: | ||
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() && | ||
rhsConst.getValue() == -1) { | ||
return nullptr; | ||
} | ||
return getAffineConstantExpr( | ||
divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()), | ||
lhs.getContext()); | ||
} | ||
|
||
// Fold floordiv of a multiply with a constant that is a multiple of the | ||
// divisor. Eg: (i * 128) floordiv 64 = i * 2. | ||
|
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) { | |
if (!rhsConst || rhsConst.getValue() < 1) | ||
return nullptr; | ||
|
||
if (lhsConst) | ||
if (lhsConst) { | ||
// divideCeilSigned can only overflow in this case: | ||
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. already covered above in line 921 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's fix that separately? Line 921 seems wrong, it should be checking for == 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sgtm |
||
rhsConst.getValue() == -1) { | ||
return nullptr; | ||
} | ||
return getAffineConstantExpr( | ||
divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()), | ||
lhs.getContext()); | ||
} | ||
|
||
// Fold ceildiv of a multiply with a constant that is a multiple of the | ||
// divisor. Eg: (i * 128) ceildiv 64 = i * 2. | ||
|
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) { | |
if (!rhsConst || rhsConst.getValue() < 1) | ||
return nullptr; | ||
|
||
if (lhsConst) | ||
if (lhsConst) { | ||
// mod never overflows. | ||
return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()), | ||
lhs.getContext()); | ||
} | ||
|
||
// Fold modulo of an expression that is known to be a multiple of a constant | ||
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already covered above in line 859