@@ -94,106 +94,111 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
94
94
bool abortOnFailedAssert = true ;
95
95
};
96
96
97
- // / The cf->LLVM lowerings for branching ops require that the blocks they jump
98
- // / to first have updated types which should be handled by a pattern operating
99
- // / on the parent op.
100
- static LogicalResult verifyMatchingValues (ConversionPatternRewriter &rewriter,
101
- ValueRange operands,
102
- ValueRange blockArgs, Location loc,
103
- llvm::StringRef messagePrefix) {
104
- for (const auto &idxAndTypes :
105
- llvm::enumerate (llvm::zip (blockArgs, operands))) {
106
- int64_t i = idxAndTypes.index ();
107
- Value argValue =
108
- rewriter.getRemappedValue (std::get<0 >(idxAndTypes.value ()));
109
- Type operandType = std::get<1 >(idxAndTypes.value ()).getType ();
110
- // In the case of an invalid jump, the block argument will have been
111
- // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112
- // there might still be a no-op conversion cast with both types being equal.
113
- // Consider both of these details to see if the jump would be invalid.
114
- if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115
- argValue.getDefiningOp ())) {
116
- if (op.getOperandTypes ().front () != operandType) {
117
- return rewriter.notifyMatchFailure (loc, [&](Diagnostic &diag) {
118
- diag << messagePrefix;
119
- diag << " mismatched types from operand # " << i << " " ;
120
- diag << operandType;
121
- diag << " not compatible with destination block argument type " ;
122
- diag << op.getOperandTypes ().front ();
123
- diag << " which should be converted with the parent op." ;
124
- });
125
- }
126
- }
127
- }
128
- return success ();
97
+ // / Helper function for converting branch ops. This function converts the
98
+ // / signature of the given block. If the new block signature is different from
99
+ // / `expectedTypes`, returns "failure".
100
+ static FailureOr<Block *> getConvertedBlock (ConversionPatternRewriter &rewriter,
101
+ const TypeConverter *converter,
102
+ Operation *branchOp, Block *block,
103
+ TypeRange expectedTypes) {
104
+ assert (converter && " expected non-null type converter" );
105
+ assert (!block->isEntryBlock () && " entry blocks have no predecessors" );
106
+
107
+ // There is nothing to do if the types already match.
108
+ if (block->getArgumentTypes () == expectedTypes)
109
+ return block;
110
+
111
+ // Compute the new block argument types and convert the block.
112
+ std::optional<TypeConverter::SignatureConversion> conversion =
113
+ converter->convertBlockSignature (block);
114
+ if (!conversion)
115
+ return rewriter.notifyMatchFailure (branchOp,
116
+ " could not compute block signature" );
117
+ if (expectedTypes != conversion->getConvertedTypes ())
118
+ return rewriter.notifyMatchFailure (
119
+ branchOp,
120
+ " mismatch between adaptor operand types and computed block signature" );
121
+ return rewriter.applySignatureConversion (block, *conversion, converter);
129
122
}
130
123
131
- // / Ensure that all block types were updated and then create an LLVM::BrOp
124
+ // / Convert the destination block signature (if necessary) and lower the branch
125
+ // / op to llvm.br.
132
126
struct BranchOpLowering : public ConvertOpToLLVMPattern <cf::BranchOp> {
133
127
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134
128
135
129
LogicalResult
136
130
matchAndRewrite (cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137
131
ConversionPatternRewriter &rewriter) const override {
138
- if ( failed ( verifyMatchingValues (rewriter, adaptor. getDestOperands (),
139
- op.getSuccessor ()-> getArguments (),
140
- op. getLoc (),
141
- /* messagePrefix= */ " " ) ))
132
+ FailureOr<Block *> convertedBlock =
133
+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getSuccessor (),
134
+ TypeRange (adaptor. getOperands ()));
135
+ if ( failed (convertedBlock ))
142
136
return failure ();
143
-
144
- rewriter.replaceOpWithNewOp <LLVM::BrOp>(
145
- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
137
+ rewriter.replaceOpWithNewOp <LLVM::BrOp>(op, adaptor.getOperands (),
138
+ *convertedBlock);
146
139
return success ();
147
140
}
148
141
};
149
142
150
- // / Ensure that all block types were updated and then create an LLVM::CondBrOp
143
+ // / Convert the destination block signatures (if necessary) and lower the
144
+ // / branch op to llvm.cond_br.
151
145
struct CondBranchOpLowering : public ConvertOpToLLVMPattern <cf::CondBranchOp> {
152
146
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153
147
154
148
LogicalResult
155
149
matchAndRewrite (cf::CondBranchOp op,
156
150
typename cf::CondBranchOp::Adaptor adaptor,
157
151
ConversionPatternRewriter &rewriter) const override {
158
- if (failed (verifyMatchingValues (rewriter, adaptor.getFalseDestOperands (),
159
- op.getFalseDest ()->getArguments (),
160
- op.getLoc (), " in false case branch " )))
152
+ FailureOr<Block *> convertedTrueBlock =
153
+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getTrueDest (),
154
+ TypeRange (adaptor.getTrueDestOperands ()));
155
+ if (failed (convertedTrueBlock))
161
156
return failure ();
162
- if (failed (verifyMatchingValues (rewriter, adaptor.getTrueDestOperands (),
163
- op.getTrueDest ()->getArguments (),
164
- op.getLoc (), " in true case branch " )))
157
+ FailureOr<Block *> convertedFalseBlock =
158
+ getConvertedBlock (rewriter, getTypeConverter (), op, op.getFalseDest (),
159
+ TypeRange (adaptor.getFalseDestOperands ()));
160
+ if (failed (convertedFalseBlock))
165
161
return failure ();
166
-
167
162
rewriter.replaceOpWithNewOp <LLVM::CondBrOp>(
168
- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
163
+ op, adaptor.getCondition (), *convertedTrueBlock,
164
+ adaptor.getTrueDestOperands (), *convertedFalseBlock,
165
+ adaptor.getFalseDestOperands ());
169
166
return success ();
170
167
}
171
168
};
172
169
173
- // / Ensure that all block types were updated and then create an LLVM::SwitchOp
170
+ // / Convert the destination block signatures (if necessary) and lower the
171
+ // / switch op to llvm.switch.
174
172
struct SwitchOpLowering : public ConvertOpToLLVMPattern <cf::SwitchOp> {
175
173
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176
174
177
175
LogicalResult
178
176
matchAndRewrite (cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179
177
ConversionPatternRewriter &rewriter) const override {
180
- if (failed (verifyMatchingValues (rewriter, adaptor.getDefaultOperands (),
181
- op.getDefaultDestination ()->getArguments (),
182
- op.getLoc (), " in switch default case " )))
178
+ // Get or convert default block.
179
+ FailureOr<Block *> convertedDefaultBlock = getConvertedBlock (
180
+ rewriter, getTypeConverter (), op, op.getDefaultDestination (),
181
+ TypeRange (adaptor.getDefaultOperands ()));
182
+ if (failed (convertedDefaultBlock))
183
183
return failure ();
184
184
185
- for (const auto &i : llvm::enumerate (
186
- llvm::zip (adaptor.getCaseOperands (), op.getCaseDestinations ()))) {
187
- if (failed (verifyMatchingValues (
188
- rewriter, std::get<0 >(i.value ()),
189
- std::get<1 >(i.value ())->getArguments (), op.getLoc (),
190
- " in switch case " + std::to_string (i.index ()) + " " ))) {
185
+ // Get or convert all case blocks.
186
+ SmallVector<Block *> caseDestinations;
187
+ SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands ();
188
+ for (auto it : llvm::enumerate (op.getCaseDestinations ())) {
189
+ Block *b = it.value ();
190
+ FailureOr<Block *> convertedBlock =
191
+ getConvertedBlock (rewriter, getTypeConverter (), op, b,
192
+ TypeRange (caseOperands[it.index ()]));
193
+ if (failed (convertedBlock))
191
194
return failure ();
192
- }
195
+ caseDestinations. push_back (*convertedBlock);
193
196
}
194
197
195
198
rewriter.replaceOpWithNewOp <LLVM::SwitchOp>(
196
- op, adaptor.getOperands (), op->getSuccessors (), op->getAttrs ());
199
+ op, adaptor.getFlag (), *convertedDefaultBlock,
200
+ adaptor.getDefaultOperands (), adaptor.getCaseValuesAttr (),
201
+ caseDestinations, caseOperands);
197
202
return success ();
198
203
}
199
204
};
@@ -230,14 +235,22 @@ struct ConvertControlFlowToLLVM
230
235
231
236
// / Run the dialect converter on the module.
232
237
void runOnOperation () override {
233
- LLVMConversionTarget target (getContext ());
234
- RewritePatternSet patterns (&getContext ());
235
-
236
- LowerToLLVMOptions options (&getContext ());
238
+ MLIRContext *ctx = &getContext ();
239
+ LLVMConversionTarget target (*ctx);
240
+ // This pass lowers only CF dialect ops, but it also modifies block
241
+ // signatures inside other ops. These ops should be treated as legal. They
242
+ // are lowered by other passes.
243
+ target.markUnknownOpDynamicallyLegal ([&](Operation *op) {
244
+ return op->getDialect () !=
245
+ ctx->getLoadedDialect <cf::ControlFlowDialect>();
246
+ });
247
+
248
+ LowerToLLVMOptions options (ctx);
237
249
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout )
238
250
options.overrideIndexBitwidth (indexBitwidth);
239
251
240
- LLVMTypeConverter converter (&getContext (), options);
252
+ LLVMTypeConverter converter (ctx, options);
253
+ RewritePatternSet patterns (ctx);
241
254
mlir::cf::populateControlFlowToLLVMConversionPatterns (converter, patterns);
242
255
243
256
if (failed (applyPartialConversion (getOperation (), target,
0 commit comments