Skip to content

Commit 30ed24f

Browse files
matthias-springerAnkur-0429
authored andcommitted
[mlir][memref] Add runtime verification for memref.atomic_rmw (llvm#130414)
Implement runtime verification for `memref.atomic_rmw` and `memref.generic_atomic_rmw`. Also add a missing test for `memref.store`.
1 parent dc76b45 commit 30ed24f

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
426426
DialectRegistry &registry) {
427427
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
428428
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
429+
AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
429430
CastOp::attachInterface<CastOpInterface>(*ctx);
430431
CopyOp::attachInterface<CopyOpInterface>(*ctx);
431432
DimOp::attachInterface<DimOpInterface>(*ctx);
432433
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
434+
GenericAtomicRMWOp::attachInterface<
435+
LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
433436
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
434437
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
435438
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -test-cf-assert \
3+
// RUN: -convert-to-llvm | \
4+
// RUN: mlir-runner -e main -entry-point-result=void \
5+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
6+
// RUN: FileCheck %s
7+
8+
func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
9+
%cst = arith.constant 1.0 : f32
10+
memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32
11+
return
12+
}
13+
14+
func.func @main() {
15+
// Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
16+
// necessary because "-test-cf-assert" does not abort the program and we do
17+
// not want to segfault when running the test case.
18+
%alloc = memref.alloca() : memref<10xf32>
19+
%ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
20+
%ptr_i64 = arith.index_cast %ptr : index to i64
21+
%ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
22+
%c0 = llvm.mlir.constant(0 : index) : i64
23+
%c1 = llvm.mlir.constant(1 : index) : i64
24+
%c5 = llvm.mlir.constant(5 : index) : i64
25+
%4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
26+
%5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
27+
%6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
28+
%8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
29+
%9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
30+
%10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
31+
%buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
32+
%cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
33+
34+
// CHECK: ERROR: Runtime op verification failed
35+
// CHECK-NEXT: "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}) <{kind = 0 : i64}> : (f32, memref<?xf32>, index) -> f32
36+
// CHECK-NEXT: ^ out-of-bounds access
37+
// CHECK-NEXT: Location: loc({{.*}})
38+
%c9 = arith.constant 9 : index
39+
func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
40+
41+
return
42+
}
43+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -test-cf-assert \
3+
// RUN: -convert-to-llvm | \
4+
// RUN: mlir-runner -e main -entry-point-result=void \
5+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
6+
// RUN: FileCheck %s
7+
8+
func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
9+
%cst = arith.constant 1.0 : f32
10+
memref.store %cst, %memref[%index] : memref<?xf32>
11+
return
12+
}
13+
14+
func.func @main() {
15+
// Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
16+
// necessary because "-test-cf-assert" does not abort the program and we do
17+
// not want to segfault when running the test case.
18+
%alloc = memref.alloca() : memref<10xf32>
19+
%ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
20+
%ptr_i64 = arith.index_cast %ptr : index to i64
21+
%ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
22+
%c0 = llvm.mlir.constant(0 : index) : i64
23+
%c1 = llvm.mlir.constant(1 : index) : i64
24+
%c5 = llvm.mlir.constant(5 : index) : i64
25+
%4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
26+
%5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
27+
%6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
28+
%8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
29+
%9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
30+
%10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
31+
%buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
32+
%cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
33+
34+
// CHECK: ERROR: Runtime op verification failed
35+
// CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<?xf32>, index) -> ()
36+
// CHECK-NEXT: ^ out-of-bounds access
37+
// CHECK-NEXT: Location: loc({{.*}})
38+
%c9 = arith.constant 9 : index
39+
func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
40+
41+
return
42+
}
43+

0 commit comments

Comments
 (0)