Skip to content

[mlir][memref] Remove runtime verification for memref.reinterpret_cast #132547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 1 addition & 73 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,78 +255,6 @@ struct LoadStoreOpInterface
}
};

/// Compute the linear index for the provided strided layout and indices.
Value computeLinearIndex(OpBuilder &builder, Location loc, OpFoldResult offset,
ArrayRef<OpFoldResult> strides,
ArrayRef<OpFoldResult> indices) {
auto [expr, values] = computeLinearIndex(offset, strides, indices);
auto index =
affine::makeComposedFoldedAffineApply(builder, loc, expr, values);
return getValueOrCreateConstantIndexOp(builder, loc, index);
}

/// Returns two Values representing the bounds of the provided strided layout
/// metadata. The bounds are returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
OpFoldResult offset,
ArrayRef<OpFoldResult> strides,
ArrayRef<OpFoldResult> sizes) {
auto zeros = SmallVector<int64_t>(sizes.size(), 0);
auto indices = getAsIndexOpFoldResult(builder.getContext(), zeros);
auto lowerBound = computeLinearIndex(builder, loc, offset, strides, indices);
auto upperBound = computeLinearIndex(builder, loc, offset, strides, sizes);
return {lowerBound, upperBound};
}

/// Returns two Values representing the bounds of the memref. The bounds are
/// returned as a half open interval -- [low, high).
std::pair<Value, Value> computeLinearBounds(OpBuilder &builder, Location loc,
TypedValue<BaseMemRefType> memref) {
auto runtimeMetadata = builder.create<ExtractStridedMetadataOp>(loc, memref);
auto offset = runtimeMetadata.getConstifiedMixedOffset();
auto strides = runtimeMetadata.getConstifiedMixedStrides();
auto sizes = runtimeMetadata.getConstifiedMixedSizes();
return computeLinearBounds(builder, loc, offset, strides, sizes);
}

/// Verifies that the linear bounds of a reinterpret_cast op are within the
/// linear bounds of the base memref: low >= baseLow && high <= baseHigh
struct ReinterpretCastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<
ReinterpretCastOpInterface, ReinterpretCastOp> {
void generateRuntimeVerification(Operation *op, OpBuilder &builder,
Location loc) const {
auto reinterpretCast = cast<ReinterpretCastOp>(op);
auto baseMemref = reinterpretCast.getSource();
auto resultMemref =
cast<TypedValue<BaseMemRefType>>(reinterpretCast.getResult());

builder.setInsertionPointAfter(op);

// Compute the linear bounds of the base memref
auto [baseLow, baseHigh] = computeLinearBounds(builder, loc, baseMemref);

// Compute the linear bounds of the resulting memref
auto [low, high] = computeLinearBounds(builder, loc, resultMemref);

// Check low >= baseLow
auto geLow = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, low, baseLow);

// Check high <= baseHigh
auto leHigh = builder.createOrFold<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, high, baseHigh);

auto assertCond = builder.createOrFold<arith::AndIOp>(loc, geLow, leHigh);

builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op,
"result of reinterpret_cast is out-of-bounds of the base memref"));
}
};

struct SubViewOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
Expand Down Expand Up @@ -431,9 +359,9 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx);
SubViewOp::attachInterface<SubViewOpInterface>(*ctx);
// Note: There is nothing to verify for ReinterpretCastOp.

// Load additional dialects of which ops may get created.
ctx->loadDialect<affine::AffineDialect, arith::ArithDialect,
Expand Down

This file was deleted.