Skip to content

[mlir][CF] Split cf-to-llvm from func-to-llvm #120580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 49 additions & 9 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3287,10 +3287,40 @@ struct SelectCaseOpConversion : public fir::FIROpConversion<fir::SelectCaseOp> {
}
};

/// Helper function for converting select ops. This function converts the
/// signature of the given block. If the new block signature is different from
/// `expectedTypes`, returns "failure".
static llvm::FailureOr<mlir::Block *>
getConvertedBlock(mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter,
mlir::Operation *branchOp, mlir::Block *block,
mlir::TypeRange expectedTypes) {
assert(converter && "expected non-null type converter");
assert(!block->isEntryBlock() && "entry blocks have no predecessors");

// There is nothing to do if the types already match.
if (block->getArgumentTypes() == expectedTypes)
return block;

// Compute the new block argument types and convert the block.
std::optional<mlir::TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion)
return rewriter.notifyMatchFailure(branchOp,
"could not compute block signature");
if (expectedTypes != conversion->getConvertedTypes())
return rewriter.notifyMatchFailure(
branchOp,
"mismatch between adaptor operand types and computed block signature");
return rewriter.applySignatureConversion(block, *conversion, converter);
}

template <typename OP>
static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
OP select, typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) {
static llvm::LogicalResult
selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering, OP select,
typename OP::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
unsigned conds = select.getNumConditions();
auto cases = select.getCases().getValue();
mlir::Value selector = adaptor.getSelector();
Expand All @@ -3308,15 +3338,24 @@ static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
auto destOps = select.getSuccessorOperands(adaptor.getOperands(), t);
const mlir::Attribute &attr = cases[t];
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) {
destinations.push_back(dest);
destinationsOperands.push_back(destOps ? *destOps : mlir::ValueRange{});
auto convertedBlock =
getConvertedBlock(rewriter, converter, select, dest,
mlir::TypeRange(destinationsOperands.back()));
if (mlir::failed(convertedBlock))
return mlir::failure();
destinations.push_back(*convertedBlock);
caseValues.push_back(intAttr.getInt());
continue;
}
assert(mlir::dyn_cast_or_null<mlir::UnitAttr>(attr));
assert((t + 1 == conds) && "unit must be last");
defaultDestination = dest;
defaultOperands = destOps ? *destOps : mlir::ValueRange{};
auto convertedBlock = getConvertedBlock(rewriter, converter, select, dest,
mlir::TypeRange(defaultOperands));
if (mlir::failed(convertedBlock))
return mlir::failure();
defaultDestination = *convertedBlock;
}

// LLVM::SwitchOp takes a i32 type for the selector.
Expand All @@ -3332,6 +3371,7 @@ static void selectMatchAndRewrite(const fir::LLVMTypeConverter &lowering,
/*caseDestinations=*/destinations,
/*caseOperands=*/destinationsOperands,
/*branchWeights=*/llvm::ArrayRef<std::int32_t>());
return mlir::success();
}

/// conversion of fir::SelectOp to an if-then-else ladder
Expand All @@ -3341,8 +3381,8 @@ struct SelectOpConversion : public fir::FIROpConversion<fir::SelectOp> {
llvm::LogicalResult
matchAndRewrite(fir::SelectOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor, rewriter);
return mlir::success();
return selectMatchAndRewrite<fir::SelectOp>(lowerTy(), op, adaptor,
rewriter, getTypeConverter());
}
};

Expand All @@ -3353,8 +3393,8 @@ struct SelectRankOpConversion : public fir::FIROpConversion<fir::SelectRankOp> {
llvm::LogicalResult
matchAndRewrite(fir::SelectRankOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
selectMatchAndRewrite<fir::SelectRankOp>(lowerTy(), op, adaptor, rewriter);
return mlir::success();
return selectMatchAndRewrite<fir::SelectRankOp>(
lowerTy(), op, adaptor, rewriter, getTypeConverter());
}
};

Expand Down
4 changes: 0 additions & 4 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,6 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
1 value is returned, packed into an LLVM IR struct type. Function calls and
returns are updated accordingly. Block argument types are updated to use
LLVM IR types.

Note that until https://github.com/llvm/llvm-project/issues/70982 is resolved,
this pass includes patterns that lower `arith` and `cf` to LLVM. This is legacy
code due to when they were all converted in the same pass.
}];
let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Expand Down
153 changes: 86 additions & 67 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,106 +94,117 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
bool abortOnFailedAssert = true;
};

/// The cf->LLVM lowerings for branching ops require that the blocks they jump
/// to first have updated types which should be handled by a pattern operating
/// on the parent op.
static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
ValueRange operands,
ValueRange blockArgs, Location loc,
llvm::StringRef messagePrefix) {
for (const auto &idxAndTypes :
llvm::enumerate(llvm::zip(blockArgs, operands))) {
int64_t i = idxAndTypes.index();
Value argValue =
rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
Type operandType = std::get<1>(idxAndTypes.value()).getType();
// In the case of an invalid jump, the block argument will have been
// remapped to an UnrealizedConversionCast. In the case of a valid jump,
// there might still be a no-op conversion cast with both types being equal.
// Consider both of these details to see if the jump would be invalid.
if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
argValue.getDefiningOp())) {
if (op.getOperandTypes().front() != operandType) {
return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
diag << messagePrefix;
diag << "mismatched types from operand # " << i << " ";
diag << operandType;
diag << " not compatible with destination block argument type ";
diag << op.getOperandTypes().front();
diag << " which should be converted with the parent op.";
});
}
}
}
return success();
/// Helper function for converting branch ops. This function converts the
/// signature of the given block. If the new block signature is different from
/// `expectedTypes`, returns "failure".
static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
const TypeConverter *converter,
Operation *branchOp, Block *block,
TypeRange expectedTypes) {
Comment on lines +97 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(non-actionable side note): Feels like this could be part of dialect conversion in the future 🙂 Similar to remapValues but for successors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is slightly different because replaceAllUsesWith for blocks is reflected immediately in the IR. (In contrast to value replacements, which are being kept track of in the ConversionValueMapping.)

The reason why getConvertedBlock is needed is because there could be multiple branch ops that jump to the same block. In that case, the block should be converted only once.

assert(converter && "expected non-null type converter");
assert(!block->isEntryBlock() && "entry blocks have no predecessors");

// There is nothing to do if the types already match.
if (block->getArgumentTypes() == expectedTypes)
return block;

// Compute the new block argument types and convert the block.
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion)
return rewriter.notifyMatchFailure(branchOp,
"could not compute block signature");
if (expectedTypes != conversion->getConvertedTypes())
return rewriter.notifyMatchFailure(
branchOp,
"mismatch between adaptor operand types and computed block signature");
return rewriter.applySignatureConversion(block, *conversion, converter);
}

/// Ensure that all block types were updated and then create an LLVM::BrOp
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
op.getSuccessor()->getArguments(),
op.getLoc(),
/*messagePrefix=*/"")))
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
TypeRange(adaptor.getOperands()));
if (failed(convertedBlock))
return failure();

rewriter.replaceOpWithNewOp<LLVM::BrOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
op, adaptor.getOperands(), *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(op->getAttrDictionary());
return success();
}
};

/// Ensure that all block types were updated and then create an LLVM::CondBrOp
/// Convert the destination block signatures (if necessary) and lower the
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(cf::CondBranchOp op,
typename cf::CondBranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
op.getFalseDest()->getArguments(),
op.getLoc(), "in false case branch ")))
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
TypeRange(adaptor.getTrueDestOperands()));
if (failed(convertedTrueBlock))
return failure();
if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
op.getTrueDest()->getArguments(),
op.getLoc(), "in true case branch ")))
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
TypeRange(adaptor.getFalseDestOperands()));
if (failed(convertedFalseBlock))
return failure();

rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getCondition(), *convertedTrueBlock,
adaptor.getTrueDestOperands(), *convertedFalseBlock,
adaptor.getFalseDestOperands());
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(op->getAttrDictionary());
return success();
}
};

/// Ensure that all block types were updated and then create an LLVM::SwitchOp
/// Convert the destination block signatures (if necessary) and lower the
/// switch op to llvm.switch.
struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
op.getDefaultDestination()->getArguments(),
op.getLoc(), "in switch default case ")))
// Get or convert default block.
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
rewriter, getTypeConverter(), op, op.getDefaultDestination(),
TypeRange(adaptor.getDefaultOperands()));
if (failed(convertedDefaultBlock))
return failure();

for (const auto &i : llvm::enumerate(
llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
if (failed(verifyMatchingValues(
rewriter, std::get<0>(i.value()),
std::get<1>(i.value())->getArguments(), op.getLoc(),
"in switch case " + std::to_string(i.index()) + " "))) {
// Get or convert all case blocks.
SmallVector<Block *> caseDestinations;
SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
for (auto it : llvm::enumerate(op.getCaseDestinations())) {
Block *b = it.value();
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, b,
TypeRange(caseOperands[it.index()]));
if (failed(convertedBlock))
return failure();
}
caseDestinations.push_back(*convertedBlock);
}

rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
op, adaptor.getFlag(), *convertedDefaultBlock,
adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
caseDestinations, caseOperands);
return success();
}
};
Expand Down Expand Up @@ -230,14 +241,22 @@ struct ConvertControlFlowToLLVM

/// Run the dialect converter on the module.
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());

LowerToLLVMOptions options(&getContext());
MLIRContext *ctx = &getContext();
LLVMConversionTarget target(*ctx);
// This pass lowers only CF dialect ops, but it also modifies block
// signatures inside other ops. These ops should be treated as legal. They
// are lowered by other passes.
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return op->getDialect() !=
ctx->getLoadedDialect<cf::ControlFlowDialect>();
});

LowerToLLVMOptions options(ctx);
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);

LLVMTypeConverter converter(&getContext(), options);
LLVMTypeConverter converter(ctx, options);
RewritePatternSet patterns(ctx);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
Expand Down
14 changes: 5 additions & 9 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,11 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,

rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
&result))) {
return rewriter.notifyMatchFailure(funcOp,
"region types conversion failed");
}
// Convert just the entry block. The remaining unstructured control flow is
// converted by ControlFlowToLLVM.
if (!newFuncOp.getBody().empty())
rewriter.applySignatureConversion(&newFuncOp.getBody().front(), result,
&converter);

// Fix the type mismatch between the materialized `llvm.ptr` and the expected
// pointee type in the function body when converting `llvm.byval`/`llvm.byref`
Expand Down Expand Up @@ -785,10 +785,6 @@ struct ConvertFuncToLLVMPass
RewritePatternSet patterns(&getContext());
populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);

// TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
// favor of their dedicated conversion passes.
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"

#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
Expand Down Expand Up @@ -91,6 +92,7 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
pm.addPass(createConvertFuncToLLVMPass());
pm.addPass(createArithToLLVMConversionPass());
pm.addPass(createConvertControlFlowToLLVMPass());

// Finalize GPU code generation.
if (gpuCodegen) {
Expand Down
Loading
Loading