Skip to content

[mlir][IR] Add listener notifications for pattern begin/end #84131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -921,20 +921,36 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
return RegionScope(*this, region);
}

/// A configuration object for customizing a `TrackingListener`.
struct TrackingListenerConfig {
using SkipHandleFn = std::function<bool(Value)>;

/// An optional function that returns "true" for handles that do not have to
/// be updated. These are typically dead or consumed handles.
SkipHandleFn skipHandleFn = nullptr;

/// If set to "true", the name of a replacement op must match the name of the
/// original op. If set to "false", the names of the payload ops tracked in a
/// handle may change as the tracking listener updates the transform state.
bool requireMatchingReplacementOpName = true;

/// If set to "true", cast ops (that implement the CastOpInterface) are
/// skipped and the replacement op search continues with the operands of the
/// cast op.
bool skipCastOps = true;
};

/// A listener that updates a TransformState based on IR modifications. This
/// listener can be used during a greedy pattern rewrite to keep the transform
/// state up-to-date.
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
/// A function that returns "true" for handles that do not have to be updated.
using SkipHandleFn = std::function<bool(Value)>;

/// Create a new TrackingListener for usage in the specified transform op.
/// Optionally, a function can be specified to identify handles that should
/// do not have to be updated.
TrackingListener(TransformState &state, TransformOpInterface op,
SkipHandleFn skipHandleFn = nullptr);
TrackingListenerConfig config = TrackingListenerConfig());

protected:
/// Return a replacement payload op for the given op, which is going to be
Expand All @@ -959,7 +975,8 @@ class TrackingListener : public RewriterBase::Listener,
/// same computation; e.g., there may be tiled "linalg.generic" inside the
/// loop body that represents the original computation. Therefore, the
/// TrackingListener is conservative by default: it drops the mapping and
/// triggers the "payload replacement not found" notification.
/// triggers the "payload replacement not found" notification. This default
/// behavior can be customized in `TrackingListenerConfig`.
///
/// If no replacement op could be found according to the rules mentioned
/// above, this function tries to skip over cast-like ops that implement
Expand Down Expand Up @@ -1023,9 +1040,8 @@ class TrackingListener : public RewriterBase::Listener,
/// The handles that are consumed by the transform op.
DenseSet<Value> consumedHandles;

/// Handles for which this function evaluates to "true" do not have to be
/// updated. These are typically dead or consumed handles.
SkipHandleFn skipHandleFn;
/// Tracking listener configuration.
TrackingListenerConfig config;
};

/// A specialized listener that keeps track of cases in which no replacement
Expand Down
19 changes: 15 additions & 4 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,30 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects`
attributes specify the conversion target.

This transform consumes the `target` handle and modifies the payload. It
does not produce any handles.
This transform modifies the payload. By default, it consumes the `target`
handle. It does not produce any handles.

If the `preserve_handles` attribute is set, this transform does not consume
the `target` handle and instead updates handles based on notifications from
a tracking listener that is attached to the dialect conversion, similar to
`transform.apply_patterns`. Only replacements via `RewriterBase::replaceOp`
or `replaceOpWithNewOp` are considered "payload op replacements". In
contrast to `transform.apply_patterns`, we allow replacement ops even if the
op name has changed. This is because conversion patterns are expected to
lower ops to different ops (from a different dialect). More details can be
found at the documentation site of `TrackingListener`.

This transform produces a silenceable failure if the dialect conversion was
unsuccessful.
unsuccessful or the tracking listener failed to find a replacement op.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
OptionalAttr<StrArrayAttr>:$legal_ops,
OptionalAttr<StrArrayAttr>:$illegal_ops,
OptionalAttr<StrArrayAttr>:$legal_dialects,
OptionalAttr<StrArrayAttr>:$illegal_dialects,
UnitAttr:$partial_conversion);
UnitAttr:$partial_conversion,
UnitAttr:$preserve_handles);
let results = (outs);
let regions = (region
MaxSizedRegion<1>:$patterns,
Expand Down
30 changes: 25 additions & 5 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationErased(Operation *op) {}

/// Notify the listener that the pattern failed to match the given
/// operation, and provide a callback to populate a diagnostic with the
/// reason why the failure occurred. This method allows for derived
/// listeners to optionally hook into the reason why a rewrite failed, and
/// display it to users.
/// Notify the listener that the specified pattern is about to be applied
/// at the specified root operation.
virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}

/// Notify the listener that a pattern application finished with the
/// specified status. "success" indicates that the pattern was applied
/// successfully. "failure" indicates that the pattern could not be
/// applied. The pattern may have communicated the reason for the failure
/// with `notifyMatchFailure`.
virtual void notifyPatternEnd(const Pattern &pattern,
LogicalResult status) {}

/// Notify the listener that the pattern failed to match, and provide a
/// callback to populate a diagnostic with the reason why the failure
/// occurred. This method allows for derived listeners to optionally hook
/// into the reason why a rewrite failed, and display it to users.
virtual void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) {}
Expand Down Expand Up @@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyOperationErased(op);
}
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyPatternBegin(pattern, op);
}
void notifyPatternEnd(const Pattern &pattern,
LogicalResult status) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyPatternEnd(pattern, status);
}
void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
Expand Down
39 changes: 21 additions & 18 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}

// Prepare rewriter and listener.
TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
TrackingListenerConfig config;
config.skipHandleFn = [&](Value handle) {
// Skip handle if it is dead.
auto scopeIt =
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
Expand All @@ -935,7 +936,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
return true;
};
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
skipHandleFn);
config);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);

Expand Down Expand Up @@ -1184,9 +1185,8 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {

transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op,
SkipHandleFn skipHandleFn)
: TransformState::Extension(state), transformOp(op),
skipHandleFn(skipHandleFn) {
TrackingListenerConfig config)
: TransformState::Extension(state), transformOp(op), config(config) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
Expand Down Expand Up @@ -1228,8 +1228,19 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
return diag;
}

// If the defining op has the same type, we take it as a replacement.
if (op->getName() == defOp->getName()) {
// Skip through ops that implement CastOpInterface.
if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
values.clear();
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
diag.attachNote(defOp->getLoc())
<< "using output of 'CastOpInterface' op";
continue;
}

// If the defining op has the same name or we do not care about the name of
// op replacements at all, we take it as a replacement.
if (!config.requireMatchingReplacementOpName ||
op->getName() == defOp->getName()) {
result = defOp;
return DiagnosedSilenceableFailure::success();
}
Expand All @@ -1251,14 +1262,6 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
"'FindPayloadReplacementOpInterface'";
continue;
}

// Skip through ops that implement CastOpInterface.
if (isa<CastOpInterface>(defOp)) {
values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
diag.attachNote(defOp->getLoc())
<< "using output of 'CastOpInterface' op";
continue;
}
} while (!values.empty());

diag.attachNote() << "ran out of suitable replacement values";
Expand Down Expand Up @@ -1318,9 +1321,9 @@ void transform::TrackingListener::notifyOperationReplaced(

// Check if there are any handles that must be updated.
Value aliveHandle;
if (skipHandleFn) {
auto it =
llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
if (config.skipHandleFn) {
auto it = llvm::find_if(opHandles,
[&](Value v) { return !config.skipHandleFn(v); });
if (it != opHandles.end())
aliveHandle = *it;
} else if (!opHandles.empty()) {
Expand Down
45 changes: 40 additions & 5 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,17 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
}
}

// Attach a tracking listener if handles should be preserved. We configure the
// listener to allow op replacements with different names, as conversion
// patterns typically replace ops with replacement ops that have a different
// name.
TrackingListenerConfig trackingConfig;
trackingConfig.requireMatchingReplacementOpName = false;
ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
ConversionConfig conversionConfig;
if (getPreserveHandles())
conversionConfig.listener = &trackingListener;

FrozenRewritePatternSet frozenPatterns(std::move(patterns));
for (Operation *target : state.getPayloadOps(getTarget())) {
// Make sure that this transform is not applied to itself. Modifying the
Expand All @@ -574,16 +585,36 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(

LogicalResult status = failure();
if (getPartialConversion()) {
status = applyPartialConversion(target, conversionTarget, frozenPatterns);
status = applyPartialConversion(target, conversionTarget, frozenPatterns,
conversionConfig);
} else {
status = applyFullConversion(target, conversionTarget, frozenPatterns);
status = applyFullConversion(target, conversionTarget, frozenPatterns,
conversionConfig);
}

// Check dialect conversion state.
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
if (failed(status)) {
auto diag = emitSilenceableError() << "dialect conversion failed";
diag = emitSilenceableError() << "dialect conversion failed";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}

// Check tracking listener error state.
DiagnosedSilenceableFailure trackingFailure =
trackingListener.checkAndResetError();
if (!trackingFailure.succeeded()) {
if (diag.succeeded()) {
// Tracking failure is the only failure.
return trackingFailure;
} else {
diag.attachNote() << "tracking listener also failed: "
<< trackingFailure.getMessage();
(void)trackingFailure.silence();
}
}

if (!diag.succeeded())
return diag;
}

return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -632,7 +663,11 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {

void transform::ApplyConversionPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
if (!getPreserveHandles()) {
transform::consumesHandle(getTarget(), effects);
} else {
transform::onlyReadsHandle(getTarget(), effects);
}
transform::modifiesPayload(effects);
}

Expand Down
29 changes: 21 additions & 8 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,7 +1856,8 @@ class OperationLegalizer {
using LegalizationAction = ConversionTarget::LegalizationAction;

OperationLegalizer(const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config);

/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
Expand Down Expand Up @@ -1948,12 +1949,16 @@ class OperationLegalizer {

/// The pattern applicator to use for conversions.
PatternApplicator applicator;

/// Dialect conversion configuration.
const ConversionConfig &config;
};
} // namespace

OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
: target(targetInfo), applicator(patterns) {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config)
: target(targetInfo), applicator(patterns), config(config) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
Expand Down Expand Up @@ -2098,7 +2103,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,

// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
return canApplyPattern(op, pattern, rewriter);
bool canApply = canApplyPattern(op, pattern, rewriter);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
};

// Functor that cleans up the rewriter state after a pattern failed to match.
Expand All @@ -2115,6 +2123,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
rewriterImpl.config.notifyCallback(diag);
}
});
if (config.listener)
config.listener->notifyPatternEnd(pattern, failure());
rewriterImpl.resetState(curState);
appliedPatterns.erase(&pattern);
};
Expand All @@ -2127,6 +2137,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
appliedPatterns.erase(&pattern);
if (failed(result))
rewriterImpl.resetState(curState);
if (config.listener)
config.listener->notifyPatternEnd(pattern, result);
return result;
};

Expand Down Expand Up @@ -2502,7 +2514,8 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
: opLegalizer(target, patterns), config(config), mode(mode) {}
: config(config), opLegalizer(target, patterns, this->config),
mode(mode) {}

/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
Expand Down Expand Up @@ -2539,12 +2552,12 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;

/// Dialect conversion configuration.
ConversionConfig config;

/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;

/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
};
Expand Down
Loading