Skip to content

Commit 5b00758

Browse files
matthias-springermaksleventalzero9178
authored
[mlir][Conversion] Generalize and fix crash in reconcile-unrealized-casts (llvm#95700)
This commit fixes a crash in `-reconcile-unrealized-casts` when cast ops have multiple operands: ``` DialectConversion.cpp:1583: virtual void mlir::ConversionPatternRewriter::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed. ``` This commit also generalizes the pass such that more ops are folded. In particular (letters indicate types): ``` A / \ B C | A ``` Previously, such IR was not folded at all. The `A -> B -> A` type cast cycle is now folded away. (The `A -> C` cast stays in place.) This commit also turns the pass from a dialect conversion into a simple IR walk. The pattern and its `populate` function are removed. The pattern was a (non-conversion) rewrite pattern, but used in a dialect conversion, which is generally not safe. In particular, the rewrite pattern may traverse IR that was already scheduled for erasure by the dialect conversion. Note: Some test cases changed slightly (NFC) because the new pass implementation no longer attempts to fold ops. Note for LLVM integration: If your pipeline uses the removed `populate` function, try to simply remove that function call. Chances are you may not need it at all. If it is in fact needed, run the `-reconcile-unrealized-casts` pass right after the pass that used to populate the pattern. --------- Co-authored-by: Maksim Levental <[email protected]> Co-authored-by: Markus Böck <[email protected]>
1 parent 2c1ae80 commit 5b00758

File tree

7 files changed

+180
-159
lines changed

7 files changed

+180
-159
lines changed

mlir/include/mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ class RewritePatternSet;
2121
/// Creates a pass that eliminates noop `unrealized_conversion_cast` operation
2222
/// sequences.
2323
std::unique_ptr<Pass> createReconcileUnrealizedCastsPass();
24-
25-
/// Populates `patterns` with rewrite patterns that eliminate noop
26-
/// `unrealized_conversion_cast` operation sequences.
27-
void populateReconcileUnrealizedCastsPatterns(RewritePatternSet &patterns);
2824
} // namespace mlir
2925

3026
#endif // MLIR_CONVERSION_RECONCILEUNREALIZEDCASTS_RECONCILEUNREALIZEDCASTS_H_

mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp

Lines changed: 70 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
1010

1111
#include "mlir/IR/BuiltinOps.h"
12-
#include "mlir/IR/PatternMatch.h"
1312
#include "mlir/Pass/Pass.h"
14-
#include "mlir/Transforms/DialectConversion.h"
1513

1614
namespace mlir {
1715
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
@@ -22,113 +20,87 @@ using namespace mlir;
2220

2321
namespace {
2422

25-
/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
26-
/// the same as the input ones.
27-
/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
28-
/// represent a noop within the IR, and thus the initial input values can be
29-
/// propagated.
30-
/// The same does not hold for 'open' chains of casts, such as
31-
/// `A -> B -> C`. In this last case there is no cycle among the types and thus
32-
/// the conversion is incomplete. The same hold for 'closed' chains like
33-
/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
34-
/// operations.
35-
/// Bifurcations (that is when a chain starts in between of another one) are
36-
/// also taken into considerations, and all the above considerations remain
37-
/// valid.
38-
/// Special corner cases such as dead casts or single casts with same input and
39-
/// output types are also covered.
40-
struct UnrealizedConversionCastPassthrough
41-
: public OpRewritePattern<UnrealizedConversionCastOp> {
42-
using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
43-
44-
LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
45-
PatternRewriter &rewriter) const override {
46-
// The nodes that either are not used by any operation or have at least
47-
// one user that is not an unrealized cast.
48-
DenseSet<UnrealizedConversionCastOp> exitNodes;
49-
50-
// The nodes whose users are all unrealized casts
51-
DenseSet<UnrealizedConversionCastOp> intermediateNodes;
52-
53-
// Stack used for the depth-first traversal of the use-def DAG.
54-
SmallVector<UnrealizedConversionCastOp, 2> visitStack;
55-
visitStack.push_back(op);
56-
57-
while (!visitStack.empty()) {
58-
UnrealizedConversionCastOp current = visitStack.pop_back_val();
59-
auto users = current->getUsers();
60-
bool isLive = false;
61-
62-
for (Operation *user : users) {
63-
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
64-
if (other.getInputs() != current.getOutputs())
65-
return rewriter.notifyMatchFailure(
66-
op, "mismatching values propagation");
67-
} else {
68-
isLive = true;
69-
}
70-
71-
// Continue traversing the DAG of unrealized casts
72-
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
73-
visitStack.push_back(other);
74-
}
75-
76-
// If the cast is live, then we need to check if the results of the last
77-
// cast have the same type of the root inputs. It this is the case (e.g.
78-
// `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
79-
// no-op and the inputs can be forwarded. If it's not (e.g.
80-
// `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
81-
82-
bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
83-
84-
if (isLive && !isCycle)
85-
return rewriter.notifyMatchFailure(op,
86-
"live unrealized conversion cast");
87-
88-
bool isExitNode = users.empty() || isLive;
89-
90-
if (isExitNode) {
91-
exitNodes.insert(current);
92-
} else {
93-
intermediateNodes.insert(current);
94-
}
95-
}
96-
97-
// Replace the sink nodes with the root input values
98-
for (UnrealizedConversionCastOp exitNode : exitNodes)
99-
rewriter.replaceOp(exitNode, op.getInputs());
100-
101-
// Erase all the other casts belonging to the DAG
102-
for (UnrealizedConversionCastOp castOp : intermediateNodes)
103-
rewriter.eraseOp(castOp);
104-
105-
return success();
106-
}
107-
};
108-
10923
/// Pass to simplify and eliminate unrealized conversion casts.
24+
///
25+
/// This pass processes unrealized_conversion_cast ops in a worklist-driven
26+
/// fashion. For each matched cast op, if the chain of input casts eventually
27+
/// reaches a cast op where the input types match the output types of the
28+
/// matched op, replace the matched op with the inputs.
29+
///
30+
/// Example:
31+
/// %1 = unrealized_conversion_cast %0 : !A to !B
32+
/// %2 = unrealized_conversion_cast %1 : !B to !C
33+
/// %3 = unrealized_conversion_cast %2 : !C to !A
34+
///
35+
/// In the above example, %0 can be used instead of %3 and all cast ops are
36+
/// folded away.
11037
struct ReconcileUnrealizedCasts
11138
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
11239
ReconcileUnrealizedCasts() = default;
11340

11441
void runOnOperation() override {
115-
RewritePatternSet patterns(&getContext());
116-
populateReconcileUnrealizedCastsPatterns(patterns);
117-
ConversionTarget target(getContext());
118-
target.addIllegalOp<UnrealizedConversionCastOp>();
119-
if (failed(applyPartialConversion(getOperation(), target,
120-
std::move(patterns))))
121-
signalPassFailure();
42+
// Gather all unrealized_conversion_cast ops.
43+
SetVector<UnrealizedConversionCastOp> worklist;
44+
getOperation()->walk(
45+
[&](UnrealizedConversionCastOp castOp) { worklist.insert(castOp); });
46+
47+
// Helper function that adds all operands to the worklist that are an
48+
// unrealized_conversion_cast op result.
49+
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
50+
for (Value v : castOp.getInputs())
51+
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
52+
worklist.insert(inputCastOp);
53+
};
54+
55+
// Helper function that return the unrealized_conversion_cast op that
56+
// defines all inputs of the given op (in the same order). Return "nullptr"
57+
// if there is no such op.
58+
auto getInputCast =
59+
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
60+
if (castOp.getInputs().empty())
61+
return {};
62+
auto inputCastOp = castOp.getInputs()
63+
.front()
64+
.getDefiningOp<UnrealizedConversionCastOp>();
65+
if (!inputCastOp)
66+
return {};
67+
if (inputCastOp.getOutputs() != castOp.getInputs())
68+
return {};
69+
return inputCastOp;
70+
};
71+
72+
// Process ops in the worklist bottom-to-top.
73+
while (!worklist.empty()) {
74+
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
75+
if (castOp->use_empty()) {
76+
// DCE: If the op has no users, erase it. Add the operands to the
77+
// worklist to find additional DCE opportunities.
78+
enqueueOperands(castOp);
79+
castOp->erase();
80+
continue;
81+
}
82+
83+
// Traverse the chain of input cast ops to see if an op with the same
84+
// input types can be found.
85+
UnrealizedConversionCastOp nextCast = castOp;
86+
while (nextCast) {
87+
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
88+
// Found a cast where the input types match the output types of the
89+
// matched op. We can directly use those inputs and the matched op can
90+
// be removed.
91+
enqueueOperands(castOp);
92+
castOp.replaceAllUsesWith(nextCast.getInputs());
93+
castOp->erase();
94+
break;
95+
}
96+
nextCast = getInputCast(nextCast);
97+
}
98+
}
12299
}
123100
};
124101

125102
} // namespace
126103

127-
void mlir::populateReconcileUnrealizedCastsPatterns(
128-
RewritePatternSet &patterns) {
129-
patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
130-
}
131-
132104
std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
133105
return std::make_unique<ReconcileUnrealizedCasts>();
134106
}

mlir/test/Conversion/FuncToLLVM/calling-convention.mlir

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func.func @return_var_memref_caller(%arg0: memref<4x3xf32>) {
127127
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
128128
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
129129
// CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)>
130-
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
130+
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]]
131131
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
132132
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
133133
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
@@ -159,14 +159,17 @@ func.func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> attributes
159159

160160
// CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant
161161
// CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]]
162-
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]]
162+
// CHECK: %[[RANK_EXTR:.*]] = llvm.extractvalue %[[DESC_2]][0]
163+
// CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK_EXTR]]
163164
// CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]]
164165
// CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]]
165166
// CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]]
166167
// CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]])
167-
// CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA]], %[[ALLOC_SIZE]]) <{isVolatile = false}>
168+
// CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
169+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[ALLOCA_EXTRACTED]], %[[ALLOC_SIZE]]) <{isVolatile = false}>
168170
// CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
169-
// CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0]
171+
// CHECK: %[[RANK_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][0]
172+
// CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK_EXTRACTED]], %[[NEW_DESC]][0]
170173
// CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1]
171174
// CHECK: llvm.return %[[NEW_DESC_2]]
172175
return %0 : memref<*xf32>
@@ -218,13 +221,15 @@ func.func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memr
218221
// convention requires the caller to free them and the caller cannot know
219222
// whether they are the same value or not.
220223
// CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}})
221-
// CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}>
224+
// CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
225+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}>
222226
// CHECK: %[[RES_1:.*]] = llvm.mlir.undef
223227
// CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0]
224228
// CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1]
225229

226230
// CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}})
227-
// CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA]], %{{.*}}) <{isVolatile = false}>
231+
// CHECK: %[[ALLOCA_EXTRACTED:.*]] = llvm.extractvalue %[[DESC_2]][1]
232+
// CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[ALLOCA_EXTRACTED]], %{{.*}}) <{isVolatile = false}>
228233
// CHECK: %[[RES_2:.*]] = llvm.mlir.undef
229234
// CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0]
230235
// CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1]
@@ -265,7 +270,8 @@ func.func @bare_ptr_calling_conv(%arg0: memref<4x3xf32>, %arg1 : index, %arg2 :
265270
// CHECK: llvm.store %{{.*}}, %[[STOREPTR]]
266271
memref.store %arg3, %arg0[%arg1, %arg2] : memref<4x3xf32>
267272

268-
// CHECK: llvm.return %[[ARG0]]
273+
// CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0]
274+
// CHECK: llvm.return %[[EXTRACT_MEMREF]]
269275
return %arg0 : memref<4x3xf32>
270276
}
271277

@@ -298,9 +304,10 @@ func.func @bare_ptr_calling_conv_multiresult(%arg0: memref<4x3xf32>, %arg1 : ind
298304
// CHECK: %[[RETURN0:.*]] = llvm.load %[[LOADPTR]]
299305
%0 = memref.load %arg0[%arg1, %arg2] : memref<4x3xf32>
300306

307+
// CHECK: %[[EXTRACT_MEMREF:.*]] = llvm.extractvalue %[[INSERT_STRIDE1]][0]
301308
// CHECK: %[[RETURN_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(f32, ptr)>
302309
// CHECK: %[[INSERT_RETURN0:.*]] = llvm.insertvalue %[[RETURN0]], %[[RETURN_DESC]][0]
303-
// CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[ARG0]], %[[INSERT_RETURN0]][1]
310+
// CHECK: %[[INSERT_RETURN1:.*]] = llvm.insertvalue %[[EXTRACT_MEMREF]], %[[INSERT_RETURN0]][1]
304311
// CHECK: llvm.return %[[INSERT_RETURN1]]
305312
return %0, %arg0 : f32, memref<4x3xf32>
306313
}

mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir

Lines changed: 0 additions & 45 deletions
This file was deleted.

0 commit comments

Comments
 (0)