-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][affine] Fix dim index out of bounds crash #73266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir-affine Author: Rik Huijzer (rikhuijzer) ChangesThis PR suggests a way to fix #70418. It now throws an error if the Full diff: https://github.com/llvm/llvm-project/pull/73266.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..d6e640ddd8f25d5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -317,9 +317,16 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
/// for `region`.
template <typename AnyMemRefDefOp>
-static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
+static bool isMemRefSizeValidSymbol(ShapedDimOpInterface dimOp,
+ AnyMemRefDefOp memrefDefOp, unsigned index,
Region *region) {
- auto memRefType = memrefDefOp.getType();
+ MemRefType memRefType = memrefDefOp.getType();
+
+ // Dimension index is out of bounds.
+ if (index >= memRefType.getRank()) {
+ return false;
+ }
+
// Statically shaped.
if (!memRefType.isDynamicDim(index))
return true;
@@ -351,7 +358,9 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
int64_t i = index.value();
return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
.Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
- [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
+ [&](auto memRefDefOp) {
+ return isMemRefSizeValidSymbol(dimOp, memRefDefOp, i, region);
+ })
.Default([](Operation *) { return false; });
}
@@ -1651,19 +1660,19 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
if (!idx.getType().isIndex())
return emitOpError("src index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("src index must be a dimension or symbol identifier");
+ return emitOpError("src index must be a valid dimension or symbol identifier");
}
for (auto idx : getDstIndices()) {
if (!idx.getType().isIndex())
return emitOpError("dst index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("dst index must be a dimension or symbol identifier");
+ return emitOpError("dst index must be a valid dimension or symbol identifier");
}
for (auto idx : getTagIndices()) {
if (!idx.getType().isIndex())
return emitOpError("tag index to dma_start must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("tag index must be a dimension or symbol identifier");
+ return emitOpError("tag index must be a valid dimension or symbol identifier");
}
return success();
}
@@ -1752,7 +1761,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
if (!idx.getType().isIndex())
return emitOpError("index to dma_wait must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("index must be a dimension or symbol identifier");
+ return emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
}
@@ -2913,8 +2922,7 @@ static void composeSetAndOperands(IntegerSet &set,
}
/// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
- SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
auto set = getIntegerSet();
SmallVector<Value, 4> operands(getOperands());
composeSetAndOperands(set, operands);
@@ -3005,18 +3013,18 @@ static LogicalResult
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
Operation::operand_range mapOperands,
MemRefType memrefType, unsigned numIndexOperands) {
- AffineMap map = mapAttr.getValue();
- if (map.getNumResults() != memrefType.getRank())
- return op->emitOpError("affine map num results must equal memref rank");
- if (map.getNumInputs() != numIndexOperands)
- return op->emitOpError("expects as many subscripts as affine map inputs");
+ AffineMap map = mapAttr.getValue();
+ if (map.getNumResults() != memrefType.getRank())
+ return op->emitOpError("affine map num results must equal memref rank");
+ if (map.getNumInputs() != numIndexOperands)
+ return op->emitOpError("expects as many subscripts as affine map inputs");
Region *scope = getAffineScope(op);
for (auto idx : mapOperands) {
if (!idx.getType().isIndex())
return op->emitOpError("index to load must have 'index' type");
if (!isValidAffineIndexOperand(idx, scope))
- return op->emitOpError("index must be a dimension or symbol identifier");
+ return op->emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
@@ -3605,7 +3613,7 @@ LogicalResult AffinePrefetchOp::verify() {
Region *scope = getAffineScope(*this);
for (auto idx : getMapOperands()) {
if (!isValidAffineIndexOperand(idx, scope))
- return emitOpError("index must be a dimension or symbol identifier");
+ return emitOpError("index must be a valid dimension or symbol identifier");
}
return success();
}
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index 759ab2d6c358c8a..b94d271fc197014 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -49,3 +49,15 @@ func.func @call_functions(%arg0: index) -> index {
}
// -----
+
+func.func @dim_out_of_bounds() {
+ %c6 = arith.constant 6 : index
+ %alloc_4 = memref.alloc() : memref<4xi64>
+ %dim = memref.dim %alloc_4, %c6 : memref<4xi64> // Out of bounds; UB.
+ %alloca_100 = memref.alloca() : memref<100xi64>
+ // expected-error@+1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
+ %70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
+ return
+}
+
+// -----
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 72864516b459a51..60f13102f551569 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -55,7 +55,7 @@ func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
"unknown"() ({
^bb0(%arg: index):
affine.load %M[%arg] : memref<10xi32>
- // expected-error@-1 {{index must be a dimension or symbol identifier}}
+ // expected-error@-1 {{index must be a valid dimension or symbol identifier}}
cf.br ^bb1
^bb1:
cf.br ^bb1
@@ -521,7 +521,7 @@ func.func @dynamic_dimension_index() {
%idx = "unknown.test"() : () -> (index)
%memref = "unknown.test"() : () -> memref<?x?xf32>
%dim = memref.dim %memref, %idx : memref<?x?xf32>
- // expected-error @below {{op index must be a dimension or symbol identifier}}
+ // expected-error @below {{op index must be a valid dimension or symbol identifier}}
affine.load %memref[%dim, %dim] : memref<?x?xf32>
"unknown.terminator"() : () -> ()
}) : () -> ()
diff --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 482d2f35e094923..01d6b25dee695bb 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -37,7 +37,7 @@ func.func @load_non_affine_index(%arg0 : index) {
%0 = memref.alloc() : memref<10xf32>
affine.for %i0 = 0 to 10 {
%1 = arith.muli %i0, %arg0 : index
- // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+ // expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
%v = affine.load %0[%1] : memref<10xf32>
}
return
@@ -50,7 +50,7 @@ func.func @store_non_affine_index(%arg0 : index) {
%1 = arith.constant 11.0 : f32
affine.for %i0 = 0 to 10 {
%2 = arith.muli %i0, %arg0 : index
- // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+ // expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
affine.store %1, %0[%2] : memref<10xf32>
}
return
@@ -84,7 +84,7 @@ func.func @dma_start_non_affine_src_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error@+1 {{op src index must be a dimension or symbol identifier}}
+ // expected-error@+1 {{op src index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -101,7 +101,7 @@ func.func @dma_start_non_affine_dst_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error@+1 {{op dst index must be a dimension or symbol identifier}}
+ // expected-error@+1 {{op dst index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -118,7 +118,7 @@ func.func @dma_start_non_affine_tag_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error@+1 {{op tag index must be a dimension or symbol identifier}}
+ // expected-error@+1 {{op tag index must be a valid dimension or symbol identifier}}
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
}
@@ -135,7 +135,7 @@ func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
%c64 = arith.constant 64 : index
affine.for %i0 = 0 to 10 {
%3 = arith.muli %i0, %arg0 : index
- // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+ // expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
}
return
|
This comment was marked as outdated.
This comment was marked as outdated.
Also related is #73027. That PR allows us to simply the code and we could then also remove the "or symbol type" from the error message since that is tested as part of |
This PR suggests a way to fix #70418. It now throws an error if the
index
operand formemref.dim
is out of bounds. Catching it in the verifier was not possible because the constant value is not yet available at that point. Unfortunately, the error is not very descriptive since it was only possible to propagate boolean up.