Skip to content

Commit 5030dea

Browse files
[mlir][Transforms] Dialect conversion: No rollback during analysis conversion (#106414)
This commit changes the implementation of analysis conversions, so that no rollback is needed at the end of the analysis. Instead, the dialect conversion is run on a clone of the IR. The purpose of this commit is to reduce the number of rollbacks in the dialect conversion framework. (Long term goal: Remove rollback functionality entirely.)
1 parent 3f9caba commit 5030dea

File tree

1 file changed

+68
-8
lines changed

1 file changed

+68
-8
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2466,13 +2466,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
24662466
// legalized.
24672467
finalize(rewriter);
24682468

2469-
// After a successful conversion, apply rewrites if this is not an analysis
2470-
// conversion.
2471-
if (mode == OpConversionMode::Analysis) {
2472-
rewriterImpl.undoRewrites();
2473-
} else {
2474-
rewriterImpl.applyRewrites();
2475-
}
2469+
// After a successful conversion, apply rewrites.
2470+
rewriterImpl.applyRewrites();
24762471

24772472
// Gather all unresolved materializations.
24782473
SmallVector<UnrealizedConversionCastOp> allCastOps;
@@ -3215,13 +3210,78 @@ LogicalResult mlir::applyFullConversion(Operation *op,
32153210
//===----------------------------------------------------------------------===//
32163211
// Analysis Conversion
32173212

3213+
/// Find a common IsolatedFromAbove ancestor of the given ops. If at least one
3214+
/// op is a top-level module op (which is expected to be isolated from above),
3215+
/// return that op.
3216+
static Operation *findCommonAncestor(ArrayRef<Operation *> ops) {
3217+
// Check if there is a top-level operation within `ops`. If so, return that
3218+
// op.
3219+
for (Operation *op : ops) {
3220+
if (!op->getParentOp()) {
3221+
#ifndef NDEBUG
3222+
assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
3223+
"expected top-level op to be isolated from above");
3224+
for (Operation *other : ops)
3225+
assert(op->isAncestor(other) &&
3226+
"expected ops to have a common ancestor");
3227+
#endif // NDEBUG
3228+
return op;
3229+
}
3230+
}
3231+
3232+
// No top-level op. Find a common ancestor.
3233+
Operation *commonAncestor =
3234+
ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3235+
for (Operation *op : ops.drop_front()) {
3236+
while (!commonAncestor->isProperAncestor(op)) {
3237+
commonAncestor =
3238+
commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
3239+
assert(commonAncestor &&
3240+
"expected to find a common isolated from above ancestor");
3241+
}
3242+
}
3243+
3244+
return commonAncestor;
3245+
}
3246+
32183247
LogicalResult mlir::applyAnalysisConversion(
32193248
ArrayRef<Operation *> ops, ConversionTarget &target,
32203249
const FrozenRewritePatternSet &patterns, ConversionConfig config) {
3250+
#ifndef NDEBUG
3251+
if (config.legalizableOps)
3252+
assert(config.legalizableOps->empty() && "expected empty set");
3253+
#endif // NDEBUG
3254+
3255+
// Clone closted common ancestor that is isolated from above.
3256+
Operation *commonAncestor = findCommonAncestor(ops);
3257+
IRMapping mapping;
3258+
Operation *clonedAncestor = commonAncestor->clone(mapping);
3259+
// Compute inverse IR mapping.
3260+
DenseMap<Operation *, Operation *> inverseOperationMap;
3261+
for (auto &it : mapping.getOperationMap())
3262+
inverseOperationMap[it.second] = it.first;
3263+
3264+
// Convert the cloned operations. The original IR will remain unchanged.
3265+
SmallVector<Operation *> opsToConvert = llvm::map_to_vector(
3266+
ops, [&](Operation *op) { return mapping.lookup(op); });
32213267
OperationConverter opConverter(target, patterns, config,
32223268
OpConversionMode::Analysis);
3223-
return opConverter.convertOperations(ops);
3269+
LogicalResult status = opConverter.convertOperations(opsToConvert);
3270+
3271+
// Remap `legalizableOps`, so that they point to the original ops and not the
3272+
// cloned ops.
3273+
if (config.legalizableOps) {
3274+
DenseSet<Operation *> originalLegalizableOps;
3275+
for (Operation *op : *config.legalizableOps)
3276+
originalLegalizableOps.insert(inverseOperationMap[op]);
3277+
*config.legalizableOps = std::move(originalLegalizableOps);
3278+
}
3279+
3280+
// Erase the cloned IR.
3281+
clonedAncestor->erase();
3282+
return status;
32243283
}
3284+
32253285
LogicalResult
32263286
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
32273287
const FrozenRewritePatternSet &patterns,

0 commit comments

Comments
 (0)