Skip to content

[mlir][memref] Remove runtime verification for memref.reinterpret_cast #132547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 22, 2025

The runtime verification code used to verify that the result of a memref.reinterpret_cast is in-bounds with respect to the source memref. This is incorrect: memref.reinterpret_cast allows users to construct almost arbitrary memref descriptors and there is no correctness expectation.

This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of memref.reinterpret_cast does not verify in-bounds semantics either.

Depends on #132545.

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Matthias Springer (matthias-springer)

Changes

The runtime verification code used to verify that the result of a memref.reinterpret_cast is in-bounds with respect to the source memref. This is incorrect: memref.reinterpret_cast allows users to construct almost arbitrary memref descriptors and there is no correctness expectation.

This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of memref.reinterpret_cast does not verify in-bounds semantics either.


Full diff: https://github.com/llvm/llvm-project/pull/132547.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+1-73)
  • (removed) mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir (-74)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 7cd4814bf88d0..922111e1fad1f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -255,78 +255,6 @@ struct LoadStoreOpInterface
   }
 };
 
-/// Compute the linear index for the provided strided layout and indices.
-Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
-                         ArrayRef<OpFoldResult> strides,
-                         ArrayRef<OpFoldResult> indices) {
-  auto [expr, values] = computeLinearIndex(offset, strides, indices);
-  auto index =
-      affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
-  return getValueOrCreateConstantIndexOp(builder, loc, index);
-}
-
-/// Returns two Values representing the bounds of the provided strided layout
-/// metadata. The bounds are returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
-                                            OpFoldResult offset,
-                                            ArrayRef<OpFoldResult> strides,
-                                            ArrayRef<OpFoldResult> sizes) {
-  auto zeros = SmallVector<int64_t>(sizes.size(), 0);
-  auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
-  auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
-  auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
-  return {lowerBound, upperBound};
-}
-
-/// Returns two Values representing the bounds of the memref. The bounds are
-/// returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
-                                            TypedValue<BaseMemRefType> memref) {
-  auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
-  auto offset = runtimeMetadata.getConstifiedMixedOffset();
-  auto strides = runtimeMetadata.getConstifiedMixedStrides();
-  auto sizes = runtimeMetadata.getConstifiedMixedSizes();
-  return computeLinearBounds(builder, loc, offset, strides, sizes);
-}
-
-/// Verifies that the linear bounds of a reinterpret_cast op are within the
-/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
-struct ReinterpretCastOpInterface
-    : public RuntimeVerifiableOpInterface::ExternalModel<
-          ReinterpretCastOpInterface, ReinterpretCastOp> {
-  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
-                                   Location loc) const {
-    auto reinterpretCast = cast<ReinterpretCastOp>(op);
-    auto baseMemref = reinterpretCast.getSource();
-    auto resultMemref =
-        cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
-
-    builder.setInsertionPointAfter(op);
-
-    // Compute the linear bounds of the base memref
-    auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
-    // Compute the linear bounds of the resulting memref
-    auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
-    // Check low >= baseLow
-    auto geLow = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sge, low, baseLow);
-
-    // Check high <= baseHigh
-    auto leHigh = builder.createOrFold<arith::CmpIOp>(
-        loc, arith::CmpIPredicate::sle, high, baseHigh);
-
-    auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
-    builder.create<cf::AssertOp>(
-        loc, assertCond,
-        RuntimeVerifiableOpInterface::generateErrorMessage(
-            op,
-            "result of reinterpret_cast is out-of-bounds of the base memref"));
-  }
-};
-
 struct SubViewOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
                                                          SubViewOp> {
@@ -430,9 +358,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DimOp::attachInterface<DimOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
     LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
-    ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
     StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
     SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
+    // Note: There is nothing to verify for ReinterpretCastOp.
 
     // Load additional dialects of which ops may get created.
     ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
deleted file mode 100644
index 601a53f4b5cd9..0000000000000
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ /dev/null
@@ -1,74 +0,0 @@
-// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -test-cf-assert \
-// RUN:     -expand-strided-metadata \
-// RUN:     -lower-affine \
-// RUN:     -convert-to-llvm | \
-// RUN: mlir-runner -e main -entry-point-result=void \
-// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
-// RUN: FileCheck %s
-
-func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
-    memref.reinterpret_cast %memref to
-                    offset: [%offset],
-                    sizes: [1],
-                    strides: [1]
-                  : memref<1xf32> to  memref<1xf32, strided<[1], offset: ?>>
-    return
-}
-
-func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index)  {
-    memref.reinterpret_cast %memref to
-                    offset: [%offset],
-                    sizes: [%size],
-                    strides: [%stride]
-                  : memref<?xf32> to  memref<?xf32, strided<[?], offset: ?>>
-    return
-}
-
-func.func @main() {
-  %0 = arith.constant 0 : index
-  %1 = arith.constant 1 : index
-  %n1 = arith.constant -1 : index
-  %4 = arith.constant 4 : index
-  %5 = arith.constant 5 : index
-
-  %alloca_1 = memref.alloca() : memref<1xf32>
-  %alloca_4 = memref.alloca() : memref<4xf32>
-  %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
-
-  // Offset is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
-
-  // Offset is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
-
-  // Size is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
-
-  // Stride is out-of-bounds
-  //      CHECK: ERROR: Runtime op verification failed
-  // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
-  // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
-  // CHECK-NEXT: Location: loc({{.*}})
-  func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
-
-  //  CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
-
-  //  CHECK-NOT: ERROR: Runtime op verification failed
-  func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
-
-  return
-}

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK, but when we create a new memref, dont we want to verify that the strides specified dont make it such that accessing using strides goes out of bounds?

@matthias-springer
Copy link
Member Author

I think this is OK, but when we create a new memref, dont we want to verify that the strides specified dont make it such that accessing using strides goes out of bounds?

By "create a new memref" you mean the reinterpret_cast result, right? I'd say you don't want to verify this. Maybe you know that the source memref is a view into a larger allocation, so out-of-bounds access is safe. I think of it like a C++ reinterpret_cast, there you can cast to almost an arbitrary memref type.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/memref_subview_veri branch from 4603a89 to 0d3e9d8 Compare March 31, 2025 16:57
Base automatically changed from users/matthias-springer/memref_subview_veri to main March 31, 2025 17:24
The runtime verification code used to verify that the result of a `memref.reinterpret_cast` is in-bounds with respect to the source memref. This is incorrect: `memref.reinterpret_cast` allows users to construct almost arbitrary memref descriptors and there is no correctness expectation. This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of `memref.reinterpret_cast` does not verify in-bounds semantics either.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/remove_reinterpret_cast_veri branch from b52e2fd to 63d85da Compare April 27, 2025 07:20
@krzysz00 krzysz00 self-requested a review May 5, 2025 19:46
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, given the "trust me" nature of reinterpret_cast, removing this makes sense

@matthias-springer matthias-springer merged commit fd161cf into main May 6, 2025
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/remove_reinterpret_cast_veri branch May 6, 2025 07:40
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…st` (llvm#132547)

The runtime verification code used to verify that the result of a
`memref.reinterpret_cast` is in-bounds with respect to the source
memref. This is incorrect: `memref.reinterpret_cast` allows users to
construct almost arbitrary memref descriptors and there is no
correctness expectation.

This op is supposed to be used when the user "knows what they are
doing." Similarly, the static verifier of `memref.reinterpret_cast` does
not verify in-bounds semantics either.
CoTinker pushed a commit that referenced this pull request Jun 23, 2025
…have an implementation (#145230)

Previously running `-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` would crash because its implementation of the
`RuntimeVerifiableOpInterface` was removed in
#132547 but its associated
entry in `declarePromisedInterface` was never removed.

This causes an error when you try and run
`-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` that looks like

```
LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered.
```
as reported in #144028.

In this PR I also added all the ops that do have implementations of this
interface in
`mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the
`declarePromisedInterface` for consistency.

Fixes #144028
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 23, 2025
…f ops that have an implementation (#145230)

Previously running `-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` would crash because its implementation of the
`RuntimeVerifiableOpInterface` was removed in
llvm/llvm-project#132547 but its associated
entry in `declarePromisedInterface` was never removed.

This causes an error when you try and run
`-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` that looks like

```
LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered.
```
as reported in llvm/llvm-project#144028.

In this PR I also added all the ops that do have implementations of this
interface in
`mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the
`declarePromisedInterface` for consistency.

Fixes llvm/llvm-project#144028
miguelcsx pushed a commit to miguelcsx/llvm-project that referenced this pull request Jun 23, 2025
…have an implementation (llvm#145230)

Previously running `-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` would crash because its implementation of the
`RuntimeVerifiableOpInterface` was removed in
llvm#132547 but its associated
entry in `declarePromisedInterface` was never removed.

This causes an error when you try and run
`-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` that looks like

```
LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered.
```
as reported in llvm#144028.

In this PR I also added all the ops that do have implementations of this
interface in
`mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the
`declarePromisedInterface` for consistency.

Fixes llvm#144028
Jaddyen pushed a commit to Jaddyen/llvm-project that referenced this pull request Jun 23, 2025
…have an implementation (llvm#145230)

Previously running `-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` would crash because its implementation of the
`RuntimeVerifiableOpInterface` was removed in
llvm#132547 but its associated
entry in `declarePromisedInterface` was never removed.

This causes an error when you try and run
`-generate-runtime-verification` on an IR containing
`memref.reinterpret_cast` that looks like

```
LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered.
```
as reported in llvm#144028.

In this PR I also added all the ops that do have implementations of this
interface in
`mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the
`declarePromisedInterface` for consistency.

Fixes llvm#144028
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants