-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Fix import of calls with mismatched variadic types #124286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Henrich Lauko (xlauko) ChangesPreviously, an indirect call was incorrectly generated when Full diff: https://github.com/llvm/llvm-project/pull/124286.diff 2 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..b8f66357d9250e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));
- if (!callInst->getCalledFunction()) {
- if (!allowInlineAsm ||
- !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
- FailureOr<Value> called = convertValue(callInst->getCalledOperand());
- if (failed(called))
- return failure();
- operands.push_back(*called);
- }
+ bool isInlineAsm = callInst->isInlineAsm();
+ if (isInlineAsm && !allowInlineAsm)
+ return failure();
+
+ // Cannot use isIndirectCall() here because we need to handle Constant callees
+ // that are not considered indirect calls by LLVM. However, in MLIR, they are
+ // treated as indirect calls to constant operands that need to be converted.
+ // Skip the callee operand if it's inline assembly, as it's handled separately
+ // in InlineAsmOp.
+ if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+ FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+ if (failed(called))
+ return failure();
+ operands.push_back(*called);
}
+
SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return success();
}
if (inst->getOpcode() == llvm::Instruction::Call) {
- auto *callInst = cast<llvm::CallInst>(inst);
+ auto callInst = cast<llvm::CallInst>(inst);
+ auto calledOperand = callInst->getCalledOperand();
SmallVector<Type> types;
SmallVector<Value> operands;
@@ -1601,14 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*allowInlineAsm=*/true)))
return failure();
- auto funcTy =
- dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
- if (!funcTy)
- return failure();
-
- if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+ auto resultTy = convertType(callInst->getType());
+ if (!resultTy)
+ return failure();
auto callOp = builder.create<InlineAsmOp>(
- loc, funcTy.getReturnType(), operands,
+ loc, resultTy, operands,
builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
@@ -1619,27 +1625,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
else
mapNoResultOp(inst, callOp);
} else {
- CallOp callOp;
+ auto funcTy = dyn_cast<LLVMFunctionType>([&] () -> Type {
+ // Retrieve the real function type. For direct calls, use the callee's
+ // function type, as it may differ from the operand type in the case of
+ // variadic functions. For indirect calls, use the call function type.
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+ return convertType(callee->getFunctionType());
+ return convertType(callInst->getFunctionType());
+ }() );
+
+ if (!funcTy)
+ return failure();
- if (llvm::Function *callee = callInst->getCalledFunction()) {
- callOp = builder.create<CallOp>(
- loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
- operands);
- } else {
- callOp = builder.create<CallOp>(loc, funcTy, operands);
- }
+ auto callOp = [&]() -> CallOp {
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+ auto name = SymbolRefAttr::get(context, callee->getName());
+ return builder.create<CallOp>(loc, funcTy, name, operands);
+ }
+ return builder.create<CallOp>(loc, funcTy, operands);
+ }();
+
+ // Handle function attributes.
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);
- // Handle function attributes.
- if (callInst->hasFnAttr(llvm::Attribute::Convergent))
- callOp.setConvergent(true);
- if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
- callOp.setNoUnwind(true);
- if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
- callOp.setWillReturn(true);
+ callOp.setConvergent(callInst->isConvergent());
+ callOp.setNoUnwind(callInst->doesNotThrow());
+ callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..47f821c8d29909 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,41 @@ define void @varargs_call(i32 %0) {
; // -----
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+ ; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+ call void @varargs(i32 %0)
+ ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+ ; CHECK: llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+ call void @varargs()
+ ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @undef_call
+define void @undef_call() {
+ ; CHECK: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : !llvm.ptr
+ ; CHECK-NEXT: %[[CONST:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+ ; CHECK-NEXT: llvm.call %[[UNDEF]](%[[CONST]]) : !llvm.ptr, (i32) -> ()
+ call void undef(i32 0)
+ ret void
+}
+; // -----
+
; CHECK: llvm.func @f()
declare void @f()
|
@llvm/pr-subscribers-mlir-llvm Author: Henrich Lauko (xlauko) ChangesPreviously, an indirect call was incorrectly generated when Full diff: https://github.com/llvm/llvm-project/pull/124286.diff 2 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..b8f66357d9250e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
if (!callInst->getType()->isVoidTy())
types.push_back(convertType(callInst->getType()));
- if (!callInst->getCalledFunction()) {
- if (!allowInlineAsm ||
- !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
- FailureOr<Value> called = convertValue(callInst->getCalledOperand());
- if (failed(called))
- return failure();
- operands.push_back(*called);
- }
+ bool isInlineAsm = callInst->isInlineAsm();
+ if (isInlineAsm && !allowInlineAsm)
+ return failure();
+
+ // Cannot use isIndirectCall() here because we need to handle Constant callees
+ // that are not considered indirect calls by LLVM. However, in MLIR, they are
+ // treated as indirect calls to constant operands that need to be converted.
+ // Skip the callee operand if it's inline assembly, as it's handled separately
+ // in InlineAsmOp.
+ if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+ FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+ if (failed(called))
+ return failure();
+ operands.push_back(*called);
}
+
SmallVector<llvm::Value *> args(callInst->args());
FailureOr<SmallVector<Value>> arguments = convertValues(args);
if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
return success();
}
if (inst->getOpcode() == llvm::Instruction::Call) {
- auto *callInst = cast<llvm::CallInst>(inst);
+ auto callInst = cast<llvm::CallInst>(inst);
+ auto calledOperand = callInst->getCalledOperand();
SmallVector<Type> types;
SmallVector<Value> operands;
@@ -1601,14 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
/*allowInlineAsm=*/true)))
return failure();
- auto funcTy =
- dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
- if (!funcTy)
- return failure();
-
- if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+ if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+ auto resultTy = convertType(callInst->getType());
+ if (!resultTy)
+ return failure();
auto callOp = builder.create<InlineAsmOp>(
- loc, funcTy.getReturnType(), operands,
+ loc, resultTy, operands,
builder.getStringAttr(asmI->getAsmString()),
builder.getStringAttr(asmI->getConstraintString()),
/*has_side_effects=*/true,
@@ -1619,27 +1625,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
else
mapNoResultOp(inst, callOp);
} else {
- CallOp callOp;
+ auto funcTy = dyn_cast<LLVMFunctionType>([&] () -> Type {
+ // Retrieve the real function type. For direct calls, use the callee's
+ // function type, as it may differ from the operand type in the case of
+ // variadic functions. For indirect calls, use the call function type.
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+ return convertType(callee->getFunctionType());
+ return convertType(callInst->getFunctionType());
+ }() );
+
+ if (!funcTy)
+ return failure();
- if (llvm::Function *callee = callInst->getCalledFunction()) {
- callOp = builder.create<CallOp>(
- loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
- operands);
- } else {
- callOp = builder.create<CallOp>(loc, funcTy, operands);
- }
+ auto callOp = [&]() -> CallOp {
+ if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+ auto name = SymbolRefAttr::get(context, callee->getName());
+ return builder.create<CallOp>(loc, funcTy, name, operands);
+ }
+ return builder.create<CallOp>(loc, funcTy, operands);
+ }();
+
+ // Handle function attributes.
callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
callOp.setTailCallKind(
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
setFastmathFlagsAttr(inst, callOp);
- // Handle function attributes.
- if (callInst->hasFnAttr(llvm::Attribute::Convergent))
- callOp.setConvergent(true);
- if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
- callOp.setNoUnwind(true);
- if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
- callOp.setWillReturn(true);
+ callOp.setConvergent(callInst->isConvergent());
+ callOp.setNoUnwind(callInst->doesNotThrow());
+ callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
ModRefInfo othermem = convertModRefInfoFromLLVM(
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..47f821c8d29909 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,41 @@ define void @varargs_call(i32 %0) {
; // -----
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+ ; CHECK: llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+ call void @varargs(i32 %0)
+ ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+ ; CHECK: llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+ call void @varargs()
+ ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @undef_call
+define void @undef_call() {
+ ; CHECK: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : !llvm.ptr
+ ; CHECK-NEXT: %[[CONST:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+ ; CHECK-NEXT: llvm.call %[[UNDEF]](%[[CONST]]) : !llvm.ptr, (i32) -> ()
+ call void undef(i32 0)
+ ret void
+}
+; // -----
+
; CHECK: llvm.func @f()
declare void @f()
|
✅ With the latest revision this PR passed the undef deprecator. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
Previously, the translation for the following LLVM IR: define i32 @bar() {
entry:
%call = call i32 @foo()
ret i32 %call
}
declare i32 @foo(...) was generating an indirect call like this in MLIR: llvm.func @bar() -> i32 {
%0 = llvm.mlir.addressof @foo : !llvm.ptr
%1 = llvm.call %0() : !llvm.ptr, () -> i32
llvm.return %1 : i32
}
llvm.func @foo(...) -> i32 However, this translation was incorrect. The fixed translation now generates the correct MLIR: llvm.func @bar() -> i32 {
%0 = llvm.call @foo() vararg(!llvm.func<i32 (...)>) : () -> i32
llvm.return %0 : i32
}
llvm.func @foo(...) -> i32 |
adc300e
to
3fce659
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix. LGTM % minor nits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for the fix!
LGTM modulo some ultra nits.
Previously, an indirect call was incorrectly generated when `llvm::CallBase::getCalledFunction` returned null due to a type mismatch between the call and the function. This patch updates the code to use `llvm::CallBase::getCalledOperand` instead.
This resolves the same issue addressed in llvm#124286 for call operations and refactors the common conversion code for both call and invoke instructions.
This resolves the same issue addressed in llvm#124286 for call operations and refactors the common conversion code for both call and invoke instructions.
This resolves the same issue addressed in llvm#124286, but for invoke operations. The issue arose from duplicated logic for both imports. This PR also refactors the common import code for call and invoke instructions to mitigate issues in the future.
This resolves the same issue addressed in #124286, but for invoke operations. The issue arose from duplicated logic for both imports. This PR also refactors the common import code for call and invoke instructions to mitigate issues in the future.
…s (#124828) This resolves the same issue addressed in llvm/llvm-project#124286, but for invoke operations. The issue arose from duplicated logic for both imports. This PR also refactors the common import code for call and invoke instructions to mitigate issues in the future.
Previously, an indirect call was incorrectly generated when
llvm::CallBase::getCalledFunction
returned null due to a type mismatch between the call and the function. This patch updates the code to usellvm::CallBase::getCalledOperand
instead.