Skip to content

Commit 0d3e9d8

Browse files
[mlir][memref] Improve runtime verification for memref.subview
This commit addresses a TODO in the runtime verification of `memref.subview`. Each dimension is now verified: the offset must be in-bounds and the slice must not run out-of-bounds.
1 parent 799e905 commit 0d3e9d8

File tree

2 files changed

+66
-49
lines changed

2 files changed

+66
-49
lines changed

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -327,47 +327,52 @@ struct ReinterpretCastOpInterface
327327
}
328328
};
329329

330-
/// Verifies that the linear bounds of a subview op are within the linear bounds
331-
/// of the base memref: low >= baseLow && high <= baseHigh
332-
/// TODO: This is not yet a full runtime verification of subview. For example,
333-
/// consider:
334-
/// %m = memref.alloc(%c10, %c10) : memref<10x10xf32>
335-
/// memref.subview %m[%c0, %c0][%c20, %c2][%c1, %c1]
336-
/// : memref<?x?xf32> to memref<?x?xf32>
337-
/// The subview is in-bounds of the entire base memref but the first dimension
338-
/// is out-of-bounds. Future work would verify the bounds on a per-dimension
339-
/// basis.
340330
struct SubViewOpInterface
341331
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
342332
SubViewOp> {
343333
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
344334
Location loc) const {
345335
auto subView = cast<SubViewOp>(op);
346-
auto baseMemref = cast<TypedValue<BaseMemRefType>>(subView.getSource());
347-
auto resultMemref = cast<TypedValue<BaseMemRefType>>(subView.getResult());
336+
MemRefType sourceType = subView.getSource().getType();
348337

349-
builder.setInsertionPointAfter(op);
350-
351-
// Compute the linear bounds of the base memref
352-
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);
353-
354-
// Compute the linear bounds of the resulting memref
355-
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);
356-
357-
// Check low >= baseLow
358-
auto geLow = builder.createOrFold<arith::CmpIOp>(
359-
loc, arith::CmpIPredicate::sge, low, baseLow);
360-
361-
// Check high <= baseHigh
362-
auto leHigh = builder.createOrFold<arith::CmpIOp>(
363-
loc, arith::CmpIPredicate::sle, high, baseHigh);
364-
365-
auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);
366-
367-
builder.create<cf::AssertOp>(
368-
loc, assertCond,
369-
RuntimeVerifiableOpInterface::generateErrorMessage(
370-
op, "subview is out-of-bounds of the base memref"));
338+
// For each dimension, assert that:
339+
// 0 <= offset < dim_size
340+
// 0 <= offset + (size - 1) * stride < dim_size
341+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
342+
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
343+
auto metadataOp =
344+
builder.create<ExtractStridedMetadataOp>(loc, subView.getSource());
345+
for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
346+
Value offset = getValueOrCreateConstantIndexOp(
347+
builder, loc, subView.getMixedOffsets()[i]);
348+
Value size = getValueOrCreateConstantIndexOp(builder, loc,
349+
subView.getMixedSizes()[i]);
350+
Value stride = getValueOrCreateConstantIndexOp(
351+
builder, loc, subView.getMixedStrides()[i]);
352+
353+
// Verify that offset is in-bounds.
354+
Value dimSize = metadataOp.getSizes()[i];
355+
Value offsetInBounds =
356+
generateInBoundsCheck(builder, loc, offset, zero, dimSize);
357+
builder.create<cf::AssertOp>(
358+
loc, offsetInBounds,
359+
RuntimeVerifiableOpInterface::generateErrorMessage(
360+
op, "offset " + std::to_string(i) + " is out-of-bounds"));
361+
362+
// Verify that slice does not run out-of-bounds.
363+
Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one);
364+
Value sizeMinusOneTimesStride =
365+
builder.create<arith::MulIOp>(loc, sizeMinusOne, stride);
366+
Value lastPos =
367+
builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride);
368+
Value lastPosInBounds =
369+
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
370+
builder.create<cf::AssertOp>(
371+
loc, lastPosInBounds,
372+
RuntimeVerifiableOpInterface::generateErrorMessage(
373+
op, "Subview runs out-of-bounds along dimension" +
374+
std::to_string(i)));
375+
}
371376
}
372377
};
373378

mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,38 +39,50 @@ func.func @main() {
3939
%alloca_4 = memref.alloca() : memref<4x4xf32>
4040
%alloca_4_dyn = memref.cast %alloca_4 : memref<4x4xf32> to memref<?x4xf32>
4141

42-
// Offset is out-of-bounds
42+
// Offset is out-of-bounds and slice runs out-of-bounds
4343
// CHECK: ERROR: Runtime op verification failed
44-
// CHECK-NEXT: "memref.subview"
45-
// CHECK-NEXT: ^ subview is out-of-bounds of the base memref
44+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?xf32, strided<[?], offset: ?>>
45+
// CHECK-NEXT: ^ offset 0 is out-of-bounds
46+
// CHECK-NEXT: Location: loc({{.*}})
47+
// CHECK: ERROR: Runtime op verification failed
48+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 1>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?xf32, strided<[?], offset: ?>>
49+
// CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
4650
// CHECK-NEXT: Location: loc({{.*}})
4751
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %5, %5, %1) : (memref<?x4xf32>, index, index, index) -> ()
4852

49-
// Offset is out-of-bounds
53+
// Offset is out-of-bounds and slice runs out-of-bounds
54+
// CHECK: ERROR: Runtime op verification failed
55+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
56+
// CHECK-NEXT: ^ offset 0 is out-of-bounds
57+
// CHECK-NEXT: Location: loc({{.*}})
5058
// CHECK: ERROR: Runtime op verification failed
51-
// CHECK-NEXT: "memref.subview"
52-
// CHECK-NEXT: ^ subview is out-of-bounds of the base memref
59+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
60+
// CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
5361
// CHECK-NEXT: Location: loc({{.*}})
5462
func.call @subview(%alloca, %1) : (memref<1xf32>, index) -> ()
5563

56-
// Offset is out-of-bounds
64+
// Offset is out-of-bounds and slice runs out-of-bounds
65+
// CHECK: ERROR: Runtime op verification failed
66+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
67+
// CHECK-NEXT: ^ offset 0 is out-of-bounds
68+
// CHECK-NEXT: Location: loc({{.*}})
5769
// CHECK: ERROR: Runtime op verification failed
58-
// CHECK-NEXT: "memref.subview"
59-
// CHECK-NEXT: ^ subview is out-of-bounds of the base memref
70+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>, static_offsets = array<i64: -9223372036854775808>, static_sizes = array<i64: 1>, static_strides = array<i64: 1>}> : (memref<1xf32>, index) -> memref<1xf32, strided<[1], offset: ?>>
71+
// CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
6072
// CHECK-NEXT: Location: loc({{.*}})
6173
func.call @subview(%alloca, %n1) : (memref<1xf32>, index) -> ()
6274

63-
// Size is out-of-bounds
75+
// Slice runs out-of-bounds due to size
6476
// CHECK: ERROR: Runtime op verification failed
65-
// CHECK-NEXT: "memref.subview"
66-
// CHECK-NEXT: ^ subview is out-of-bounds of the base memref
77+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 4>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?x4xf32, strided<[?, 1], offset: ?>>
78+
// CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
6779
// CHECK-NEXT: Location: loc({{.*}})
6880
func.call @subview_dynamic(%alloca_4_dyn, %0, %5, %1) : (memref<?x4xf32>, index, index, index) -> ()
6981

70-
// Stride is out-of-bounds
82+
// Slice runs out-of-bounds due to stride
7183
// CHECK: ERROR: Runtime op verification failed
72-
// CHECK-NEXT: "memref.subview"
73-
// CHECK-NEXT: ^ subview is out-of-bounds of the base memref
84+
// CHECK-NEXT: "memref.subview"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>, static_offsets = array<i64: -9223372036854775808, 0>, static_sizes = array<i64: -9223372036854775808, 4>, static_strides = array<i64: -9223372036854775808, 1>}> : (memref<?x4xf32>, index, index, index) -> memref<?x4xf32, strided<[?, 1], offset: ?>>
85+
// CHECK-NEXT: ^ subview runs out-of-bounds along dimension 0
7486
// CHECK-NEXT: Location: loc({{.*}})
7587
func.call @subview_dynamic(%alloca_4_dyn, %0, %4, %4) : (memref<?x4xf32>, index, index, index) -> ()
7688

0 commit comments

Comments
 (0)