Skip to content

Commit 18f917b

Browse files
[mlir][memref] Add runtime verification for memref.atomic_rmw
1 parent 53a395f commit 18f917b

File tree

3 files changed

+116
-19
lines changed

3 files changed

+116
-19
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3535
return inBounds;
3636
}
3737

38+
/// Generate a runtime check to see if the given indices are in-bounds with
39+
/// respect to the given ranked memref.
40+
Value generateIndicesInBoundsCheck(OpBuilder &builder, Location loc,
41+
Value memref, ValueRange indices) {
42+
auto memrefType = cast<MemRefType>(memref.getType());
43+
assert(memrefType.getRank() == static_cast<int64_t>(indices.size()) &&
44+
"rank mismatch");
45+
Value cond = builder.create<arith::ConstantOp>(
46+
loc, builder.getIntegerAttr(builder.getI1Type(), 1));
47+
48+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
49+
for (auto [dim, idx] : llvm::enumerate(indices)) {
50+
Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, dim);
51+
Value inBounds = generateInBoundsCheck(builder, loc, idx, zero, dimOp);
52+
cond = builder.createOrFold<arith::AndIOp>(loc, cond, inBounds);
53+
}
54+
55+
return cond;
56+
}
57+
3858
struct AssumeAlignmentOpInterface
3959
: public RuntimeVerifiableOpInterface::ExternalModel<
4060
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -230,26 +250,10 @@ struct LoadStoreOpInterface
230250
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
231251
Location loc) const {
232252
auto loadStoreOp = cast<LoadStoreOp>(op);
233-
234-
auto memref = loadStoreOp.getMemref();
235-
auto rank = memref.getType().getRank();
236-
if (rank == 0) {
237-
return;
238-
}
239-
auto indices = loadStoreOp.getIndices();
240-
241-
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
242-
Value assertCond;
243-
for (auto i : llvm::seq<int64_t>(0, rank)) {
244-
Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
245-
Value inBounds =
246-
generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
247-
assertCond =
248-
i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
249-
: inBounds;
250-
}
253+
Value cond = generateIndicesInBoundsCheck(
254+
builder, loc, loadStoreOp.getMemref(), loadStoreOp.getIndices());
251255
builder.create<cf::AssertOp>(
252-
loc, assertCond,
256+
loc, cond,
253257
RuntimeVerifiableOpInterface::generateErrorMessage(
254258
op, "out-of-bounds access"));
255259
}
@@ -421,10 +425,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
421425
DialectRegistry &registry) {
422426
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
423427
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
428+
AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
424429
CastOp::attachInterface<CastOpInterface>(*ctx);
425430
CopyOp::attachInterface<CopyOpInterface>(*ctx);
426431
DimOp::attachInterface<DimOpInterface>(*ctx);
427432
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
433+
GenericAtomicRMWOp::attachInterface<
434+
LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
428435
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
429436
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
430437
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -test-cf-assert \
3+
// RUN: -expand-strided-metadata \
4+
// RUN: -lower-affine \
5+
// RUN: -convert-to-llvm | \
6+
// RUN: mlir-runner -e main -entry-point-result=void \
7+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
8+
// RUN: FileCheck %s
9+
10+
func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
11+
%cst = arith.constant 1.0 : f32
12+
memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32
13+
return
14+
}
15+
16+
func.func @main() {
17+
// Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
18+
// necessary because "-test-cf-assert" does not abort the program and we do
19+
// not want to segfault when running the test case.
20+
%alloc = memref.alloca() : memref<10xf32>
21+
%ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
22+
%ptr_i64 = arith.index_cast %ptr : index to i64
23+
%ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
24+
%c0 = llvm.mlir.constant(0 : index) : i64
25+
%c1 = llvm.mlir.constant(1 : index) : i64
26+
%c5 = llvm.mlir.constant(5 : index) : i64
27+
%4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
28+
%5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
29+
%6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
30+
%8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
31+
%9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
32+
%10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
33+
%buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
34+
%cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
35+
36+
// CHECK: ERROR: Runtime op verification failed
37+
// CHECK-NEXT: "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}) <{kind = 0 : i64}> : (f32, memref<?xf32>, index) -> f32
38+
// CHECK-NEXT: ^ out-of-bounds access
39+
// CHECK-NEXT: Location: loc({{.*}})
40+
%c9 = arith.constant 9 : index
41+
func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
42+
43+
return
44+
}
45+
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -test-cf-assert \
3+
// RUN: -expand-strided-metadata \
4+
// RUN: -lower-affine \
5+
// RUN: -convert-to-llvm | \
6+
// RUN: mlir-runner -e main -entry-point-result=void \
7+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
8+
// RUN: FileCheck %s
9+
10+
func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
11+
%cst = arith.constant 1.0 : f32
12+
memref.store %cst, %memref[%index] : memref<?xf32>
13+
return
14+
}
15+
16+
func.func @main() {
17+
// Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is
18+
// necessary because "-test-cf-assert" does not abort the program and we do
19+
// not want to segfault when running the test case.
20+
%alloc = memref.alloca() : memref<10xf32>
21+
%ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index
22+
%ptr_i64 = arith.index_cast %ptr : index to i64
23+
%ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr
24+
%c0 = llvm.mlir.constant(0 : index) : i64
25+
%c1 = llvm.mlir.constant(1 : index) : i64
26+
%c5 = llvm.mlir.constant(5 : index) : i64
27+
%4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
28+
%5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
29+
%6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
30+
%8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
31+
%9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
32+
%10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
33+
%buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32>
34+
%cast = memref.cast %buffer : memref<5xf32> to memref<?xf32>
35+
36+
// CHECK: ERROR: Runtime op verification failed
37+
// CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<?xf32>, index) -> ()
38+
// CHECK-NEXT: ^ out-of-bounds access
39+
// CHECK-NEXT: Location: loc({{.*}})
40+
%c9 = arith.constant 9 : index
41+
func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> ()
42+
43+
return
44+
}
45+

0 commit comments

Comments
 (0)