@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
139
139
if (iface.isConvertibleInstruction (inst->getOpcode ()))
140
140
return iface.convertInstruction (odsBuilder, inst, llvmOperands,
141
141
moduleImport);
142
- // TODO: Implement the `convertInstruction` hooks in the
143
- // `LLVMDialectLLVMIRImportInterface` and move the following include there.
142
+ // TODO: Implement the `convertInstruction` hooks in the
143
+ // `LLVMDialectLLVMIRImportInterface` and move the following include there.
144
144
#include " mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
145
145
return failure ();
146
146
}
@@ -1489,16 +1489,15 @@ ModuleImport::convertBranchArgs(llvm::Instruction *branch,
1489
1489
return success ();
1490
1490
}
1491
1491
1492
- LogicalResult ModuleImport::convertCallTypeAndOperands (
1493
- llvm::CallBase *callInst, SmallVectorImpl<Type> &types,
1494
- SmallVectorImpl<Value> &operands, bool allowInlineAsm) {
1495
- if (!callInst->getType ()->isVoidTy ())
1496
- types.push_back (convertType (callInst->getType ()));
1497
-
1492
+ FailureOr<SmallVector<Value>>
1493
+ ModuleImport::convertCallOperands (llvm::CallBase *callInst,
1494
+ bool allowInlineAsm) {
1498
1495
bool isInlineAsm = callInst->isInlineAsm ();
1499
1496
if (isInlineAsm && !allowInlineAsm)
1500
1497
return failure ();
1501
1498
1499
+ SmallVector<Value> operands;
1500
+
1502
1501
// Cannot use isIndirectCall() here because we need to handle Constant callees
1503
1502
// that are not considered indirect calls by LLVM. However, in MLIR, they are
1504
1503
// treated as indirect calls to constant operands that need to be converted.
@@ -1515,8 +1514,29 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
1515
1514
FailureOr<SmallVector<Value>> arguments = convertValues (args);
1516
1515
if (failed (arguments))
1517
1516
return failure ();
1517
+
1518
1518
llvm::append_range (operands, *arguments);
1519
- return success ();
1519
+ return operands;
1520
+ }
1521
+
1522
+ LLVMFunctionType ModuleImport::convertFunctionType (llvm::CallBase *callInst) {
1523
+ llvm::Value *calledOperand = callInst->getCalledOperand ();
1524
+ Type converted = [&] {
1525
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1526
+ return convertType (callee->getFunctionType ());
1527
+ return convertType (callInst->getFunctionType ());
1528
+ }();
1529
+
1530
+ if (auto funcTy = dyn_cast_or_null<LLVMFunctionType>(converted))
1531
+ return funcTy;
1532
+ return {};
1533
+ }
1534
+
1535
+ FlatSymbolRefAttr ModuleImport::convertCalleeName (llvm::CallBase *callInst) {
1536
+ llvm::Value *calledOperand = callInst->getCalledOperand ();
1537
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1538
+ return SymbolRefAttr::get (context, callee->getName ());
1539
+ return {};
1520
1540
}
1521
1541
1522
1542
LogicalResult ModuleImport::convertIntrinsic (llvm::CallInst *inst) {
@@ -1603,75 +1623,45 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1603
1623
auto callInst = cast<llvm::CallInst>(inst);
1604
1624
llvm::Value *calledOperand = callInst->getCalledOperand ();
1605
1625
1606
- SmallVector<Type> types;
1607
- SmallVector<Value> operands;
1608
- if (failed (convertCallTypeAndOperands (callInst, types, operands,
1609
- /* allowInlineAsm=*/ true )))
1626
+ FailureOr<SmallVector<Value>> operands =
1627
+ convertCallOperands (callInst, /* allowInlineAsm=*/ true );
1628
+ if (failed (operands))
1610
1629
return failure ();
1611
1630
1612
- if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1613
- Type resultTy = convertType (callInst->getType ());
1614
- if (!resultTy)
1615
- return failure ();
1616
- auto callOp = builder.create <InlineAsmOp>(
1617
- loc, resultTy, operands, builder.getStringAttr (asmI->getAsmString ()),
1618
- builder.getStringAttr (asmI->getConstraintString ()),
1619
- /* has_side_effects=*/ true ,
1620
- /* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
1621
- /* operand_attrs=*/ nullptr );
1622
- if (!callInst->getType ()->isVoidTy ())
1623
- mapValue (inst, callOp.getResult (0 ));
1624
- else
1625
- mapNoResultOp (inst, callOp);
1626
- } else {
1627
- auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1628
- // Retrieve the real function type. For direct calls, use the callee's
1629
- // function type, as it may differ from the operand type in the case of
1630
- // variadic functions. For indirect calls, use the call function type.
1631
- if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1632
- return convertType (callee->getFunctionType ());
1633
- return convertType (callInst->getFunctionType ());
1634
- }());
1635
-
1636
- if (!funcTy)
1637
- return failure ();
1631
+ auto callOp = [&]() -> FailureOr<Operation *> {
1632
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1633
+ Type resultTy = convertType (callInst->getType ());
1634
+ if (!resultTy)
1635
+ return failure ();
1636
+ return builder
1637
+ .create <InlineAsmOp>(
1638
+ loc, resultTy, *operands,
1639
+ builder.getStringAttr (asmI->getAsmString ()),
1640
+ builder.getStringAttr (asmI->getConstraintString ()),
1641
+ /* has_side_effects=*/ true ,
1642
+ /* is_align_stack=*/ false , /* asm_dialect=*/ nullptr ,
1643
+ /* operand_attrs=*/ nullptr )
1644
+ .getOperation ();
1645
+ } else {
1646
+ LLVMFunctionType funcTy = convertFunctionType (callInst);
1647
+ if (!funcTy)
1648
+ return failure ();
1638
1649
1639
- auto callOp = [&]() -> CallOp {
1640
- if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1641
- auto name = SymbolRefAttr::get (context, callee->getName ());
1642
- return builder.create <CallOp>(loc, funcTy, name, operands);
1643
- }
1644
- return builder.create <CallOp>(loc, funcTy, operands);
1645
- }();
1646
-
1647
- // Handle function attributes.
1648
- callOp.setCConv (convertCConvFromLLVM (callInst->getCallingConv ()));
1649
- callOp.setTailCallKind (
1650
- convertTailCallKindFromLLVM (callInst->getTailCallKind ()));
1651
- setFastmathFlagsAttr (inst, callOp);
1652
-
1653
- callOp.setConvergent (callInst->isConvergent ());
1654
- callOp.setNoUnwind (callInst->doesNotThrow ());
1655
- callOp.setWillReturn (callInst->hasFnAttr (llvm::Attribute::WillReturn));
1656
-
1657
- llvm::MemoryEffects memEffects = callInst->getMemoryEffects ();
1658
- ModRefInfo othermem = convertModRefInfoFromLLVM (
1659
- memEffects.getModRef (llvm::MemoryEffects::Location::Other));
1660
- ModRefInfo argMem = convertModRefInfoFromLLVM (
1661
- memEffects.getModRef (llvm::MemoryEffects::Location::ArgMem));
1662
- ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM (
1663
- memEffects.getModRef (llvm::MemoryEffects::Location::InaccessibleMem));
1664
- auto memAttr = MemoryEffectsAttr::get (callOp.getContext (), othermem,
1665
- argMem, inaccessibleMem);
1666
- // Only set the attribute when it does not match the default value.
1667
- if (!memAttr.isReadWrite ())
1668
- callOp.setMemoryEffectsAttr (memAttr);
1669
-
1670
- if (!callInst->getType ()->isVoidTy ())
1671
- mapValue (inst, callOp.getResult ());
1672
- else
1673
- mapNoResultOp (inst, callOp);
1674
- }
1650
+ FlatSymbolRefAttr callee = convertCalleeName (callInst);
1651
+ auto callOp = builder.create <CallOp>(loc, funcTy, callee, *operands);
1652
+ if (failed (convertCallAttributes (callInst, callOp)))
1653
+ return failure ();
1654
+ return callOp.getOperation ();
1655
+ }
1656
+ }();
1657
+
1658
+ if (failed (callOp))
1659
+ return failure ();
1660
+
1661
+ if (!callInst->getType ()->isVoidTy ())
1662
+ mapValue (inst, (*callOp)->getResult (0 ));
1663
+ else
1664
+ mapNoResultOp (inst, *callOp);
1675
1665
return success ();
1676
1666
}
1677
1667
if (inst->getOpcode () == llvm::Instruction::LandingPad) {
@@ -1695,9 +1685,11 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1695
1685
if (inst->getOpcode () == llvm::Instruction::Invoke) {
1696
1686
auto *invokeInst = cast<llvm::InvokeInst>(inst);
1697
1687
1698
- SmallVector<Type> types;
1699
- SmallVector<Value> operands;
1700
- if (failed (convertCallTypeAndOperands (invokeInst, types, operands)))
1688
+ if (invokeInst->isInlineAsm ())
1689
+ return emitError (loc) << " invoke of inline assembly is not supported" ;
1690
+
1691
+ FailureOr<SmallVector<Value>> operands = convertCallOperands (invokeInst);
1692
+ if (failed (operands))
1701
1693
return failure ();
1702
1694
1703
1695
// Check whether the invoke result is an argument to the normal destination
@@ -1724,27 +1716,22 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1724
1716
unwindArgs)))
1725
1717
return failure ();
1726
1718
1727
- auto funcTy =
1728
- dyn_cast<LLVMFunctionType>(convertType (invokeInst->getFunctionType ()));
1719
+ auto funcTy = convertFunctionType (invokeInst);
1729
1720
if (!funcTy)
1730
1721
return failure ();
1731
1722
1723
+ FlatSymbolRefAttr calleeName = convertCalleeName (invokeInst);
1724
+
1732
1725
// Create the invoke operation. Normal destination block arguments will be
1733
1726
// added later on to handle the case in which the operation result is
1734
1727
// included in this list.
1735
- InvokeOp invokeOp;
1736
- if (llvm::Function *callee = invokeInst->getCalledFunction ()) {
1737
- invokeOp = builder.create <InvokeOp>(
1738
- loc, funcTy,
1739
- SymbolRefAttr::get (builder.getContext (), callee->getName ()), operands,
1740
- directNormalDest, ValueRange (),
1741
- lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
1742
- } else {
1743
- invokeOp = builder.create <InvokeOp>(
1744
- loc, funcTy, /* callee=*/ nullptr , operands, directNormalDest,
1745
- ValueRange (), lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
1746
- }
1747
- invokeOp.setCConv (convertCConvFromLLVM (invokeInst->getCallingConv ()));
1728
+ auto invokeOp = builder.create <InvokeOp>(
1729
+ loc, funcTy, calleeName, *operands, directNormalDest, ValueRange (),
1730
+ lookupBlock (invokeInst->getUnwindDest ()), unwindArgs);
1731
+
1732
+ if (failed (convertInvokeAttributes (invokeInst, invokeOp)))
1733
+ return failure ();
1734
+
1748
1735
if (!invokeInst->getType ()->isVoidTy ())
1749
1736
mapValue (inst, invokeOp.getResults ().front ());
1750
1737
else
@@ -2097,6 +2084,41 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
2097
2084
builder.getArrayAttr (convertParameterAttribute (llvmResAttr, builder)));
2098
2085
}
2099
2086
2087
+ template <typename Op>
2088
+ static LogicalResult convertCallBaseAttributes (llvm::CallBase *inst, Op op) {
2089
+ op.setCConv (convertCConvFromLLVM (inst->getCallingConv ()));
2090
+ return success ();
2091
+ }
2092
+
2093
+ LogicalResult ModuleImport::convertInvokeAttributes (llvm::InvokeInst *inst,
2094
+ InvokeOp op) {
2095
+ return convertCallBaseAttributes (inst, op);
2096
+ }
2097
+
2098
+ LogicalResult ModuleImport::convertCallAttributes (llvm::CallInst *inst,
2099
+ CallOp op) {
2100
+ setFastmathFlagsAttr (inst, op.getOperation ());
2101
+ op.setTailCallKind (convertTailCallKindFromLLVM (inst->getTailCallKind ()));
2102
+ op.setConvergent (inst->isConvergent ());
2103
+ op.setNoUnwind (inst->doesNotThrow ());
2104
+ op.setWillReturn (inst->hasFnAttr (llvm::Attribute::WillReturn));
2105
+
2106
+ llvm::MemoryEffects memEffects = inst->getMemoryEffects ();
2107
+ ModRefInfo othermem = convertModRefInfoFromLLVM (
2108
+ memEffects.getModRef (llvm::MemoryEffects::Location::Other));
2109
+ ModRefInfo argMem = convertModRefInfoFromLLVM (
2110
+ memEffects.getModRef (llvm::MemoryEffects::Location::ArgMem));
2111
+ ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM (
2112
+ memEffects.getModRef (llvm::MemoryEffects::Location::InaccessibleMem));
2113
+ auto memAttr = MemoryEffectsAttr::get (op.getContext (), othermem, argMem,
2114
+ inaccessibleMem);
2115
+ // Only set the attribute when it does not match the default value.
2116
+ if (!memAttr.isReadWrite ())
2117
+ op.setMemoryEffectsAttr (memAttr);
2118
+
2119
+ return convertCallBaseAttributes (inst, op);
2120
+ }
2121
+
2100
2122
LogicalResult ModuleImport::processFunction (llvm::Function *func) {
2101
2123
clearRegionState ();
2102
2124
0 commit comments