-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Do not trigger UB during AffineExpr parsing. #96896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently, parsing expressions that are undefined will trigger UB during compilation (e.g. INT_MIN / -1). This change instead leaves the expressions as they were. This change is an NFC for compilations that did not previously involve UB.
@llvm/pr-subscribers-llvm-support @llvm/pr-subscribers-mlir-core Author: Johannes Reifferscheid (jreiffers) ChangesCurrently, parsing expressions that are undefined will trigger UB during compilation This change is an NFC for compilations that did not previously involve UB. Full diff: https://github.com/llvm/llvm-project/pull/96896.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 3bba999fb00e9..6de754f472635 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -435,7 +435,8 @@ inline uint64_t divideCeil(uint64_t Numerator, uint64_t Denominator) {
}
/// Returns the integer ceil(Numerator / Denominator). Signed version.
-/// Guaranteed to never overflow.
+/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
+/// is -1.
inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
assert(Denominator && "Division by zero");
if (!Numerator)
@@ -448,7 +449,8 @@ inline int64_t divideCeilSigned(int64_t Numerator, int64_t Denominator) {
}
/// Returns the integer floor(Numerator / Denominator). Signed version.
-/// Guaranteed to never overflow.
+/// Guaranteed to never overflow, unless Numerator is INT64_MIN and Denominator
+/// is -1.
inline int64_t divideFloorSigned(int64_t Numerator, int64_t Denominator) {
assert(Denominator && "Division by zero");
if (!Numerator)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 1fab33327ba76..cf8157cf7bb8c 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include <cstdint>
+#include <limits>
#include <utility>
#include "AffineExprDetail.h"
@@ -645,10 +647,14 @@ mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
- // Fold if both LHS, RHS are a constant.
- if (lhsConst && rhsConst)
- return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
- lhs.getContext());
+ // Fold if both LHS, RHS are a constant and the sum does not overflow.
+ if (lhsConst && rhsConst) {
+ int64_t sum;
+ if (llvm::AddOverflow(lhsConst.getValue(), rhsConst.getValue(), sum)) {
+ return nullptr;
+ }
+ return getAffineConstantExpr(sum, lhs.getContext());
+ }
// Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
// If only one of them is a symbolic expressions, make it the RHS.
@@ -774,9 +780,13 @@ static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
auto lhsConst = dyn_cast<AffineConstantExpr>(lhs);
auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
- if (lhsConst && rhsConst)
- return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
- lhs.getContext());
+ if (lhsConst && rhsConst) {
+ int64_t product;
+ if (llvm::MulOverflow(lhsConst.getValue(), rhsConst.getValue(), product)) {
+ return nullptr;
+ }
+ return getAffineConstantExpr(product, lhs.getContext());
+ }
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
return nullptr;
@@ -849,10 +859,16 @@ static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
- if (lhsConst)
+ if (lhsConst) {
+ // divideFloorSigned can only overflow in this case:
+ if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
+ rhsConst.getValue() == -1) {
+ return nullptr;
+ }
return getAffineConstantExpr(
divideFloorSigned(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
+ }
// Fold floordiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) floordiv 64 = i * 2.
@@ -905,10 +921,16 @@ static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
- if (lhsConst)
+ if (lhsConst) {
+ // divideCeilSigned can only overflow in this case:
+ if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() &&
+ rhsConst.getValue() == -1) {
+ return nullptr;
+ }
return getAffineConstantExpr(
divideCeilSigned(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
+ }
// Fold ceildiv of a multiply with a constant that is a multiple of the
// divisor. Eg: (i * 128) ceildiv 64 = i * 2.
@@ -950,9 +972,11 @@ static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
if (!rhsConst || rhsConst.getValue() < 1)
return nullptr;
- if (lhsConst)
+ if (lhsConst) {
+ // mod never overflows.
return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
lhs.getContext());
+ }
// Fold modulo of an expression that is known to be a multiple of a constant
// to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp
index ff154eb29807c..9740165c6b324 100644
--- a/mlir/unittests/IR/AffineExprTest.cpp
+++ b/mlir/unittests/IR/AffineExprTest.cpp
@@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include <cstdint>
+#include <limits>
+
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "gtest/gtest.h"
@@ -30,3 +33,46 @@ TEST(AffineExprTest, constructFromBinaryOperators) {
ASSERT_EQ(product.getKind(), AffineExprKind::Mul);
ASSERT_EQ(remainder.getKind(), AffineExprKind::Mod);
}
+
+TEST(AffineExprTest, constantFolding) {
+ MLIRContext ctx;
+ OpBuilder b(&ctx);
+ auto cn1 = b.getAffineConstantExpr(-1);
+ auto c0 = b.getAffineConstantExpr(0);
+ auto c1 = b.getAffineConstantExpr(1);
+ auto c2 = b.getAffineConstantExpr(2);
+ auto c3 = b.getAffineConstantExpr(3);
+ auto c6 = b.getAffineConstantExpr(6);
+ auto cmax = b.getAffineConstantExpr(std::numeric_limits<int64_t>::max());
+ auto cmin = b.getAffineConstantExpr(std::numeric_limits<int64_t>::min());
+
+ ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Add, c1, c2), c3);
+ ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::Mul, c2, c3), c6);
+ ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c2), c1);
+ ASSERT_EQ(getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c2), c2);
+
+ // Test division by zero:
+ auto c3ceildivc0 = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, c3, c0);
+ ASSERT_EQ(c3ceildivc0.getKind(), AffineExprKind::CeilDiv);
+
+ auto c3floordivc0 = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, c3, c0);
+ ASSERT_EQ(c3floordivc0.getKind(), AffineExprKind::FloorDiv);
+
+ auto c3modc0 = getAffineBinaryOpExpr(AffineExprKind::Mod, c3, c0);
+ ASSERT_EQ(c3modc0.getKind(), AffineExprKind::Mod);
+
+ // Test overflow:
+ auto cmaxplusc1 = getAffineBinaryOpExpr(AffineExprKind::Add, cmax, c1);
+ ASSERT_EQ(cmaxplusc1.getKind(), AffineExprKind::Add);
+
+ auto cmaxtimesc2 = getAffineBinaryOpExpr(AffineExprKind::Mul, cmax, c2);
+ ASSERT_EQ(cmaxtimesc2.getKind(), AffineExprKind::Mul);
+
+ auto cminceildivcn1 =
+ getAffineBinaryOpExpr(AffineExprKind::CeilDiv, cmin, cn1);
+ ASSERT_EQ(cminceildivcn1.getKind(), AffineExprKind::CeilDiv);
+
+ auto cminfloordivcn1 =
+ getAffineBinaryOpExpr(AffineExprKind::FloorDiv, cmin, cn1);
+ ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv);
+}
|
if (lhsConst) { | ||
// divideFloorSigned can only overflow in this case: | ||
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() && | ||
rhsConst.getValue() == -1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already covered above in line 859
if (lhsConst) | ||
if (lhsConst) { | ||
// divideCeilSigned can only overflow in this case: | ||
if (lhsConst.getValue() == std::numeric_limits<int64_t>::min() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already covered above in line 921
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's fix that separately? Line 921 seems wrong, it should be checking for == 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
AFAICT the failed buildkite was due to |
Currently, parsing expressions that are undefined will trigger UB during compilation (e.g. `9223372036854775807 * 2`). This change instead leaves the expressions as they were. This change is an NFC for compilations that did not previously involve UB.
Currently, parsing expressions that are undefined will trigger UB during compilation
(e.g.
9223372036854775807 * 2
). This change instead leaves the expressions asthey were.
This change is an NFC for compilations that did not previously involve UB.