-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][memref] Remove runtime verification for memref.reinterpret_cast
#132547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] Remove runtime verification for memref.reinterpret_cast
#132547
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) ChangesThe runtime verification code used to verify that the result of a This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of Full diff: https://github.com/llvm/llvm-project/pull/132547.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 7cd4814bf88d0..922111e1fad1f 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -255,78 +255,6 @@ struct LoadStoreOpInterface
}
};
-/// Compute the linear index for the provided strided layout and indices.
-Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
- ArrayRef<OpFoldResult> strides,
- ArrayRef<OpFoldResult> indices) {
- auto [expr, values] = computeLinearIndex(offset, strides, indices);
- auto index =
- affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
- return getValueOrCreateConstantIndexOp(builder, loc, index);
-}
-
-/// Returns two Values representing the bounds of the provided strided layout
-/// metadata. The bounds are returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
- OpFoldResult offset,
- ArrayRef<OpFoldResult> strides,
- ArrayRef<OpFoldResult> sizes) {
- auto zeros = SmallVector<int64_t>(sizes.size(), 0);
- auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
- auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
- auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
- return {lowerBound, upperBound};
-}
-
-/// Returns two Values representing the bounds of the memref. The bounds are
-/// returned as a half open interval -- [low, high).
-std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
- TypedValue<BaseMemRefType> memref) {
- auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
- auto offset = runtimeMetadata.getConstifiedMixedOffset();
- auto strides = runtimeMetadata.getConstifiedMixedStrides();
- auto sizes = runtimeMetadata.getConstifiedMixedSizes();
- return computeLinearBounds(builder, loc, offset, strides, sizes);
-}
-
-/// Verifies that the linear bounds of a reinterpret_cast op are within the
-/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
-struct ReinterpretCastOpInterface
- : public RuntimeVerifiableOpInterface::ExternalModel<
- ReinterpretCastOpInterface, ReinterpretCastOp> {
- void generateRuntimeVerification(Operation *op, OpBuilder &builder,
- Location loc) const {
- auto reinterpretCast = cast<ReinterpretCastOp>(op);
- auto baseMemref = reinterpretCast.getSource();
- auto resultMemref =
- cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());
-
- builder.setInsertionPointAfter(op);
-
- // Compute the linear bounds of the base memref
- auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
-
- // Compute the linear bounds of the resulting memref
- auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
-
- // Check low >= baseLow
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, low, baseLow);
-
- // Check high <= baseHigh
- auto leHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sle, high, baseHigh);
-
- auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
-
- builder.create<cf::AssertOp>(
- loc, assertCond,
- RuntimeVerifiableOpInterface::generateErrorMessage(
- op,
- "result of reinterpret_cast is out-of-bounds of the base memref"));
- }
-};
-
struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -430,9 +358,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
- ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
+ // Note: There is nothing to verify for ReinterpretCastOp.
// Load additional dialects of which ops may get created.
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
deleted file mode 100644
index 601a53f4b5cd9..0000000000000
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ /dev/null
@@ -1,74 +0,0 @@
-// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -test-cf-assert \
-// RUN: -expand-strided-metadata \
-// RUN: -lower-affine \
-// RUN: -convert-to-llvm | \
-// RUN: mlir-runner -e main -entry-point-result=void \
-// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
-// RUN: FileCheck %s
-
-func.func @reinterpret_cast(%memref: memref<1xf32>, %offset: index) {
- memref.reinterpret_cast %memref to
- offset: [%offset],
- sizes: [1],
- strides: [1]
- : memref<1xf32> to memref<1xf32, strided<[1], offset: ?>>
- return
-}
-
-func.func @reinterpret_cast_fully_dynamic(%memref: memref<?xf32>, %offset: index, %size: index, %stride: index) {
- memref.reinterpret_cast %memref to
- offset: [%offset],
- sizes: [%size],
- strides: [%stride]
- : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
- return
-}
-
-func.func @main() {
- %0 = arith.constant 0 : index
- %1 = arith.constant 1 : index
- %n1 = arith.constant -1 : index
- %4 = arith.constant 4 : index
- %5 = arith.constant 5 : index
-
- %alloca_1 = memref.alloca() : memref<1xf32>
- %alloca_4 = memref.alloca() : memref<4xf32>
- %alloca_4_dyn = memref.cast %alloca_4 : memref<4xf32> to memref<?xf32>
-
- // Offset is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast(%alloca_1, %1) : (memref<1xf32>, index) -> ()
-
- // Offset is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast(%alloca_1, %n1) : (memref<1xf32>, index) -> ()
-
- // Size is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref<?xf32>, index, index, index) -> ()
-
- // Stride is out-of-bounds
- // CHECK: ERROR: Runtime op verification failed
- // CHECK-NEXT: "memref.reinterpret_cast"(%{{.*}})
- // CHECK-NEXT: ^ result of reinterpret_cast is out-of-bounds of the base memref
- // CHECK-NEXT: Location: loc({{.*}})
- func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref<?xf32>, index, index, index) -> ()
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @reinterpret_cast(%alloca_1, %0) : (memref<1xf32>, index) -> ()
-
- // CHECK-NOT: ERROR: Runtime op verification failed
- func.call @reinterpret_cast_fully_dynamic(%alloca_4_dyn, %0, %4, %1) : (memref<?xf32>, index, index, index) -> ()
-
- return
-}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is OK, but when we create a new memref, dont we want to verify that the strides specified dont make it such that accessing using strides goes out of bounds?
By "create a new memref" you mean the reinterpret_cast result, right? I'd say you don't want to verify this. Maybe you know that the source memref is a view into a larger allocation, so out-of-bounds access is safe. I think of it like a C++ reinterpret_cast, there you can cast to almost an arbitrary memref type. |
4603a89
to
0d3e9d8
Compare
The runtime verification code used to verify that the result of a `memref.reinterpret_cast` is in-bounds with respect to the source memref. This is incorrect: `memref.reinterpret_cast` allows users to construct almost arbitrary memref descriptors and there is no correctness expectation. This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of `memref.reinterpret_cast` does not verify in-bounds semantics either.
b52e2fd
to
63d85da
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, given the "trust me" nature of reinterpret_cast
, removing this makes sense
…st` (llvm#132547) The runtime verification code used to verify that the result of a `memref.reinterpret_cast` is in-bounds with respect to the source memref. This is incorrect: `memref.reinterpret_cast` allows users to construct almost arbitrary memref descriptors and there is no correctness expectation. This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of `memref.reinterpret_cast` does not verify in-bounds semantics either.
…have an implementation (#145230) Previously running `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` would crash because its implementation of the `RuntimeVerifiableOpInterface` was removed in #132547 but its associated entry in `declarePromisedInterface` was never removed. This causes an error when you try and run `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` that looks like ``` LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered. ``` as reported in #144028. In this PR I also added all the ops that do have implementations of this interface in `mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the `declarePromisedInterface` for consistency. Fixes #144028
…f ops that have an implementation (#145230) Previously running `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` would crash because its implementation of the `RuntimeVerifiableOpInterface` was removed in llvm/llvm-project#132547 but its associated entry in `declarePromisedInterface` was never removed. This causes an error when you try and run `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` that looks like ``` LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered. ``` as reported in llvm/llvm-project#144028. In this PR I also added all the ops that do have implementations of this interface in `mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the `declarePromisedInterface` for consistency. Fixes llvm/llvm-project#144028
…have an implementation (llvm#145230) Previously running `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` would crash because its implementation of the `RuntimeVerifiableOpInterface` was removed in llvm#132547 but its associated entry in `declarePromisedInterface` was never removed. This causes an error when you try and run `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` that looks like ``` LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered. ``` as reported in llvm#144028. In this PR I also added all the ops that do have implementations of this interface in `mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the `declarePromisedInterface` for consistency. Fixes llvm#144028
…have an implementation (llvm#145230) Previously running `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` would crash because its implementation of the `RuntimeVerifiableOpInterface` was removed in llvm#132547 but its associated entry in `declarePromisedInterface` was never removed. This causes an error when you try and run `-generate-runtime-verification` on an IR containing `memref.reinterpret_cast` that looks like ``` LLVM ERROR: checking for an interface (`mlir::RuntimeVerifiableOpInterface`) that was promised by dialect 'memref' but never implemented. This is generally an indication that the dialect extension implementing the interface was never registered. ``` as reported in llvm#144028. In this PR I also added all the ops that do have implementations of this interface in `mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp` to the `declarePromisedInterface` for consistency. Fixes llvm#144028
The runtime verification code used to verify that the result of a
memref.reinterpret_cast
is in-bounds with respect to the source memref. This is incorrect:memref.reinterpret_cast
allows users to construct almost arbitrary memref descriptors and there is no correctness expectation.This op is supposed to be used when the user "knows what they are doing." Similarly, the static verifier of
memref.reinterpret_cast
does not verify in-bounds semantics either.Depends on #132545.