Skip to content

Commit 95d993a

Browse files
authored
[MLIR] Fix import of calls with mismatched variadic types (#124286)
Previously, an indirect call was incorrectly generated when `llvm::CallBase::getCalledFunction` returned null due to a type mismatch between the call and the function. This patch updates the code to use `llvm::CallBase::getCalledOperand` instead.
1 parent 3b30f20 commit 95d993a

File tree

2 files changed

+70
-32
lines changed

2 files changed

+70
-32
lines changed

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
14951495
if (!callInst->getType()->isVoidTy())
14961496
types.push_back(convertType(callInst->getType()));
14971497

1498-
if (!callInst->getCalledFunction()) {
1499-
if (!allowInlineAsm ||
1500-
!isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
1501-
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1502-
if (failed(called))
1503-
return failure();
1504-
operands.push_back(*called);
1505-
}
1498+
bool isInlineAsm = callInst->isInlineAsm();
1499+
if (isInlineAsm && !allowInlineAsm)
1500+
return failure();
1501+
1502+
// Cannot use isIndirectCall() here because we need to handle Constant callees
1503+
// that are not considered indirect calls by LLVM. However, in MLIR, they are
1504+
// treated as indirect calls to constant operands that need to be converted.
1505+
// Skip the callee operand if it's inline assembly, as it's handled separately
1506+
// in InlineAsmOp.
1507+
if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
1508+
FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1509+
if (failed(called))
1510+
return failure();
1511+
operands.push_back(*called);
15061512
}
1513+
15071514
SmallVector<llvm::Value *> args(callInst->args());
15081515
FailureOr<SmallVector<Value>> arguments = convertValues(args);
15091516
if (failed(arguments))
@@ -1593,23 +1600,21 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
15931600
return success();
15941601
}
15951602
if (inst->getOpcode() == llvm::Instruction::Call) {
1596-
auto *callInst = cast<llvm::CallInst>(inst);
1603+
auto callInst = cast<llvm::CallInst>(inst);
1604+
llvm::Value *calledOperand = callInst->getCalledOperand();
15971605

15981606
SmallVector<Type> types;
15991607
SmallVector<Value> operands;
16001608
if (failed(convertCallTypeAndOperands(callInst, types, operands,
16011609
/*allowInlineAsm=*/true)))
16021610
return failure();
16031611

1604-
auto funcTy =
1605-
dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
1606-
if (!funcTy)
1607-
return failure();
1608-
1609-
if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
1612+
if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
1613+
Type resultTy = convertType(callInst->getType());
1614+
if (!resultTy)
1615+
return failure();
16101616
auto callOp = builder.create<InlineAsmOp>(
1611-
loc, funcTy.getReturnType(), operands,
1612-
builder.getStringAttr(asmI->getAsmString()),
1617+
loc, resultTy, operands, builder.getStringAttr(asmI->getAsmString()),
16131618
builder.getStringAttr(asmI->getConstraintString()),
16141619
/*has_side_effects=*/true,
16151620
/*is_align_stack=*/false, /*asm_dialect=*/nullptr,
@@ -1619,27 +1624,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
16191624
else
16201625
mapNoResultOp(inst, callOp);
16211626
} else {
1622-
CallOp callOp;
1627+
auto funcTy = dyn_cast<LLVMFunctionType>([&]() -> Type {
1628+
// Retrieve the real function type. For direct calls, use the callee's
1629+
// function type, as it may differ from the operand type in the case of
1630+
// variadic functions. For indirect calls, use the call function type.
1631+
if (auto callee = dyn_cast<llvm::Function>(calledOperand))
1632+
return convertType(callee->getFunctionType());
1633+
return convertType(callInst->getFunctionType());
1634+
}());
1635+
1636+
if (!funcTy)
1637+
return failure();
16231638

1624-
if (llvm::Function *callee = callInst->getCalledFunction()) {
1625-
callOp = builder.create<CallOp>(
1626-
loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1627-
operands);
1628-
} else {
1629-
callOp = builder.create<CallOp>(loc, funcTy, operands);
1630-
}
1639+
auto callOp = [&]() -> CallOp {
1640+
if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
1641+
auto name = SymbolRefAttr::get(context, callee->getName());
1642+
return builder.create<CallOp>(loc, funcTy, name, operands);
1643+
}
1644+
return builder.create<CallOp>(loc, funcTy, operands);
1645+
}();
1646+
1647+
// Handle function attributes.
16311648
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
16321649
callOp.setTailCallKind(
16331650
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
16341651
setFastmathFlagsAttr(inst, callOp);
16351652

1636-
// Handle function attributes.
1637-
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
1638-
callOp.setConvergent(true);
1639-
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
1640-
callOp.setNoUnwind(true);
1641-
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
1642-
callOp.setWillReturn(true);
1653+
callOp.setConvergent(callInst->isConvergent());
1654+
callOp.setNoUnwind(callInst->doesNotThrow());
1655+
callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
16431656

16441657
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
16451658
ModRefInfo othermem = convertModRefInfoFromLLVM(

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,31 @@ define void @varargs_call(i32 %0) {
570570

571571
; // -----
572572

573+
; CHECK: @varargs(...)
574+
declare void @varargs(...)
575+
576+
; CHECK-LABEL: @varargs_call
577+
; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
578+
define void @varargs_call(i32 %0) {
579+
; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
580+
call void @varargs(i32 %0)
581+
ret void
582+
}
583+
584+
; // -----
585+
586+
; CHECK: @varargs(...)
587+
declare void @varargs(...)
588+
589+
; CHECK-LABEL: @empty_varargs_call
590+
define void @empty_varargs_call() {
591+
; CHECK: llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
592+
call void @varargs()
593+
ret void
594+
}
595+
596+
; // -----
597+
573598
; CHECK: llvm.func @f()
574599
declare void @f()
575600

0 commit comments

Comments
 (0)