@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
1495
1495
if (!callInst->getType ()->isVoidTy ())
1496
1496
types.push_back (convertType (callInst->getType ()));
1497
1497
1498
- if (!callInst->getCalledFunction ()) {
1499
- if (!allowInlineAsm ||
1500
- !isa<llvm::InlineAsm>(callInst->getCalledOperand ())) {
1501
- FailureOr<Value> called = convertValue (callInst->getCalledOperand ());
1502
- if (failed (called))
1503
- return failure ();
1504
- operands.push_back (*called);
1505
- }
1498
+ bool isInlineAsm = callInst->isInlineAsm ();
1499
+ if (isInlineAsm && !allowInlineAsm)
1500
+ return failure ();
1501
+
1502
+ // Cannot use isIndirectCall() here because we need to handle Constant callees
1503
+ // that are not considered indirect calls by LLVM. However, in MLIR, they are
1504
+ // treated as indirect calls to constant operands that need to be converted.
1505
+ // Skip the callee operand if it's inline assembly, as it's handled separately
1506
+ // in InlineAsmOp.
1507
+ if (!isa<llvm::Function>(callInst->getCalledOperand ()) && !isInlineAsm) {
1508
+ FailureOr<Value> called = convertValue (callInst->getCalledOperand ());
1509
+ if (failed (called))
1510
+ return failure ();
1511
+ operands.push_back (*called);
1506
1512
}
1513
+
1507
1514
SmallVector<llvm::Value *> args (callInst->args ());
1508
1515
FailureOr<SmallVector<Value>> arguments = convertValues (args);
1509
1516
if (failed (arguments))
@@ -1593,23 +1600,21 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1593
1600
return success ();
1594
1601
}
1595
1602
if (inst->getOpcode () == llvm::Instruction::Call) {
1596
- auto *callInst = cast<llvm::CallInst>(inst);
1603
+ auto callInst = cast<llvm::CallInst>(inst);
1604
+ llvm::Value *calledOperand = callInst->getCalledOperand ();
1597
1605
1598
1606
SmallVector<Type> types;
1599
1607
SmallVector<Value> operands;
1600
1608
if (failed (convertCallTypeAndOperands (callInst, types, operands,
1601
1609
/* allowInlineAsm=*/ true )))
1602
1610
return failure ();
1603
1611
1604
- auto funcTy =
1605
- dyn_cast<LLVMFunctionType>(convertType (callInst->getFunctionType ()));
1606
- if (!funcTy)
1607
- return failure ();
1608
-
1609
- if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand ())) {
1612
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1613
+ Type resultTy = convertType (callInst->getType ());
1614
+ if (!resultTy)
1615
+ return failure ();
1610
1616
auto callOp = builder.create <InlineAsmOp>(
1611
- loc, funcTy.getReturnType (), operands,
1612
- builder.getStringAttr (asmI->getAsmString ()),
1617
+ loc, resultTy, operands, builder.getStringAttr (asmI->getAsmString ()),
1613
1618
builder.getStringAttr (asmI->getConstraintString ()),
1614
1619
/* has_side_effects=*/ true ,
1615
1620
/* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
@@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1619
1624
else
1620
1625
mapNoResultOp (inst, callOp);
1621
1626
} else {
1622
- CallOp callOp;
1627
+ auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1628
+ // Retrieve the real function type. For direct calls, use the callee's
1629
+ // function type, as it may differ from the operand type in the case of
1630
+ // variadic functions. For indirect calls, use the call function type.
1631
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1632
+ return convertType (callee->getFunctionType ());
1633
+ return convertType (callInst->getFunctionType ());
1634
+ }());
1635
+
1636
+ if (!funcTy)
1637
+ return failure ();
1623
1638
1624
- if (llvm::Function *callee = callInst->getCalledFunction ()) {
1625
- callOp = builder.create <CallOp>(
1626
- loc, funcTy, SymbolRefAttr::get (context, callee->getName ()),
1627
- operands);
1628
- } else {
1629
- callOp = builder.create <CallOp>(loc, funcTy, operands);
1630
- }
1639
+ auto callOp = [&]() -> CallOp {
1640
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1641
+ auto name = SymbolRefAttr::get (context, callee->getName ());
1642
+ return builder.create <CallOp>(loc, funcTy, name, operands);
1643
+ }
1644
+ return builder.create <CallOp>(loc, funcTy, operands);
1645
+ }();
1646
+
1647
+ // Handle function attributes.
1631
1648
callOp.setCConv (convertCConvFromLLVM (callInst->getCallingConv ()));
1632
1649
callOp.setTailCallKind (
1633
1650
convertTailCallKindFromLLVM (callInst->getTailCallKind ()));
1634
1651
setFastmathFlagsAttr (inst, callOp);
1635
1652
1636
- // Handle function attributes.
1637
- if (callInst->hasFnAttr (llvm::Attribute::Convergent))
1638
- callOp.setConvergent (true );
1639
- if (callInst->hasFnAttr (llvm::Attribute::NoUnwind))
1640
- callOp.setNoUnwind (true );
1641
- if (callInst->hasFnAttr (llvm::Attribute::WillReturn))
1642
- callOp.setWillReturn (true );
1653
+ callOp.setConvergent (callInst->isConvergent ());
1654
+ callOp.setNoUnwind (callInst->doesNotThrow ());
1655
+ callOp.setWillReturn (callInst->hasFnAttr (llvm::Attribute::WillReturn));
1643
1656
1644
1657
llvm::MemoryEffects memEffects = callInst->getMemoryEffects ();
1645
1658
ModRefInfo othermem = convertModRefInfoFromLLVM (
0 commit comments