Skip to content

Commit 7cf452f

Browse files
[mlir][memref] Add runtime verification for memref.assume_alignment
1 parent 6883972 commit 7cf452f

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using namespace mlir;
2323
namespace mlir {
2424
namespace memref {
2525
namespace {
26-
/// Generate a runtime check for lb <= value < ub.
26+
/// Generate a runtime check for lb <= value < ub.
2727
Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
2828
Value lb, Value ub) {
2929
Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
@@ -35,6 +35,28 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
3535
return inBounds;
3636
}
3737

38+
struct AssumeAlignmentOpInterface
39+
: public RuntimeVerifiableOpInterface::ExternalModel<
40+
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
41+
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
42+
Location loc) const {
43+
auto assumeOp = cast<AssumeAlignmentOp>(op);
44+
Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>(
45+
loc, assumeOp.getMemref());
46+
Value rest = builder.create<arith::RemUIOp>(
47+
loc, ptr,
48+
builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment()));
49+
Value isAligned = builder.create<arith::CmpIOp>(
50+
loc, arith::CmpIPredicate::eq, rest,
51+
builder.create<arith::ConstantIndexOp>(loc, 0));
52+
builder.create<cf::AssertOp>(
53+
loc, isAligned,
54+
RuntimeVerifiableOpInterface::generateErrorMessage(
55+
op, "memref is not aligned to " +
56+
std::to_string(assumeOp.getAlignment())));
57+
}
58+
};
59+
3860
struct CastOpInterface
3961
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
4062
CastOp> {
@@ -354,6 +376,7 @@ struct ExpandShapeOpInterface
354376
void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
355377
DialectRegistry &registry) {
356378
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
379+
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
357380
CastOp::attachInterface<CastOpInterface>(*ctx);
358381
DimOp::attachInterface<DimOpInterface>(*ctx);
359382
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s -generate-runtime-verification \
2+
// RUN: -expand-strided-metadata \
3+
// RUN: -test-cf-assert \
4+
// RUN: -convert-to-llvm | \
5+
// RUN: mlir-runner -e main -entry-point-result=void \
6+
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
7+
// RUN: FileCheck %s
8+
9+
func.func @main() {
10+
// This buffer is properly aligned. There should be no error.
11+
// CHECK-NOT: ^ memref is not aligned to 8
12+
%alloc = memref.alloca() : memref<5xf64>
13+
memref.assume_alignment %alloc, 8 : memref<5xf64>
14+
15+
// Construct a memref descriptor with a pointer that is not aligned to 4.
16+
// This cannot be done with just the memref dialect. We have to resort to
17+
// the LLVM dialect.
18+
%c0 = llvm.mlir.constant(0 : index) : i64
19+
%c1 = llvm.mlir.constant(1 : index) : i64
20+
%c3 = llvm.mlir.constant(3 : index) : i64
21+
%unaligned_ptr = llvm.inttoptr %c3 : i64 to !llvm.ptr
22+
%4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
23+
%5 = llvm.insertvalue %unaligned_ptr, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
24+
%6 = llvm.insertvalue %unaligned_ptr, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
25+
%8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
26+
%9 = llvm.insertvalue %c1, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
27+
%10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
28+
%buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<1xf32>
29+
30+
// CHECK: ERROR: Runtime op verification failed
31+
// CHECK-NEXT: "memref.assume_alignment"(%{{.*}}) <{alignment = 4 : i32}> : (memref<1xf32>) -> ()
32+
// CHECK-NEXT: ^ memref is not aligned to 4
33+
// CHECK-NEXT: Location: loc({{.*}})
34+
memref.assume_alignment %buffer, 4 : memref<1xf32>
35+
36+
return
37+
}

0 commit comments

Comments
 (0)