Skip to content

[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig #83754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Transforms/MemoryAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ class MemoryAllocationOpt
return keepStackAllocation(alloca, &func.front(), options);
});

patterns.insert<AllocaOpConversion>(context, analysis.getReturns(func));
llvm::SmallVector<mlir::Operation *> returnOps = analysis.getReturns(func);
patterns.insert<AllocaOpConversion>(context, returnOps);
if (mlir::failed(
mlir::applyPartialConversion(func, target, std::move(patterns)))) {
mlir::emitError(func.getLoc(),
Expand Down
72 changes: 45 additions & 27 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace mlir {
// Forward declarations.
class Attribute;
class Block;
struct ConversionConfig;
class ConversionPatternRewriter;
class MLIRContext;
class Operation;
Expand Down Expand Up @@ -767,7 +768,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Conversion pattern rewriters must not be used outside of dialect
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
explicit ConversionPatternRewriter(MLIRContext *ctx);
explicit ConversionPatternRewriter(MLIRContext *ctx,
const ConversionConfig &config);

// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
Expand Down Expand Up @@ -1067,6 +1069,30 @@ class PDLConversionConfig final {

#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

//===----------------------------------------------------------------------===//
// ConversionConfig
//===----------------------------------------------------------------------===//

/// Dialect conversion configuration.
struct ConversionConfig {
/// An optional callback used to notify about match failure diagnostics during
/// the conversion. Diagnostics reported to this callback may only be
/// available in debug mode.
function_ref<void(Diagnostic &)> notifyCallback = nullptr;

/// Partial conversion only. All operations that are found not to be
/// legalizable are placed in this set. (Note that if there is an op
/// explicitly marked as illegal, the conversion terminates and the set will
/// not necessarily be complete.)
DenseSet<Operation *> *unlegalizedOps = nullptr;

/// Analysis conversion only. All operations that are found to be legalizable
/// are placed in this set. Note that no actual rewrites are applied to the
/// IR during an analysis conversion and only pre-existing operations are
/// added to the set.
DenseSet<Operation *> *legalizableOps = nullptr;
};

//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
Expand All @@ -1080,52 +1106,44 @@ class PDLConversionConfig final {
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
/// returns failure if there ops explicitly marked as illegal. If an
/// `unconvertedOps` set is provided, all operations that are found not to be
/// legalizable to the given `target` are placed within that set. (Note that if
/// there is an op explicitly marked as illegal, the conversion terminates and
/// the `unconvertedOps` set will not necessarily be complete.)
/// returns failure if there ops explicitly marked as illegal.
LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps = nullptr);
ConversionConfig config = ConversionConfig());
LogicalResult
applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps = nullptr);
ConversionConfig config = ConversionConfig());

/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
/// fails, or if there are unreachable blocks in any of the regions nested
/// within 'ops'.
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns);
const FrozenRewritePatternSet &patterns,
ConversionConfig config = ConversionConfig());
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns);
const FrozenRewritePatternSet &patterns,
ConversionConfig config = ConversionConfig());

/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
/// provided 'convertedOps' set; note that no actual rewrites are applied to the
/// operations on success and only pre-existing operations are added to the set.
/// This method only returns failure if there are unreachable blocks in any of
/// the regions nested within 'ops'. There's an additional argument
/// `notifyCallback` which is used for collecting match failure diagnostics
/// generated during the conversion. Diagnostics are only reported to this
/// callback may only be available in debug mode.
LogicalResult applyAnalysisConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> &convertedOps,
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
LogicalResult applyAnalysisConversion(
Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> &convertedOps,
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
/// provided 'config.legalizableOps' set; note that no actual rewrites are
/// applied to the operations on success. This method only returns failure if
/// there are unreachable blocks in any of the regions nested within 'ops'.
LogicalResult
applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config = ConversionConfig());
LogicalResult
applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config = ConversionConfig());
} // namespace mlir

#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
134 changes: 61 additions & 73 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class IRRewrite {
/// Erase the given block (unless it was already erased).
void eraseBlock(Block *block);

const ConversionConfig &getConfig() const;

const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
Expand Down Expand Up @@ -735,8 +737,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
: eraseRewriter(rewriter.getContext()) {}
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
: eraseRewriter(ctx), config(config) {}

//===--------------------------------------------------------------------===//
// State Management
Expand Down Expand Up @@ -936,14 +939,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;

/// This allows the user to collect the match failure message.
function_ref<void(Diagnostic &)> notifyCallback;

/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
/// this is populated with ops found to be legalizable to the target.
/// When mode == OpConversionMode::Partial, this is populated with ops found
/// *not* to be legalizable to the target.
DenseSet<Operation *> *trackedOps = nullptr;
/// Dialect conversion configuration.
const ConversionConfig &config;

#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
Expand All @@ -966,6 +963,10 @@ void IRRewrite::eraseBlock(Block *block) {
rewriterImpl.eraseRewriter.eraseBlock(block);
}

const ConversionConfig &IRRewrite::getConfig() const {
return rewriterImpl.config;
}

void BlockTypeConversionRewrite::commit() {
// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
Expand Down Expand Up @@ -1085,8 +1086,8 @@ void ReplaceOperationRewrite::commit() {
if (Value newValue =
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
result.replaceAllUsesWith(newValue);
if (rewriterImpl.trackedOps)
rewriterImpl.trackedOps->erase(op);
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
op->getBlock()->getOperations().remove(op);
}
Expand Down Expand Up @@ -1514,18 +1515,19 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
if (notifyCallback)
notifyCallback(diag);
if (config.notifyCallback)
config.notifyCallback(diag);
});
}

//===----------------------------------------------------------------------===//
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//

ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this)) {
impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
setListener(impl.get());
}

Expand Down Expand Up @@ -1994,12 +1996,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.notifyCallback) {
if (rewriterImpl.config.notifyCallback) {
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
diag << "Failed to apply pattern \"" << pattern.getDebugName()
<< "\" on op:\n"
<< *op;
rewriterImpl.notifyCallback(diag);
rewriterImpl.config.notifyCallback(diag);
}
});
rewriterImpl.resetState(curState);
Expand Down Expand Up @@ -2387,14 +2389,12 @@ namespace mlir {
struct OperationConverter {
explicit OperationConverter(const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
OpConversionMode mode,
DenseSet<Operation *> *trackedOps = nullptr)
: opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
const ConversionConfig &config,
OpConversionMode mode)
: opLegalizer(target, patterns), config(config), mode(mode) {}

/// Converts the given operations to the conversion target.
LogicalResult
convertOperations(ArrayRef<Operation *> ops,
function_ref<void(Diagnostic &)> notifyCallback = nullptr);
LogicalResult convertOperations(ArrayRef<Operation *> ops);

private:
/// Converts an operation with the given rewriter.
Expand Down Expand Up @@ -2431,14 +2431,11 @@ struct OperationConverter {
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;

/// Dialect conversion configuration.
ConversionConfig config;

/// The conversion mode to use when legalizing operations.
OpConversionMode mode;

/// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
/// this is populated with ops found to be legalizable to the target.
/// When mode == OpConversionMode::Partial, this is populated with ops found
/// *not* to be legalizable to the target.
DenseSet<Operation *> *trackedOps;
};
} // namespace mlir

Expand All @@ -2452,28 +2449,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
// Partial conversions allow conversions to fail iff the operation was not
// explicitly marked as illegal. If the user provided a nonlegalizableOps
// set, non-legalizable ops are included.
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
// set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
if (trackedOps)
trackedOps->insert(op);
if (config.unlegalizedOps)
config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
trackedOps->insert(op);
if (config.legalizableOps)
config.legalizableOps->insert(op);
}
return success();
}

LogicalResult OperationConverter::convertOperations(
ArrayRef<Operation *> ops,
function_ref<void(Diagnostic &)> notifyCallback) {
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
const ConversionTarget &target = opLegalizer.getTarget();
Expand All @@ -2494,10 +2490,8 @@ LogicalResult OperationConverter::convertOperations(
}

// Convert each operation and discard rewrites on failure.
ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
rewriterImpl.notifyCallback = notifyCallback;
rewriterImpl.trackedOps = trackedOps;

for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
Expand Down Expand Up @@ -3484,57 +3478,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
//===----------------------------------------------------------------------===//
// Partial Conversion

LogicalResult
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps) {
OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
unconvertedOps);
LogicalResult mlir::applyPartialConversion(
ArrayRef<Operation *> ops, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Partial);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> *unconvertedOps) {
return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
unconvertedOps);
ConversionConfig config) {
return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
}

//===----------------------------------------------------------------------===//
// Full Conversion

LogicalResult
mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns) {
OperationConverter opConverter(target, patterns, OpConversionMode::Full);
LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) {
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Full);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns) {
return applyFullConversion(llvm::ArrayRef(op), target, patterns);
LogicalResult mlir::applyFullConversion(Operation *op,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
ConversionConfig config) {
return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
}

//===----------------------------------------------------------------------===//
// Analysis Conversion

LogicalResult
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> &convertedOps,
function_ref<void(Diagnostic &)> notifyCallback) {
OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
&convertedOps);
return opConverter.convertOperations(ops, notifyCallback);
LogicalResult mlir::applyAnalysisConversion(
ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
OperationConverter opConverter(target, patterns, config,
OpConversionMode::Analysis);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
DenseSet<Operation *> &convertedOps,
function_ref<void(Diagnostic &)> notifyCallback) {
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
convertedOps, notifyCallback);
ConversionConfig config) {
return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
}
Loading