Skip to content

Commit 8703175

Browse files
[mlir][Transforms][NFC] Move ReconcileUnrealizedCasts implementation
Move the implementation of `ReconcileUnrealizedCasts` to `DialectConversion.cpp`, so that it can be called from there in a future commit. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion framework. The existing logic around unresolved materializations that predicts IR changes to decide if a cast op can be folded/erased will become obsolete, as `ReconcileUnrealizedCasts` will perform these kind of foldings on fully materialized IR.
1 parent aad27bf commit 8703175

File tree

3 files changed

+100
-55
lines changed

3 files changed

+100
-55
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,29 @@ struct ConversionConfig {
11261126
RewriterBase::Listener *listener = nullptr;
11271127
};
11281128

1129+
//===----------------------------------------------------------------------===//
1130+
// Reconcile Unrealized Casts
1131+
//===----------------------------------------------------------------------===//
1132+
1133+
/// Try to reconcile all given UnrealizedConversionCastOps and store the
1134+
/// left-over ops in `remainingCastOps` (if provided).
1135+
///
1136+
/// This function processes cast ops in a worklist-driven fashion. For each
1137+
/// cast op, if the chain of input casts eventually reaches a cast op where the
1138+
/// input types match the output types of the matched op, replace the matched
1139+
/// op with the inputs.
1140+
///
1141+
/// Example:
1142+
/// %1 = unrealized_conversion_cast %0 : !A to !B
1143+
/// %2 = unrealized_conversion_cast %1 : !B to !C
1144+
/// %3 = unrealized_conversion_cast %2 : !C to !A
1145+
///
1146+
/// In the above example, %0 can be used instead of %3 and all cast ops are
1147+
/// folded away.
1148+
void reconcileUnrealizedCasts(
1149+
ArrayRef<UnrealizedConversionCastOp> castOps,
1150+
SmallVector<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1151+
11291152
//===----------------------------------------------------------------------===//
11301153
// Op Conversion Entry Points
11311154
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/IR/BuiltinOps.h"
1212
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Transforms/DialectConversion.h"
1314

1415
namespace mlir {
1516
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -39,63 +40,10 @@ struct ReconcileUnrealizedCasts
3940
ReconcileUnrealizedCasts() = default;
4041

4142
void runOnOperation() override {
42-
// Gather all unrealized_conversion_cast ops.
43-
SetVector<UnrealizedConversionCastOp> worklist;
43+
SmallVector<UnrealizedConversionCastOp> ops;
4444
getOperation()->walk(
4545
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
46-
47-
// Helper function that adds all operands to the worklist that are an
48-
// unrealized_conversion_cast op result.
49-
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
50-
for (Value v : castOp.getInputs())
51-
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
52-
worklist.insert(inputCastOp);
53-
};
54-
55-
// Helper function that return the unrealized_conversion_cast op that
56-
// defines all inputs of the given op (in the same order). Return "nullptr"
57-
// if there is no such op.
58-
auto getInputCast =
59-
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
60-
if (castOp.getInputs().empty())
61-
return {};
62-
auto inputCastOp = castOp.getInputs()
63-
.front()
64-
.getDefiningOp<UnrealizedConversionCastOp>();
65-
if (!inputCastOp)
66-
return {};
67-
if (inputCastOp.getOutputs() != castOp.getInputs())
68-
return {};
69-
return inputCastOp;
70-
};
71-
72-
// Process ops in the worklist bottom-to-top.
73-
while (!worklist.empty()) {
74-
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
75-
if (castOp->use_empty()) {
76-
// DCE: If the op has no users, erase it. Add the operands to the
77-
// worklist to find additional DCE opportunities.
78-
enqueueOperands(castOp);
79-
castOp->erase();
80-
continue;
81-
}
82-
83-
// Traverse the chain of input cast ops to see if an op with the same
84-
// input types can be found.
85-
UnrealizedConversionCastOp nextCast = castOp;
86-
while (nextCast) {
87-
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
88-
// Found a cast where the input types match the output types of the
89-
// matched op. We can directly use those inputs and the matched op can
90-
// be removed.
91-
enqueueOperands(castOp);
92-
castOp.replaceAllUsesWith(nextCast.getInputs());
93-
castOp->erase();
94-
break;
95-
}
96-
nextCast = getInputCast(nextCast);
97-
}
98-
}
46+
reconcileUnrealizedCasts(ops);
9947
}
10048
};
10149

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,6 +2869,80 @@ LogicalResult OperationConverter::legalizeErasedResult(
28692869
return success();
28702870
}
28712871

2872+
//===----------------------------------------------------------------------===//
2873+
// Reconcile Unrealized Casts
2874+
//===----------------------------------------------------------------------===//
2875+
2876+
void mlir::reconcileUnrealizedCasts(
2877+
ArrayRef<UnrealizedConversionCastOp> castOps,
2878+
SmallVector<UnrealizedConversionCastOp> *remainingCastOps) {
2879+
SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(),
2880+
castOps.end());
2881+
// This set is maintained only if `remainingCastOps` is provided.
2882+
DenseSet<Operation *> erasedOps;
2883+
2884+
// Helper function that adds all operands to the worklist that are an
2885+
// unrealized_conversion_cast op result.
2886+
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
2887+
for (Value v : castOp.getInputs())
2888+
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
2889+
worklist.insert(inputCastOp);
2890+
};
2891+
2892+
// Helper function that return the unrealized_conversion_cast op that
2893+
// defines all inputs of the given op (in the same order). Return "nullptr"
2894+
// if there is no such op.
2895+
auto getInputCast =
2896+
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
2897+
if (castOp.getInputs().empty())
2898+
return {};
2899+
auto inputCastOp =
2900+
castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
2901+
if (!inputCastOp)
2902+
return {};
2903+
if (inputCastOp.getOutputs() != castOp.getInputs())
2904+
return {};
2905+
return inputCastOp;
2906+
};
2907+
2908+
// Process ops in the worklist bottom-to-top.
2909+
while (!worklist.empty()) {
2910+
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
2911+
if (castOp->use_empty()) {
2912+
// DCE: If the op has no users, erase it. Add the operands to the
2913+
// worklist to find additional DCE opportunities.
2914+
enqueueOperands(castOp);
2915+
if (remainingCastOps)
2916+
erasedOps.insert(castOp.getOperation());
2917+
castOp->erase();
2918+
continue;
2919+
}
2920+
2921+
// Traverse the chain of input cast ops to see if an op with the same
2922+
// input types can be found.
2923+
UnrealizedConversionCastOp nextCast = castOp;
2924+
while (nextCast) {
2925+
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
2926+
// Found a cast where the input types match the output types of the
2927+
// matched op. We can directly use those inputs and the matched op can
2928+
// be removed.
2929+
enqueueOperands(castOp);
2930+
castOp.replaceAllUsesWith(nextCast.getInputs());
2931+
if (remainingCastOps)
2932+
erasedOps.insert(castOp.getOperation());
2933+
castOp->erase();
2934+
break;
2935+
}
2936+
nextCast = getInputCast(nextCast);
2937+
}
2938+
}
2939+
2940+
if (remainingCastOps)
2941+
for (UnrealizedConversionCastOp op : castOps)
2942+
if (!erasedOps.contains(op.getOperation()))
2943+
remainingCastOps->push_back(op);
2944+
}
2945+
28722946
//===----------------------------------------------------------------------===//
28732947
// Type Conversion
28742948
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)