Skip to content

[mlir][affine] Fix dim index out of bounds crash #73266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 28, 2023
Merged

[mlir][affine] Fix dim index out of bounds crash #73266

merged 1 commit into from
Nov 28, 2023

Conversation

rikhuijzer
Copy link
Member

This PR suggests a way to fix #70418. It now throws an error if the index operand for memref.dim is out of bounds. Catching it in the verifier was not possible because the constant value is not yet available at that point. Unfortunately, the error is not very descriptive since it was only possible to propagate boolean up.

@llvmbot
Copy link
Member

llvmbot commented Nov 23, 2023

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Rik Huijzer (rikhuijzer)

Changes

This PR suggests a way to fix #70418. It now throws an error if the index operand for memref.dim is out of bounds. Catching it in the verifier was not possible because the constant value is not yet available at that point. Unfortunately, the error is not very descriptive since it was only possible to propagate boolean up.


Full diff: https://github.com/llvm/llvm-project/pull/73266.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+24-16)
  • (modified) mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir (+12)
  • (modified) mlir/test/Dialect/Affine/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/Affine/load-store-invalid.mlir (+6-6)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..d6e640ddd8f25d5 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -317,9 +317,16 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
 /// for `region`.
 template <typename AnyMemRefDefOp>
-static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
+static bool isMemRefSizeValidSymbol(ShapedDimOpInterface dimOp,
+                                    AnyMemRefDefOp memrefDefOp, unsigned index,
                                     Region *region) {
-  auto memRefType = memrefDefOp.getType();
+  MemRefType memRefType = memrefDefOp.getType();
+
+  // Dimension index is out of bounds.
+  if (index >= memRefType.getRank()) {
+    return false;
+  }
+
   // Statically shaped.
   if (!memRefType.isDynamicDim(index))
     return true;
@@ -351,7 +358,9 @@ static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
   int64_t i = index.value();
   return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
       .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
-          [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
+          [&](auto memRefDefOp) {
+            return isMemRefSizeValidSymbol(dimOp, memRefDefOp, i, region);
+          })
       .Default([](Operation *) { return false; });
 }
 
@@ -1651,19 +1660,19 @@ LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
     if (!idx.getType().isIndex())
       return emitOpError("src index to dma_start must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("src index must be a dimension or symbol identifier");
+      return emitOpError("src index must be a valid dimension or symbol identifier");
   }
   for (auto idx : getDstIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("dst index to dma_start must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("dst index must be a dimension or symbol identifier");
+      return emitOpError("dst index must be a valid dimension or symbol identifier");
   }
   for (auto idx : getTagIndices()) {
     if (!idx.getType().isIndex())
       return emitOpError("tag index to dma_start must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("tag index must be a dimension or symbol identifier");
+      return emitOpError("tag index must be a valid dimension or symbol identifier");
   }
   return success();
 }
@@ -1752,7 +1761,7 @@ LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
     if (!idx.getType().isIndex())
       return emitOpError("index to dma_wait must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("index must be a dimension or symbol identifier");
+      return emitOpError("index must be a valid dimension or symbol identifier");
   }
   return success();
 }
@@ -2913,8 +2922,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -3005,18 +3013,18 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
   for (auto idx : mapOperands) {
     if (!idx.getType().isIndex())
       return op->emitOpError("index to load must have 'index' type");
     if (!isValidAffineIndexOperand(idx, scope))
-      return op->emitOpError("index must be a dimension or symbol identifier");
+      return op->emitOpError("index must be a valid dimension or symbol identifier");
   }
 
   return success();
@@ -3605,7 +3613,7 @@ LogicalResult AffinePrefetchOp::verify() {
   Region *scope = getAffineScope(*this);
   for (auto idx : getMapOperands()) {
     if (!isValidAffineIndexOperand(idx, scope))
-      return emitOpError("index must be a dimension or symbol identifier");
+      return emitOpError("index must be a valid dimension or symbol identifier");
   }
   return success();
 }
diff --git a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
index 759ab2d6c358c8a..b94d271fc197014 100644
--- a/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/func-ops-to-spirv.mlir
@@ -49,3 +49,15 @@ func.func @call_functions(%arg0: index) -> index {
 }
 
 // -----
+
+func.func @dim_out_of_bounds() {
+  %c6 = arith.constant 6 : index
+  %alloc_4 = memref.alloc() : memref<4xi64>
+  %dim = memref.dim %alloc_4, %c6 : memref<4xi64> // Out of bounds; UB.
+  %alloca_100 = memref.alloca() : memref<100xi64>
+  // expected-error@+1 {{'affine.vector_load' op index must be a valid dimension or symbol identifier}}
+  %70 = affine.vector_load %alloca_100[%dim] : memref<100xi64>, vector<31xi64>
+  return
+}
+
+// -----
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 72864516b459a51..60f13102f551569 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -55,7 +55,7 @@ func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
   "unknown"() ({
   ^bb0(%arg: index):
     affine.load %M[%arg] : memref<10xi32>
-    // expected-error@-1 {{index must be a dimension or symbol identifier}}
+    // expected-error@-1 {{index must be a valid dimension or symbol identifier}}
     cf.br ^bb1
   ^bb1:
     cf.br ^bb1
@@ -521,7 +521,7 @@ func.func @dynamic_dimension_index() {
     %idx = "unknown.test"() : () -> (index)
     %memref = "unknown.test"() : () -> memref<?x?xf32>
     %dim = memref.dim %memref, %idx : memref<?x?xf32>
-    // expected-error @below {{op index must be a dimension or symbol identifier}}
+    // expected-error @below {{op index must be a valid dimension or symbol identifier}}
     affine.load %memref[%dim, %dim] : memref<?x?xf32>
     "unknown.terminator"() : () -> ()
   }) : () -> ()
diff --git a/mlir/test/Dialect/Affine/load-store-invalid.mlir b/mlir/test/Dialect/Affine/load-store-invalid.mlir
index 482d2f35e094923..01d6b25dee695bb 100644
--- a/mlir/test/Dialect/Affine/load-store-invalid.mlir
+++ b/mlir/test/Dialect/Affine/load-store-invalid.mlir
@@ -37,7 +37,7 @@ func.func @load_non_affine_index(%arg0 : index) {
   %0 = memref.alloc() : memref<10xf32>
   affine.for %i0 = 0 to 10 {
     %1 = arith.muli %i0, %arg0 : index
-    // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+    // expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
     %v = affine.load %0[%1] : memref<10xf32>
   }
   return
@@ -50,7 +50,7 @@ func.func @store_non_affine_index(%arg0 : index) {
   %1 = arith.constant 11.0 : f32
   affine.for %i0 = 0 to 10 {
     %2 = arith.muli %i0, %arg0 : index
-    // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+    // expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
     affine.store %1, %0[%2] : memref<10xf32>
   }
   return
@@ -84,7 +84,7 @@ func.func @dma_start_non_affine_src_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error@+1 {{op src index must be a dimension or symbol identifier}}
+    // expected-error@+1 {{op src index must be a valid dimension or symbol identifier}}
     affine.dma_start %0[%3], %1[%i0], %2[%c0], %c64
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
   }
@@ -101,7 +101,7 @@ func.func @dma_start_non_affine_dst_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error@+1 {{op dst index must be a dimension or symbol identifier}}
+    // expected-error@+1 {{op dst index must be a valid dimension or symbol identifier}}
     affine.dma_start %0[%i0], %1[%3], %2[%c0], %c64
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
   }
@@ -118,7 +118,7 @@ func.func @dma_start_non_affine_tag_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error@+1 {{op tag index must be a dimension or symbol identifier}}
+    // expected-error@+1 {{op tag index must be a valid dimension or symbol identifier}}
     affine.dma_start %0[%i0], %1[%arg0], %2[%3], %c64
         : memref<100xf32>, memref<100xf32, 2>, memref<1xi32, 4>
   }
@@ -135,7 +135,7 @@ func.func @dma_wait_non_affine_tag_index(%arg0 : index) {
   %c64 = arith.constant 64 : index
   affine.for %i0 = 0 to 10 {
     %3 = arith.muli %i0, %arg0 : index
-    // expected-error@+1 {{op index must be a dimension or symbol identifier}}
+    // expected-error@+1 {{op index must be a valid dimension or symbol identifier}}
     affine.dma_wait %2[%3], %c64 : memref<1xi32, 4>
   }
   return

This comment was marked as outdated.

@rikhuijzer
Copy link
Member Author

Also related is #73027. That PR allows us to simply the code and we could then also remove the "or symbol type" from the error message since that is tested as part of isValidDim.

@rikhuijzer rikhuijzer merged commit 3247f1e into llvm:main Nov 28, 2023
@rikhuijzer rikhuijzer deleted the rh/dim-index-out-of-bounds branch November 28, 2023 06:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] --convert-func-to-spirv crashed with assertion failure "invalid index for shaped type"
3 participants