Skip to content

Commit 2a1f795

Browse files
authored
[MLIR] Fix import of invokes with mismatched variadic types (#124828)
This resolves the same issue addressed in #124286, but for invoke operations. The issue arose from duplicated logic for both imports. This PR also refactors the common import code for call and invoke instructions to mitigate issues in the future.
1 parent 1ac3665 commit 2a1f795

File tree

3 files changed

+153
-105
lines changed

3 files changed

+153
-105
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -316,24 +316,32 @@ class ModuleImport {
316316
LogicalResult convertBranchArgs(llvm::Instruction *branch,
317317
llvm::BasicBlock *target,
318318
SmallVectorImpl<Value> &blockArguments);
319-
/// Appends the converted result type and operands of `callInst` to the
320-
/// `types` and `operands` arrays. For indirect calls, the method additionally
321-
/// inserts the called function at the beginning of the `operands` array.
322-
/// If `allowInlineAsm` is set to false (the default), it will return failure
323-
/// if the called operand is an inline asm which isn't convertible to MLIR as
324-
/// a value.
325-
LogicalResult convertCallTypeAndOperands(llvm::CallBase *callInst,
326-
SmallVectorImpl<Type> &types,
327-
SmallVectorImpl<Value> &operands,
328-
bool allowInlineAsm = false);
329-
/// Converts the parameter attributes attached to `func` and adds them to the
330-
/// `funcOp`.
319+
/// Convert `callInst` operands. For indirect calls, the method additionally
320+
/// inserts the called function at the beginning of the returned `operands`
321+
/// array. If `allowInlineAsm` is set to false (the default), it will return
322+
/// failure if the called operand is an inline asm which isn't convertible to
323+
/// MLIR as a value.
324+
FailureOr<SmallVector<Value>>
325+
convertCallOperands(llvm::CallBase *callInst, bool allowInlineAsm = false);
326+
/// Converts the callee's function type. For direct calls, it converts the
327+
/// actual function type, which may differ from the called operand type in
328+
/// variadic functions. For indirect calls, it converts the function type
329+
/// associated with the call instruction.
330+
LLVMFunctionType convertFunctionType(llvm::CallBase *callInst);
331+
/// Returns the callee name, or an empty symbol if the call is not direct.
332+
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
333+
/// Converts the parameter attributes attached to `func` and adds them to
334+
/// the `funcOp`.
331335
void convertParameterAttributes(llvm::Function *func, LLVMFuncOp funcOp,
332336
OpBuilder &builder);
333337
/// Converts the AttributeSet of one parameter in LLVM IR to a corresponding
334338
/// DictionaryAttr for the LLVM dialect.
335339
DictionaryAttr convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
336340
OpBuilder &builder);
341+
/// Converts the attributes attached to `inst` and adds them to the `op`.
342+
LogicalResult convertCallAttributes(llvm::CallInst *inst, CallOp op);
343+
/// Converts the attributes attached to `inst` and adds them to the `op`.
344+
LogicalResult convertInvokeAttributes(llvm::InvokeInst *inst, InvokeOp op);
337345
/// Returns the builtin type equivalent to the given LLVM dialect type or
338346
/// nullptr if there is no equivalent. The returned type can be used to create
339347
/// an attribute for a GlobalOp or a ConstantOp.

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 115 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
139139
if (iface.isConvertibleInstruction(inst->getOpcode()))
140140
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
141141
moduleImport);
142-
// TODO: Implement the `convertInstruction` hooks in the
143-
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
142+
// TODO: Implement the `convertInstruction` hooks in the
143+
// `LLVMDialectLLVMIRImportInterface` and move the following include there.
144144
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
145145
return failure();
146146
}
@@ -1489,16 +1489,15 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
14891489
return success();
14901490
}
14911491

1492-
LogicalResult ModuleImport::convertCallTypeAndOperands(
1493-
llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
1494-
SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
1495-
if (!callInst->getType()->isVoidTy())
1496-
types.push_back(convertType(callInst->getType()));
1497-
1492+
FailureOr<SmallVector<Value>>
1493+
ModuleImport::convertCallOperands(llvm::CallBase *callInst,
1494+
bool allowInlineAsm) {
14981495
bool isInlineAsm = callInst->isInlineAsm();
14991496
if (isInlineAsm && !allowInlineAsm)
15001497
return failure();
15011498

1499+
SmallVector<Value> operands;
1500+
15021501
// Cannot use isIndirectCall() here because we need to handle Constant callees
15031502
// that are not considered indirect calls by LLVM. However, in MLIR, they are
15041503
// treated as indirect calls to constant operands that need to be converted.
@@ -1515,8 +1514,29 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
15151514
FailureOr<SmallVector<Value>> arguments = convertValues(args);
15161515
if (failed(arguments))
15171516
return failure();
1517+
15181518
llvm::append_range(operands, *arguments);
1519-
return success();
1519+
return operands;
1520+
}
1521+
1522+
LLVMFunctionType ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
1523+
llvm::Value *calledOperand = callInst->getCalledOperand();
1524+
Type converted = [&] {
1525+
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1526+
return convertType(callee->getFunctionType());
1527+
return convertType(callInst->getFunctionType());
1528+
}();
1529+
1530+
if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
1531+
return funcTy;
1532+
return {};
1533+
}
1534+
1535+
FlatSymbolRefAttr ModuleImport::convertCalleeName(llvm::CallBase *callInst) {
1536+
llvm::Value *calledOperand = callInst->getCalledOperand();
1537+
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1538+
return SymbolRefAttr::get(context, callee->getName());
1539+
return {};
15201540
}
15211541

15221542
LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
@@ -1603,75 +1623,45 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16031623
auto callInst = cast<llvm::CallInst>(inst);
16041624
llvm::Value *calledOperand = callInst->getCalledOperand();
16051625

1606-
SmallVector<Type> types;
1607-
SmallVector<Value> operands;
1608-
if (failed(convertCallTypeAndOperands(callInst, types, operands,
1609-
/*allowInlineAsm=*/true)))
1626+
FailureOr<SmallVector<Value>> operands =
1627+
convertCallOperands(callInst, /*allowInlineAsm=*/true);
1628+
if (failed(operands))
16101629
return failure();
16111630

1612-
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1613-
Type resultTy = convertType(callInst->getType());
1614-
if (!resultTy)
1615-
return failure();
1616-
auto callOp = builder.create<InlineAsmOp>(
1617-
loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
1618-
builder.getStringAttr(asmI->getConstraintString()),
1619-
/*has_side_effects=*/true,
1620-
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
1621-
/*operand_attrs=*/nullptr);
1622-
if (!callInst->getType()->isVoidTy())
1623-
mapValue(inst, callOp.getResult(0));
1624-
else
1625-
mapNoResultOp(inst, callOp);
1626-
} else {
1627-
auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1628-
// Retrieve the real function type. For direct calls, use the callee's
1629-
// function type, as it may differ from the operand type in the case of
1630-
// variadic functions. For indirect calls, use the call function type.
1631-
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1632-
return convertType(callee->getFunctionType());
1633-
return convertType(callInst->getFunctionType());
1634-
}());
1635-
1636-
if (!funcTy)
1637-
return failure();
1631+
auto callOp = [&]() -> FailureOr<Operation *> {
1632+
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1633+
Type resultTy = convertType(callInst->getType());
1634+
if (!resultTy)
1635+
return failure();
1636+
return builder
1637+
.create<InlineAsmOp>(
1638+
loc, resultTy, *operands,
1639+
builder.getStringAttr(asmI->getAsmString()),
1640+
builder.getStringAttr(asmI->getConstraintString()),
1641+
/*has_side_effects=*/true,
1642+
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
1643+
/*operand_attrs=*/nullptr)
1644+
.getOperation();
1645+
} else {
1646+
LLVMFunctionType funcTy = convertFunctionType(callInst);
1647+
if (!funcTy)
1648+
return failure();
16381649

1639-
auto callOp = [&]() -> CallOp {
1640-
if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1641-
auto name = SymbolRefAttr::get(context, callee->getName());
1642-
return builder.create<CallOp>(loc, funcTy, name, operands);
1643-
}
1644-
return builder.create<CallOp>(loc, funcTy, operands);
1645-
}();
1646-
1647-
// Handle function attributes.
1648-
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1649-
callOp.setTailCallKind(
1650-
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
1651-
setFastmathFlagsAttr(inst, callOp);
1652-
1653-
callOp.setConvergent(callInst->isConvergent());
1654-
callOp.setNoUnwind(callInst->doesNotThrow());
1655-
callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
1656-
1657-
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1658-
ModRefInfo othermem = convertModRefInfoFromLLVM(
1659-
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1660-
ModRefInfo argMem = convertModRefInfoFromLLVM(
1661-
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1662-
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1663-
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1664-
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem,
1665-
argMem, inaccessibleMem);
1666-
// Only set the attribute when it does not match the default value.
1667-
if (!memAttr.isReadWrite())
1668-
callOp.setMemoryEffectsAttr(memAttr);
1669-
1670-
if (!callInst->getType()->isVoidTy())
1671-
mapValue(inst, callOp.getResult());
1672-
else
1673-
mapNoResultOp(inst, callOp);
1674-
}
1650+
FlatSymbolRefAttr callee = convertCalleeName(callInst);
1651+
auto callOp = builder.create<CallOp>(loc, funcTy, callee, *operands);
1652+
if (failed(convertCallAttributes(callInst, callOp)))
1653+
return failure();
1654+
return callOp.getOperation();
1655+
}
1656+
}();
1657+
1658+
if (failed(callOp))
1659+
return failure();
1660+
1661+
if (!callInst->getType()->isVoidTy())
1662+
mapValue(inst, (*callOp)->getResult(0));
1663+
else
1664+
mapNoResultOp(inst, *callOp);
16751665
return success();
16761666
}
16771667
if (inst->getOpcode() == llvm::Instruction::LandingPad) {
@@ -1695,9 +1685,11 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16951685
if (inst->getOpcode() == llvm::Instruction::Invoke) {
16961686
auto *invokeInst = cast<llvm::InvokeInst>(inst);
16971687

1698-
SmallVector<Type> types;
1699-
SmallVector<Value> operands;
1700-
if (failed(convertCallTypeAndOperands(invokeInst, types, operands)))
1688+
if (invokeInst->isInlineAsm())
1689+
return emitError(loc) << "invoke of inline assembly is not supported";
1690+
1691+
FailureOr<SmallVector<Value>> operands = convertCallOperands(invokeInst);
1692+
if (failed(operands))
17011693
return failure();
17021694

17031695
// Check whether the invoke result is an argument to the normal destination
@@ -1724,27 +1716,22 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
17241716
unwindArgs)))
17251717
return failure();
17261718

1727-
auto funcTy =
1728-
dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
1719+
auto funcTy = convertFunctionType(invokeInst);
17291720
if (!funcTy)
17301721
return failure();
17311722

1723+
FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
1724+
17321725
// Create the invoke operation. Normal destination block arguments will be
17331726
// added later on to handle the case in which the operation result is
17341727
// included in this list.
1735-
InvokeOp invokeOp;
1736-
if (llvm::Function *callee = invokeInst->getCalledFunction()) {
1737-
invokeOp = builder.create<InvokeOp>(
1738-
loc, funcTy,
1739-
SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
1740-
directNormalDest, ValueRange(),
1741-
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1742-
} else {
1743-
invokeOp = builder.create<InvokeOp>(
1744-
loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
1745-
ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1746-
}
1747-
invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
1728+
auto invokeOp = builder.create<InvokeOp>(
1729+
loc, funcTy, calleeName, *operands, directNormalDest, ValueRange(),
1730+
lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1731+
1732+
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
1733+
return failure();
1734+
17481735
if (!invokeInst->getType()->isVoidTy())
17491736
mapValue(inst, invokeOp.getResults().front());
17501737
else
@@ -2097,6 +2084,41 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
20972084
builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
20982085
}
20992086

2087+
template <typename Op>
2088+
static LogicalResult convertCallBaseAttributes(llvm::CallBase *inst, Op op) {
2089+
op.setCConv(convertCConvFromLLVM(inst->getCallingConv()));
2090+
return success();
2091+
}
2092+
2093+
LogicalResult ModuleImport::convertInvokeAttributes(llvm::InvokeInst *inst,
2094+
InvokeOp op) {
2095+
return convertCallBaseAttributes(inst, op);
2096+
}
2097+
2098+
LogicalResult ModuleImport::convertCallAttributes(llvm::CallInst *inst,
2099+
CallOp op) {
2100+
setFastmathFlagsAttr(inst, op.getOperation());
2101+
op.setTailCallKind(convertTailCallKindFromLLVM(inst->getTailCallKind()));
2102+
op.setConvergent(inst->isConvergent());
2103+
op.setNoUnwind(inst->doesNotThrow());
2104+
op.setWillReturn(inst->hasFnAttr(llvm::Attribute::WillReturn));
2105+
2106+
llvm::MemoryEffects memEffects = inst->getMemoryEffects();
2107+
ModRefInfo othermem = convertModRefInfoFromLLVM(
2108+
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
2109+
ModRefInfo argMem = convertModRefInfoFromLLVM(
2110+
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
2111+
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
2112+
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
2113+
auto memAttr = MemoryEffectsAttr::get(op.getContext(), othermem, argMem,
2114+
inaccessibleMem);
2115+
// Only set the attribute when it does not match the default value.
2116+
if (!memAttr.isReadWrite())
2117+
op.setMemoryEffectsAttr(memAttr);
2118+
2119+
return convertCallBaseAttributes(inst, op);
2120+
}
2121+
21002122
LogicalResult ModuleImport::processFunction(llvm::Function *func) {
21012123
clearRegionState();
21022124

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,3 +702,21 @@ define void @fence() {
702702
fence syncscope("") seq_cst
703703
ret void
704704
}
705+
706+
; // -----
707+
708+
; CHECK-LABEL: @f
709+
define void @f() personality ptr @__gxx_personality_v0 {
710+
entry:
711+
; CHECK: llvm.invoke @g() to ^bb1 unwind ^bb2 vararg(!llvm.func<void (...)>) : () -> ()
712+
invoke void @g() to label %bb1 unwind label %bb2
713+
bb1:
714+
ret void
715+
bb2:
716+
%0 = landingpad i32 cleanup
717+
unreachable
718+
}
719+
720+
declare void @g(...)
721+
722+
declare i32 @__gxx_personality_v0(...)

0 commit comments

Comments
 (0)