@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
35
35
return inBounds;
36
36
}
37
37
38
+ // / Generate a runtime check to see if the given indices are in-bounds with
39
+ // / respect to the given ranked memref.
40
+ Value generateIndicesInBoundsCheck (OpBuilder &builder, Location loc,
41
+ Value memref, ValueRange indices) {
42
+ auto memrefType = cast<MemRefType>(memref.getType ());
43
+ assert (memrefType.getRank () == static_cast <int64_t >(indices.size ()) &&
44
+ " rank mismatch" );
45
+ Value cond = builder.create <arith::ConstantOp>(
46
+ loc, builder.getIntegerAttr (builder.getI1Type (), 1 ));
47
+
48
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
49
+ for (auto [dim, idx] : llvm::enumerate (indices)) {
50
+ Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, dim);
51
+ Value inBounds = generateInBoundsCheck (builder, loc, idx, zero, dimOp);
52
+ cond = builder.createOrFold <arith::AndIOp>(loc, cond, inBounds);
53
+ }
54
+
55
+ return cond;
56
+ }
57
+
38
58
struct AssumeAlignmentOpInterface
39
59
: public RuntimeVerifiableOpInterface::ExternalModel<
40
60
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -230,26 +250,10 @@ struct LoadStoreOpInterface
230
250
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
231
251
Location loc) const {
232
252
auto loadStoreOp = cast<LoadStoreOp>(op);
233
-
234
- auto memref = loadStoreOp.getMemref ();
235
- auto rank = memref.getType ().getRank ();
236
- if (rank == 0 ) {
237
- return ;
238
- }
239
- auto indices = loadStoreOp.getIndices ();
240
-
241
- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
242
- Value assertCond;
243
- for (auto i : llvm::seq<int64_t >(0 , rank)) {
244
- Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, i);
245
- Value inBounds =
246
- generateInBoundsCheck (builder, loc, indices[i], zero, dimOp);
247
- assertCond =
248
- i > 0 ? builder.createOrFold <arith::AndIOp>(loc, assertCond, inBounds)
249
- : inBounds;
250
- }
253
+ Value cond = generateIndicesInBoundsCheck (
254
+ builder, loc, loadStoreOp.getMemref (), loadStoreOp.getIndices ());
251
255
builder.create <cf::AssertOp>(
252
- loc, assertCond ,
256
+ loc, cond ,
253
257
RuntimeVerifiableOpInterface::generateErrorMessage (
254
258
op, " out-of-bounds access" ));
255
259
}
@@ -421,10 +425,13 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
421
425
DialectRegistry ®istry) {
422
426
registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
423
427
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
428
+ AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
424
429
CastOp::attachInterface<CastOpInterface>(*ctx);
425
430
CopyOp::attachInterface<CopyOpInterface>(*ctx);
426
431
DimOp::attachInterface<DimOpInterface>(*ctx);
427
432
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
433
+ GenericAtomicRMWOp::attachInterface<
434
+ LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
428
435
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
429
436
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
430
437
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
0 commit comments