9
9
#include " mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
10
10
11
11
#include " mlir/IR/BuiltinOps.h"
12
- #include " mlir/IR/PatternMatch.h"
13
12
#include " mlir/Pass/Pass.h"
14
- #include " mlir/Transforms/DialectConversion.h"
15
13
16
14
namespace mlir {
17
15
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -22,113 +20,87 @@ using namespace mlir;
22
20
23
21
namespace {
24
22
25
- // / Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
26
- // / the same as the input ones.
27
- // / For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
28
- // / represent a noop within the IR, and thus the initial input values can be
29
- // / propagated.
30
- // / The same does not hold for 'open' chains of casts, such as
31
- // / `A -> B -> C`. In this last case there is no cycle among the types and thus
32
- // / the conversion is incomplete. The same hold for 'closed' chains like
33
- // / `A -> B -> A`, but with the result of type `B` being used by some non-cast
34
- // / operations.
35
- // / Bifurcations (that is when a chain starts in between of another one) are
36
- // / also taken into considerations, and all the above considerations remain
37
- // / valid.
38
- // / Special corner cases such as dead casts or single casts with same input and
39
- // / output types are also covered.
40
- struct UnrealizedConversionCastPassthrough
41
- : public OpRewritePattern<UnrealizedConversionCastOp> {
42
- using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
43
-
44
- LogicalResult matchAndRewrite (UnrealizedConversionCastOp op,
45
- PatternRewriter &rewriter) const override {
46
- // The nodes that either are not used by any operation or have at least
47
- // one user that is not an unrealized cast.
48
- DenseSet<UnrealizedConversionCastOp> exitNodes;
49
-
50
- // The nodes whose users are all unrealized casts
51
- DenseSet<UnrealizedConversionCastOp> intermediateNodes;
52
-
53
- // Stack used for the depth-first traversal of the use-def DAG.
54
- SmallVector<UnrealizedConversionCastOp, 2 > visitStack;
55
- visitStack.push_back (op);
56
-
57
- while (!visitStack.empty ()) {
58
- UnrealizedConversionCastOp current = visitStack.pop_back_val ();
59
- auto users = current->getUsers ();
60
- bool isLive = false ;
61
-
62
- for (Operation *user : users) {
63
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
64
- if (other.getInputs () != current.getOutputs ())
65
- return rewriter.notifyMatchFailure (
66
- op, " mismatching values propagation" );
67
- } else {
68
- isLive = true ;
69
- }
70
-
71
- // Continue traversing the DAG of unrealized casts
72
- if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
73
- visitStack.push_back (other);
74
- }
75
-
76
- // If the cast is live, then we need to check if the results of the last
77
- // cast have the same type of the root inputs. It this is the case (e.g.
78
- // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
79
- // no-op and the inputs can be forwarded. If it's not (e.g.
80
- // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
81
-
82
- bool isCycle = current.getResultTypes () == op.getInputs ().getTypes ();
83
-
84
- if (isLive && !isCycle)
85
- return rewriter.notifyMatchFailure (op,
86
- " live unrealized conversion cast" );
87
-
88
- bool isExitNode = users.empty () || isLive;
89
-
90
- if (isExitNode) {
91
- exitNodes.insert (current);
92
- } else {
93
- intermediateNodes.insert (current);
94
- }
95
- }
96
-
97
- // Replace the sink nodes with the root input values
98
- for (UnrealizedConversionCastOp exitNode : exitNodes)
99
- rewriter.replaceOp (exitNode, op.getInputs ());
100
-
101
- // Erase all the other casts belonging to the DAG
102
- for (UnrealizedConversionCastOp castOp : intermediateNodes)
103
- rewriter.eraseOp (castOp);
104
-
105
- return success ();
106
- }
107
- };
108
-
109
23
// / Pass to simplify and eliminate unrealized conversion casts.
24
+ // /
25
+ // / This pass processes unrealized_conversion_cast ops in a worklist-driven
26
+ // / fashion. For each matched cast op, if the chain of input casts eventually
27
+ // / reaches a cast op where the input types match the output types of the
28
+ // / matched op, replace the matched op with the inputs.
29
+ // /
30
+ // / Example:
31
+ // / %1 = unrealized_conversion_cast %0 : !A to !B
32
+ // / %2 = unrealized_conversion_cast %1 : !B to !C
33
+ // / %3 = unrealized_conversion_cast %2 : !C to !A
34
+ // /
35
+ // / In the above example, %0 can be used instead of %3 and all cast ops are
36
+ // / folded away.
110
37
struct ReconcileUnrealizedCasts
111
38
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
112
39
ReconcileUnrealizedCasts () = default ;
113
40
114
41
void runOnOperation () override {
115
- RewritePatternSet patterns (&getContext ());
116
- populateReconcileUnrealizedCastsPatterns (patterns);
117
- ConversionTarget target (getContext ());
118
- target.addIllegalOp <UnrealizedConversionCastOp>();
119
- if (failed (applyPartialConversion (getOperation (), target,
120
- std::move (patterns))))
121
- signalPassFailure ();
42
+ // Gather all unrealized_conversion_cast ops.
43
+ SetVector<UnrealizedConversionCastOp> worklist;
44
+ getOperation ()->walk (
45
+ [&](UnrealizedConversionCastOp castOp) { worklist.insert (castOp); });
46
+
47
+ // Helper function that adds all operands to the worklist that are an
48
+ // unrealized_conversion_cast op result.
49
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
50
+ for (Value v : castOp.getInputs ())
51
+ if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
52
+ worklist.insert (inputCastOp);
53
+ };
54
+
55
+ // Helper function that return the unrealized_conversion_cast op that
56
+ // defines all inputs of the given op (in the same order). Return "nullptr"
57
+ // if there is no such op.
58
+ auto getInputCast =
59
+ [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
60
+ if (castOp.getInputs ().empty ())
61
+ return {};
62
+ auto inputCastOp = castOp.getInputs ()
63
+ .front ()
64
+ .getDefiningOp <UnrealizedConversionCastOp>();
65
+ if (!inputCastOp)
66
+ return {};
67
+ if (inputCastOp.getOutputs () != castOp.getInputs ())
68
+ return {};
69
+ return inputCastOp;
70
+ };
71
+
72
+ // Process ops in the worklist bottom-to-top.
73
+ while (!worklist.empty ()) {
74
+ UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
75
+ if (castOp->use_empty ()) {
76
+ // DCE: If the op has no users, erase it. Add the operands to the
77
+ // worklist to find additional DCE opportunities.
78
+ enqueueOperands (castOp);
79
+ castOp->erase ();
80
+ continue ;
81
+ }
82
+
83
+ // Traverse the chain of input cast ops to see if an op with the same
84
+ // input types can be found.
85
+ UnrealizedConversionCastOp nextCast = castOp;
86
+ while (nextCast) {
87
+ if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
88
+ // Found a cast where the input types match the output types of the
89
+ // matched op. We can directly use those inputs and the matched op can
90
+ // be removed.
91
+ enqueueOperands (castOp);
92
+ castOp.replaceAllUsesWith (nextCast.getInputs ());
93
+ castOp->erase ();
94
+ break ;
95
+ }
96
+ nextCast = getInputCast (nextCast);
97
+ }
98
+ }
122
99
}
123
100
};
124
101
125
102
} // namespace
126
103
127
- void mlir::populateReconcileUnrealizedCastsPatterns (
128
- RewritePatternSet &patterns) {
129
- patterns.add <UnrealizedConversionCastPassthrough>(patterns.getContext ());
130
- }
131
-
132
104
std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass () {
133
105
return std::make_unique<ReconcileUnrealizedCasts>();
134
106
}
0 commit comments