Skip to content

Commit 3247f1e

Browse files
authored
[mlir][affine] Fix dim index out of bounds crash (#73266)
This PR suggests a way to fix #70418. It now throws an error if the `index` operand for `memref.dim` is out of bounds. Catching it in the verifier was not possible because the constant value is not yet available at that point. Unfortunately, the error is not very descriptive since it was only possible to propagate boolean up.
1 parent 1b1f3c2 commit 3247f1e

File tree

4 files changed

+45
-22
lines changed

4 files changed

+45
-22
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,13 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
319319
template <typename AnyMemRefDefOp>
320320
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
321321
Region *region) {
322-
auto memRefType = memrefDefOp.getType();
322+
MemRefType memRefType = memrefDefOp.getType();
323+
324+
// Dimension index is out of bounds.
325+
if (index >= memRefType.getRank()) {
326+
return false;
327+
}
328+
323329
// Statically shaped.
324330
if (!memRefType.isDynamicDim(index))
325331
return true;
@@ -1651,19 +1657,22 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
16511657
if (!idx.getType().isIndex())
16521658
return emitOpError("src index to dma_start must have 'index' type");
16531659
if (!isValidAffineIndexOperand(idx, scope))
1654-
return emitOpError("src index must be a dimension or symbol identifier");
1660+
return emitOpError(
1661+
"src index must be a valid dimension or symbol identifier");
16551662
}
16561663
for (auto idx : getDstIndices()) {
16571664
if (!idx.getType().isIndex())
16581665
return emitOpError("dst index to dma_start must have 'index' type");
16591666
if (!isValidAffineIndexOperand(idx, scope))
1660-
return emitOpError("dst index must be a dimension or symbol identifier");
1667+
return emitOpError(
1668+
"dst index must be a valid dimension or symbol identifier");
16611669
}
16621670
for (auto idx : getTagIndices()) {
16631671
if (!idx.getType().isIndex())
16641672
return emitOpError("tag index to dma_start must have 'index' type");
16651673
if (!isValidAffineIndexOperand(idx, scope))
1666-
return emitOpError("tag index must be a dimension or symbol identifier");
1674+
return emitOpError(
1675+
"tag index must be a valid dimension or symbol identifier");
16671676
}
16681677
return success();
16691678
}
@@ -1752,7 +1761,8 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
17521761
if (!idx.getType().isIndex())
17531762
return emitOpError("index to dma_wait must have 'index' type");
17541763
if (!isValidAffineIndexOperand(idx, scope))
1755-
return emitOpError("index must be a dimension or symbol identifier");
1764+
return emitOpError(
1765+
"index must be a valid dimension or symbol identifier");
17561766
}
17571767
return success();
17581768
}
@@ -2913,8 +2923,7 @@ static void composeSetAndOperands(IntegerSet &set,
29132923
}
29142924

29152925
/// Canonicalize an affine if op's conditional (integer set + operands).
2916-
LogicalResult AffineIfOp::fold(FoldAdaptor,
2917-
SmallVectorImpl<OpFoldResult> &) {
2926+
LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
29182927
auto set = getIntegerSet();
29192928
SmallVector<Value, 4> operands(getOperands());
29202929
composeSetAndOperands(set, operands);
@@ -3005,18 +3014,19 @@ static LogicalResult
30053014
verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
30063015
Operation::operand_range mapOperands,
30073016
MemRefType memrefType, unsigned numIndexOperands) {
3008-
AffineMap map = mapAttr.getValue();
3009-
if (map.getNumResults() != memrefType.getRank())
3010-
return op->emitOpError("affine map num results must equal memref rank");
3011-
if (map.getNumInputs() != numIndexOperands)
3012-
return op->emitOpError("expects as many subscripts as affine map inputs");
3017+
AffineMap map = mapAttr.getValue();
3018+
if (map.getNumResults() != memrefType.getRank())
3019+
return op->emitOpError("affine map num results must equal memref rank");
3020+
if (map.getNumInputs() != numIndexOperands)
3021+
return op->emitOpError("expects as many subscripts as affine map inputs");
30133022

30143023
Region *scope = getAffineScope(op);
30153024
for (auto idx : mapOperands) {
30163025
if (!idx.getType().isIndex())
30173026
return op->emitOpError("index to load must have 'index' type");
30183027
if (!isValidAffineIndexOperand(idx, scope))
3019-
return op->emitOpError("index must be a dimension or symbol identifier");
3028+
return op->emitOpError(
3029+
"index must be a valid dimension or symbol identifier");
30203030
}
30213031

30223032
return success();
@@ -3605,7 +3615,8 @@ LogicalResult AffinePrefetchOp::verify() {
36053615
Region *scope = getAffineScope(*this);
36063616
for (auto idx : getMapOperands()) {
36073617
if (!isValidAffineIndexOperand(idx, scope))
3608-
return emitOpError("index must be a dimension or symbol identifier");
3618+
return emitOpError(
3619+
"index must be a valid dimension or symbol identifier");
36093620
}
36103621
return success();
36113622
}

mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,15 @@ func.func @call_functions(%arg0: index) -> index {
4949
}
5050

5151
// -----
52+
53+
func.func @dim_index_out_of_bounds() {
54+
%c6 = arith.constant 6 : index
55+
%alloc_4 = memref.alloc() : memref<4xi64>
56+
%dim = memref.dim %alloc_4, %c6 : memref<4xi64>
57+
%alloca_100 = memref.alloca() : memref<100xi64>
58+
// expected-error@+1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
59+
%70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
60+
return
61+
}
62+
63+
// -----

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
5555
"unknown"() ({
5656
^bb0(%arg: index):
5757
affine.load %M[%arg] : memref<10xi32>
58-
// expected-error@-1 {{index must be a dimension or symbol identifier}}
58+
// expected-error@-1 {{index must be a valid dimension or symbol identifier}}
5959
cf.br ^bb1
6060
^bb1:
6161
cf.br ^bb1
@@ -521,7 +521,7 @@ func.func @dynamic_dimension_index() {
521521
%idx = "unknown.test"() : () -> (index)
522522
%memref = "unknown.test"() : () -> memref<?x?xf32>
523523
%dim = memref.dim %memref, %idx : memref<?x?xf32>
524-
// expected-error @below {{op index must be a dimension or symbol identifier}}
524+
// expected-error @below {{op index must be a valid dimension or symbol identifier}}
525525
affine.load %memref[%dim, %dim] : memref<?x?xf32>
526526
"unknown.terminator"() : () -> ()
527527
}) : () -> ()

mlir/test/Dialect/Affine/load-store-invalid.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func.func @load_non_affine_index(%arg0 : index) {
3737
%0 = memref.alloc() : memref<10xf32>
3838
affine.for %i0 = 0 to 10 {
3939
%1 = arith.muli %i0, %arg0 : index
40-
// expected-error@+1 {{op index must be a dimension or symbol identifier}}
40+
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
4141
%v = affine.load %0[%1] : memref<10xf32>
4242
}
4343
return
@@ -50,7 +50,7 @@ func.func @store_non_affine_index(%arg0 : index) {
5050
%1 = arith.constant 11.0 : f32
5151
affine.for %i0 = 0 to 10 {
5252
%2 = arith.muli %i0, %arg0 : index
53-
// expected-error@+1 {{op index must be a dimension or symbol identifier}}
53+
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
5454
affine.store %1, %0[%2] : memref<10xf32>
5555
}
5656
return
@@ -84,7 +84,7 @@ func.func @dma_start_non_affine_src_index(%arg0 : index) {
8484
%c64 = arith.constant 64 : index
8585
affine.for %i0 = 0 to 10 {
8686
%3 = arith.muli %i0, %arg0 : index
87-
// expected-error@+1 {{op src index must be a dimension or symbol identifier}}
87+
// expected-error@+1 {{op src index must be a valid dimension or symbol identifier}}
8888
affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
8989
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
9090
}
@@ -101,7 +101,7 @@ func.func @dma_start_non_affine_dst_index(%arg0 : index) {
101101
%c64 = arith.constant 64 : index
102102
affine.for %i0 = 0 to 10 {
103103
%3 = arith.muli %i0, %arg0 : index
104-
// expected-error@+1 {{op dst index must be a dimension or symbol identifier}}
104+
// expected-error@+1 {{op dst index must be a valid dimension or symbol identifier}}
105105
affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
106106
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
107107
}
@@ -118,7 +118,7 @@ func.func @dma_start_non_affine_tag_index(%arg0 : index) {
118118
%c64 = arith.constant 64 : index
119119
affine.for %i0 = 0 to 10 {
120120
%3 = arith.muli %i0, %arg0 : index
121-
// expected-error@+1 {{op tag index must be a dimension or symbol identifier}}
121+
// expected-error@+1 {{op tag index must be a valid dimension or symbol identifier}}
122122
affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
123123
: memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
124124
}
@@ -135,7 +135,7 @@ func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
135135
%c64 = arith.constant 64 : index
136136
affine.for %i0 = 0 to 10 {
137137
%3 = arith.muli %i0, %arg0 : index
138-
// expected-error@+1 {{op index must be a dimension or symbol identifier}}
138+
// expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
139139
affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
140140
}
141141
return

0 commit comments

Comments
 (0)