@@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
35
35
return inBounds;
36
36
}
37
37
38
+ // / Generate a runtime check to see if the given indices are in-bounds with
39
+ // / respect to the given ranked memref.
40
+ Value generateIndicesInBoundsCheck (OpBuilder &builder, Location loc,
41
+ Value memref, ValueRange indices) {
42
+ auto memrefType = cast<MemRefType>(memref.getType ());
43
+ assert (memrefType.getRank () == static_cast <int64_t >(indices.size ()) &&
44
+ " rank mismatch" );
45
+ Value cond = builder.create <arith::ConstantOp>(
46
+ loc, builder.getIntegerAttr (builder.getI1Type (), 1 ));
47
+
48
+ auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
49
+ for (auto [dim, idx] : llvm::enumerate (indices)) {
50
+ Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, dim);
51
+ Value inBounds = generateInBoundsCheck (builder, loc, idx, zero, dimOp);
52
+ cond = builder.createOrFold <arith::AndIOp>(loc, cond, inBounds);
53
+ }
54
+
55
+ return cond;
56
+ }
57
+
38
58
struct AssumeAlignmentOpInterface
39
59
: public RuntimeVerifiableOpInterface::ExternalModel<
40
60
AssumeAlignmentOpInterface, AssumeAlignmentOp> {
@@ -186,26 +206,10 @@ struct LoadStoreOpInterface
186
206
void generateRuntimeVerification (Operation *op, OpBuilder &builder,
187
207
Location loc) const {
188
208
auto loadStoreOp = cast<LoadStoreOp>(op);
189
-
190
- auto memref = loadStoreOp.getMemref ();
191
- auto rank = memref.getType ().getRank ();
192
- if (rank == 0 ) {
193
- return ;
194
- }
195
- auto indices = loadStoreOp.getIndices ();
196
-
197
- auto zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
198
- Value assertCond;
199
- for (auto i : llvm::seq<int64_t >(0 , rank)) {
200
- Value dimOp = builder.createOrFold <memref::DimOp>(loc, memref, i);
201
- Value inBounds =
202
- generateInBoundsCheck (builder, loc, indices[i], zero, dimOp);
203
- assertCond =
204
- i > 0 ? builder.createOrFold <arith::AndIOp>(loc, assertCond, inBounds)
205
- : inBounds;
206
- }
209
+ Value cond = generateIndicesInBoundsCheck (
210
+ builder, loc, loadStoreOp.getMemref (), loadStoreOp.getIndices ());
207
211
builder.create <cf::AssertOp>(
208
- loc, assertCond ,
212
+ loc, cond ,
209
213
RuntimeVerifiableOpInterface::generateErrorMessage (
210
214
op, " out-of-bounds access" ));
211
215
}
@@ -377,9 +381,12 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
377
381
DialectRegistry ®istry) {
378
382
registry.addExtension (+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
379
383
AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx);
384
+ AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx);
380
385
CastOp::attachInterface<CastOpInterface>(*ctx);
381
386
DimOp::attachInterface<DimOpInterface>(*ctx);
382
387
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
388
+ GenericAtomicRMWOp::attachInterface<
389
+ LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx);
383
390
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
384
391
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
385
392
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
0 commit comments