Skip to content

Commit a86a580

Browse files
[mlir][CF] Split cf-to-llvm from func-to-llvm
1 parent 0a160d7 commit a86a580

File tree

22 files changed

+260
-141
lines changed

22 files changed

+260
-141
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,10 +460,6 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
460460
1 value is returned, packed into an LLVM IR struct type. Function calls and
461461
returns are updated accordingly. Block argument types are updated to use
462462
LLVM IR types.
463-
464-
Note that until https://github.com/llvm/llvm-project/issues/70982 is resolved,
465-
this pass includes patterns that lower `arith` and `cf` to LLVM. This is legacy
466-
code due to when they were all converted in the same pass.
467463
}];
468464
let dependentDialects = ["LLVM::LLVMDialect"];
469465
let options = [

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 79 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -94,106 +94,111 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
9494
bool abortOnFailedAssert = true;
9595
};
9696

97-
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
98-
/// to first have updated types which should be handled by a pattern operating
99-
/// on the parent op.
100-
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101-
ValueRange operands,
102-
ValueRange blockArgs, Location loc,
103-
llvm::StringRef messagePrefix) {
104-
for (const auto &idxAndTypes :
105-
llvm::enumerate(llvm::zip(blockArgs, operands))) {
106-
int64_t i = idxAndTypes.index();
107-
Value argValue =
108-
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109-
Type operandType = std::get<1>(idxAndTypes.value()).getType();
110-
// In the case of an invalid jump, the block argument will have been
111-
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
112-
// there might still be a no-op conversion cast with both types being equal.
113-
// Consider both of these details to see if the jump would be invalid.
114-
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115-
argValue.getDefiningOp())) {
116-
if (op.getOperandTypes().front() != operandType) {
117-
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118-
diag << messagePrefix;
119-
diag << "mismatched types from operand # " << i << " ";
120-
diag << operandType;
121-
diag << " not compatible with destination block argument type ";
122-
diag << op.getOperandTypes().front();
123-
diag << " which should be converted with the parent op.";
124-
});
125-
}
126-
}
127-
}
128-
return success();
97+
/// Helper function for converting branch ops. This function converts the
98+
/// signature of the given block. If the new block signature is different from
99+
/// `expectedTypes`, returns "failure".
100+
static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
101+
const TypeConverter *converter,
102+
Operation *branchOp, Block *block,
103+
TypeRange expectedTypes) {
104+
assert(converter && "expected non-null type converter");
105+
assert(!block->isEntryBlock() && "entry blocks have no predecessors");
106+
107+
// There is nothing to do if the types already match.
108+
if (block->getArgumentTypes() == expectedTypes)
109+
return block;
110+
111+
// Compute the new block argument types and convert the block.
112+
std::optional<TypeConverter::SignatureConversion> conversion =
113+
converter->convertBlockSignature(block);
114+
if (!conversion)
115+
return rewriter.notifyMatchFailure(branchOp,
116+
"could not compute block signature");
117+
if (expectedTypes != conversion->getConvertedTypes())
118+
return rewriter.notifyMatchFailure(
119+
branchOp,
120+
"mismatch between adaptor operand types and computed block signature");
121+
return rewriter.applySignatureConversion(block, *conversion, converter);
129122
}
130123

131-
/// Ensure that all block types were updated and then create an LLVM::BrOp
124+
/// Convert the destination block signature (if necessary) and lower the branch
125+
/// op to llvm.br.
132126
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
133127
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134128

135129
LogicalResult
136130
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137131
ConversionPatternRewriter &rewriter) const override {
138-
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139-
op.getSuccessor()->getArguments(),
140-
op.getLoc(),
141-
/*messagePrefix=*/"")))
132+
FailureOr<Block *> convertedBlock =
133+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
134+
TypeRange(adaptor.getOperands()));
135+
if (failed(convertedBlock))
142136
return failure();
143-
144-
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145-
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
137+
rewriter.replaceOpWithNewOp<LLVM::BrOp>(op, adaptor.getOperands(),
138+
*convertedBlock);
146139
return success();
147140
}
148141
};
149142

150-
/// Ensure that all block types were updated and then create an LLVM::CondBrOp
143+
/// Convert the destination block signatures (if necessary) and lower the
144+
/// branch op to llvm.cond_br.
151145
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
152146
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153147

154148
LogicalResult
155149
matchAndRewrite(cf::CondBranchOp op,
156150
typename cf::CondBranchOp::Adaptor adaptor,
157151
ConversionPatternRewriter &rewriter) const override {
158-
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159-
op.getFalseDest()->getArguments(),
160-
op.getLoc(), "in false case branch ")))
152+
FailureOr<Block *> convertedTrueBlock =
153+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
154+
TypeRange(adaptor.getTrueDestOperands()));
155+
if (failed(convertedTrueBlock))
161156
return failure();
162-
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163-
op.getTrueDest()->getArguments(),
164-
op.getLoc(), "in true case branch ")))
157+
FailureOr<Block *> convertedFalseBlock =
158+
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
159+
TypeRange(adaptor.getFalseDestOperands()));
160+
if (failed(convertedFalseBlock))
165161
return failure();
166-
167162
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168-
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
163+
op, adaptor.getCondition(), *convertedTrueBlock,
164+
adaptor.getTrueDestOperands(), *convertedFalseBlock,
165+
adaptor.getFalseDestOperands());
169166
return success();
170167
}
171168
};
172169

173-
/// Ensure that all block types were updated and then create an LLVM::SwitchOp
170+
/// Convert the destination block signatures (if necessary) and lower the
171+
/// switch op to llvm.switch.
174172
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
175173
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176174

177175
LogicalResult
178176
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179177
ConversionPatternRewriter &rewriter) const override {
180-
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181-
op.getDefaultDestination()->getArguments(),
182-
op.getLoc(), "in switch default case ")))
178+
// Get or convert default block.
179+
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
180+
rewriter, getTypeConverter(), op, op.getDefaultDestination(),
181+
TypeRange(adaptor.getDefaultOperands()));
182+
if (failed(convertedDefaultBlock))
183183
return failure();
184184

185-
for (const auto &i : llvm::enumerate(
186-
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187-
if (failed(verifyMatchingValues(
188-
rewriter, std::get<0>(i.value()),
189-
std::get<1>(i.value())->getArguments(), op.getLoc(),
190-
"in switch case " + std::to_string(i.index()) + " "))) {
185+
// Get or convert all case blocks.
186+
SmallVector<Block *> caseDestinations;
187+
SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
188+
for (auto it : llvm::enumerate(op.getCaseDestinations())) {
189+
Block *b = it.value();
190+
FailureOr<Block *> convertedBlock =
191+
getConvertedBlock(rewriter, getTypeConverter(), op, b,
192+
TypeRange(caseOperands[it.index()]));
193+
if (failed(convertedBlock))
191194
return failure();
192-
}
195+
caseDestinations.push_back(*convertedBlock);
193196
}
194197

195198
rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196-
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
199+
op, adaptor.getFlag(), *convertedDefaultBlock,
200+
adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
201+
caseDestinations, caseOperands);
197202
return success();
198203
}
199204
};
@@ -230,14 +235,22 @@ struct ConvertControlFlowToLLVM
230235

231236
/// Run the dialect converter on the module.
232237
void runOnOperation() override {
233-
LLVMConversionTarget target(getContext());
234-
RewritePatternSet patterns(&getContext());
235-
236-
LowerToLLVMOptions options(&getContext());
238+
MLIRContext *ctx = &getContext();
239+
LLVMConversionTarget target(*ctx);
240+
// This pass lowers only CF dialect ops, but it also modifies block
241+
// signatures inside other ops. These ops should be treated as legal. They
242+
// are lowered by other passes.
243+
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
244+
return op->getDialect() !=
245+
ctx->getLoadedDialect<cf::ControlFlowDialect>();
246+
});
247+
248+
LowerToLLVMOptions options(ctx);
237249
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238250
options.overrideIndexBitwidth(indexBitwidth);
239251

240-
LLVMTypeConverter converter(&getContext(), options);
252+
LLVMTypeConverter converter(ctx, options);
253+
RewritePatternSet patterns(ctx);
241254
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
242255

243256
if (failed(applyPartialConversion(getOperation(), target,

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
432432

433433
rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
434434
newFuncOp.end());
435-
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
436-
&result))) {
437-
return rewriter.notifyMatchFailure(funcOp,
438-
"region types conversion failed");
439-
}
435+
// Convert just the entry block. The remaining unstructured control flow is
436+
// converted by ControlFlowToLLVM.
437+
if (!newFuncOp.getBody().empty())
438+
rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
439+
&converter);
440440

441441
// Fix the type mismatch between the materialized `llvm.ptr` and the expected
442442
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
@@ -785,10 +785,6 @@ struct ConvertFuncToLLVMPass
785785
RewritePatternSet patterns(&getContext());
786786
populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
787787

788-
// TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
789-
// favor of their dedicated conversion passes.
790-
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
791-
792788
LLVMConversionTarget target(getContext());
793789
if (failed(applyPartialConversion(m, target, std::move(patterns))))
794790
signalPassFailure();
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// RUN: mlir-opt %s -convert-cf-to-llvm -split-input-file | FileCheck %s
2+
3+
// Unstructured control flow is converted, but the enclosing op is not
4+
// converted.
5+
6+
// CHECK-LABEL: func.func @cf_br(
7+
// CHECK-SAME: %[[arg0:.*]]: index) -> index {
8+
// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : index to i64
9+
// CHECK: llvm.br ^[[bb1:.*]](%[[cast0]] : i64)
10+
// CHECK: ^[[bb1]](%[[arg1:.*]]: i64):
11+
// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[arg1]] : i64 to index
12+
// CHECK: return %[[cast1]] : index
13+
// CHECK: }
14+
func.func @cf_br(%arg0: index) -> index {
15+
cf.br ^bb1(%arg0 : index)
16+
^bb1(%arg1: index):
17+
return %arg1 : index
18+
}
19+
20+
// -----
21+
22+
// func.func and func.return types match. No unrealized_conversion_cast is
23+
// needed.
24+
25+
// CHECK-LABEL: func.func @cf_br_type_match(
26+
// CHECK-SAME: %[[arg0:.*]]: i64) -> i64 {
27+
// CHECK: llvm.br ^[[bb1:.*]](%[[arg0:.*]] : i64)
28+
// CHECK: ^[[bb1]](%[[arg1:.*]]: i64):
29+
// CHECK: return %[[arg1]] : i64
30+
// CHECK: }
31+
func.func @cf_br_type_match(%arg0: i64) -> i64 {
32+
cf.br ^bb1(%arg0 : i64)
33+
^bb1(%arg1: i64):
34+
return %arg1 : i64
35+
}
36+
37+
// -----
38+
39+
// Test case for cf.cond_br.
40+
41+
// CHECK-LABEL: func.func @cf_cond_br
42+
// CHECK-COUNT-2: unrealized_conversion_cast {{.*}} : index to i64
43+
// CHECK: llvm.cond_br %{{.*}}, ^{{.*}}(%{{.*}} : i64), ^{{.*}}(%{{.*}} : i64)
44+
// CHECK: ^{{.*}}(%{{.*}}: i64):
45+
// CHECK: unrealized_conversion_cast {{.*}} : i64 to index
46+
// CHECK: ^{{.*}}(%{{.*}}: i64):
47+
// CHECK: unrealized_conversion_cast {{.*}} : i64 to index
48+
func.func @cf_cond_br(%cond: i1, %a: index, %b: index) -> index {
49+
cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index)
50+
^bb1(%arg1: index):
51+
return %arg1 : index
52+
^bb2(%arg2: index):
53+
return %arg2 : index
54+
}
55+
56+
// -----
57+
58+
// Unreachable block (and IR in general) is not converted during a dialect
59+
// conversion.
60+
61+
// CHECK-LABEL: func.func @unreachable_block()
62+
// CHECK: return
63+
// CHECK: ^[[bb1:.*]](%[[arg0:.*]]: index):
64+
// CHECK: cf.br ^[[bb1]](%[[arg0]] : index)
65+
func.func @unreachable_block() {
66+
return
67+
^bb1(%arg0: index):
68+
cf.br ^bb1(%arg0 : index)
69+
}

mlir/test/Conversion/ControlFlowToLLVM/invalid.mlir

Lines changed: 0 additions & 42 deletions
This file was deleted.

0 commit comments

Comments
 (0)