Skip to content

Revert "[mlir][linalg] Add runtime verification for linalg ops" #89780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

2 changes: 0 additions & 2 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
Expand Down Expand Up @@ -162,7 +161,6 @@ inline void registerAllDialects(DialectRegistry &registry) {
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
"::mlir::Location":$loc)
>,
];

let extraClassDeclaration = [{
/// Generate the error message that will be printed to the user when
/// verification fails.
static std::string generateErrorMessage(Operation *op, const std::string &msg);
}];
}

#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
1 change: 0 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
RuntimeOpVerification.cpp
Specialize.cpp
Split.cpp
SplitReduction.cpp
Expand Down
135 changes: 0 additions & 135 deletions mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp

This file was deleted.

54 changes: 33 additions & 21 deletions mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,25 @@

using namespace mlir;

/// Generate an error message string for the given op and the specified error.
static std::string generateErrorMessage(Operation *op, const std::string &msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
// We may generate a lot of error messages and so we need to ensure the
// printing is fast.
flags.elideLargeElementsAttrs();
flags.printGenericOpForm();
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
op->print(stream, flags);
stream << "\n^ " << msg;
stream << "\nLocation: ";
op->getLoc().print(stream);
return stream.str();
}

namespace mlir {
namespace memref {
namespace {
Expand All @@ -43,10 +62,8 @@ struct CastOpInterface
builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
Value isSameRank = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, srcRank, resultRank);
builder.create<cf::AssertOp>(
loc, isSameRank,
RuntimeVerifiableOpInterface::generateErrorMessage(op,
"rank mismatch"));
builder.create<cf::AssertOp>(loc, isSameRank,
generateErrorMessage(op, "rank mismatch"));
}

// Get source offset and strides. We do not have an op to get offsets and
Expand Down Expand Up @@ -84,8 +101,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
builder.create<cf::AssertOp>(
loc, isSameSz,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "size mismatch of dim " + std::to_string(it.index())));
generateErrorMessage(op, "size mismatch of dim " +
std::to_string(it.index())));
}

// Get result offset and strides.
Expand All @@ -102,10 +119,8 @@ struct CastOpInterface
builder.create<arith::ConstantIndexOp>(loc, resultOffset);
Value isSameOffset = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
builder.create<cf::AssertOp>(
loc, isSameOffset,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "offset mismatch"));
builder.create<cf::AssertOp>(loc, isSameOffset,
generateErrorMessage(op, "offset mismatch"));
}

// Check strides.
Expand All @@ -122,8 +137,8 @@ struct CastOpInterface
loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
builder.create<cf::AssertOp>(
loc, isSameStride,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "stride mismatch of dim " + std::to_string(it.index())));
generateErrorMessage(op, "stride mismatch of dim " +
std::to_string(it.index())));
}
}
};
Expand Down Expand Up @@ -163,9 +178,7 @@ struct LoadStoreOpInterface
: andOp;
}
builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "out-of-bounds access"));
loc, assertCond, generateErrorMessage(op, "out-of-bounds access"));
}
};

Expand Down Expand Up @@ -235,7 +248,7 @@ struct ReinterpretCastOpInterface

builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
generateErrorMessage(
op,
"result of reinterpret_cast is out-of-bounds of the base memref"));
}
Expand Down Expand Up @@ -280,8 +293,8 @@ struct SubViewOpInterface

builder.create<cf::AssertOp>(
loc, assertCond,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "subview is out-of-bounds of the base memref"));
generateErrorMessage(op,
"subview is out-of-bounds of the base memref"));
}
};

Expand Down Expand Up @@ -321,9 +334,8 @@ struct ExpandShapeOpInterface
builder.create<arith::ConstantIndexOp>(loc, 0));
builder.create<cf::AssertOp>(
loc, isModZero,
RuntimeVerifiableOpInterface::generateErrorMessage(
op, "static result dims in reassoc group do not "
"divide src dim evenly"));
generateErrorMessage(op, "static result dims in reassoc group do not "
"divide src dim evenly"));
}
}
};
Expand Down
22 changes: 0 additions & 22 deletions mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,6 @@
namespace mlir {
class Location;
class OpBuilder;

/// Generate an error message string for the given op and the specified error.
std::string
RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op,
const std::string &msg) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
OpPrintingFlags flags;
// We may generate a lot of error messages and so we need to ensure the
// printing is fast.
flags.elideLargeElementsAttrs();
flags.printGenericOpForm();
flags.skipRegions();
flags.useLocalScope();
stream << "ERROR: Runtime op verification failed\n";
op->print(stream, flags);
stream << "\n^ " << msg;
stream << "\nLocation: ";
op->getLoc().print(stream);
return stream.str();
}

} // namespace mlir

/// Include the definitions of the interface.
Expand Down
43 changes: 0 additions & 43 deletions mlir/test/Dialect/Linalg/runtime-verification.mlir

This file was deleted.

Loading