Skip to content

[MLIR] Fix import of calls with mismatched variadic types #124286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged

Conversation

xlauko
Copy link
Contributor

@xlauko xlauko commented Jan 24, 2025

Previously, an indirect call was incorrectly generated when llvm::CallBase::getCalledFunction returned null due to a type mismatch between the call and the function. This patch updates the code to use llvm::CallBase::getCalledOperand instead.

@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2025

@llvm/pr-subscribers-mlir

Author: Henrich Lauko (xlauko)

Changes

Previously, an indirect call was incorrectly generated when llvm::CallBase::getCalledFunction returned null due to a type mismatch between the call and the function. This patch updates the code to use llvm::CallBase::getCalledOperand instead.


Full diff: https://github.com/llvm/llvm-project/pull/124286.diff

2 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+45-31)
  • (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+35)
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..b8f66357d9250e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
-  if (!callInst->getCalledFunction()) {
-    if (!allowInlineAsm ||
-        !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
-      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-      if (failed(called))
-        return failure();
-      operands.push_back(*called);
-    }
+  bool isInlineAsm = callInst->isInlineAsm();
+  if (isInlineAsm && !allowInlineAsm)
+    return failure();
+
+  // Cannot use isIndirectCall() here because we need to handle Constant callees
+  // that are not considered indirect calls by LLVM.  However, in MLIR, they are
+  // treated as indirect calls to constant operands that need to be converted.
+  // Skip the callee operand if it's inline assembly, as it's handled separately
+  // in InlineAsmOp.
+  if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+    if (failed(called))
+      return failure();
+    operands.push_back(*called);
   }
+
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
   if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::Call) {
-    auto *callInst = cast<llvm::CallInst>(inst);
+    auto callInst = cast<llvm::CallInst>(inst);
+    auto calledOperand = callInst->getCalledOperand();
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
@@ -1601,14 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                           /*allowInlineAsm=*/true)))
       return failure();
 
-    auto funcTy =
-        dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
-    if (!funcTy)
-      return failure();
-
-    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+      auto resultTy = convertType(callInst->getType());
+      if (!resultTy)
+        return failure();
       auto callOp = builder.create<InlineAsmOp>(
-          loc, funcTy.getReturnType(), operands,
+          loc, resultTy, operands,
           builder.getStringAttr(asmI->getAsmString()),
           builder.getStringAttr(asmI->getConstraintString()),
           /*has_side_effects=*/true,
@@ -1619,27 +1625,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       else
         mapNoResultOp(inst, callOp);
     } else {
-      CallOp callOp;
+      auto funcTy = dyn_cast<LLVMFunctionType>([&] () -> Type {
+        // Retrieve the real function type. For direct calls, use the callee's
+        // function type, as it may differ from the operand type in the case of
+        // variadic functions. For indirect calls, use the call function type.
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+          return convertType(callee->getFunctionType());
+        return convertType(callInst->getFunctionType());
+      }() );
+
+      if (!funcTy)
+        return failure();
 
-      if (llvm::Function *callee = callInst->getCalledFunction()) {
-        callOp = builder.create<CallOp>(
-            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-            operands);
-      } else {
-        callOp = builder.create<CallOp>(loc, funcTy, operands);
-      }
+      auto callOp = [&]() -> CallOp {
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+          auto name = SymbolRefAttr::get(context, callee->getName());
+          return builder.create<CallOp>(loc, funcTy, name, operands);
+        }
+        return builder.create<CallOp>(loc, funcTy, operands);
+      }();
+
+      // Handle function attributes.
       callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
       callOp.setTailCallKind(
           convertTailCallKindFromLLVM(callInst->getTailCallKind()));
       setFastmathFlagsAttr(inst, callOp);
 
-      // Handle function attributes.
-      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-        callOp.setConvergent(true);
-      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-        callOp.setNoUnwind(true);
-      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-        callOp.setWillReturn(true);
+      callOp.setConvergent(callInst->isConvergent());
+      callOp.setNoUnwind(callInst->doesNotThrow());
+      callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
 
       llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
       ModRefInfo othermem = convertModRefInfoFromLLVM(
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..47f821c8d29909 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,41 @@ define void @varargs_call(i32 %0) {
 
 ; // -----
 
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+  ; CHECK:  llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+  call void @varargs(i32 %0)
+  ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+  ; CHECK:  llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+  call void @varargs()
+  ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @undef_call
+define void @undef_call() {
+  ; CHECK: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : !llvm.ptr
+  ; CHECK-NEXT: %[[CONST:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+  ; CHECK-NEXT: llvm.call %[[UNDEF]](%[[CONST]]) : !llvm.ptr, (i32) -> ()
+  call void undef(i32 0)
+  ret void
+}
+; // -----
+
 ; CHECK: llvm.func @f()
 declare void @f()
 

@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Henrich Lauko (xlauko)

Changes

Previously, an indirect call was incorrectly generated when llvm::CallBase::getCalledFunction returned null due to a type mismatch between the call and the function. This patch updates the code to use llvm::CallBase::getCalledOperand instead.


Full diff: https://github.com/llvm/llvm-project/pull/124286.diff

2 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+45-31)
  • (modified) mlir/test/Target/LLVMIR/Import/instructions.ll (+35)
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index f6826a2362bfdf..b8f66357d9250e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -1495,15 +1495,22 @@ LogicalResult ModuleImport::convertCallTypeAndOperands(
   if (!callInst->getType()->isVoidTy())
     types.push_back(convertType(callInst->getType()));
 
-  if (!callInst->getCalledFunction()) {
-    if (!allowInlineAsm ||
-        !isa<llvm::InlineAsm>(callInst->getCalledOperand())) {
-      FailureOr<Value> called = convertValue(callInst->getCalledOperand());
-      if (failed(called))
-        return failure();
-      operands.push_back(*called);
-    }
+  bool isInlineAsm = callInst->isInlineAsm();
+  if (isInlineAsm && !allowInlineAsm)
+    return failure();
+
+  // Cannot use isIndirectCall() here because we need to handle Constant callees
+  // that are not considered indirect calls by LLVM.  However, in MLIR, they are
+  // treated as indirect calls to constant operands that need to be converted.
+  // Skip the callee operand if it's inline assembly, as it's handled separately
+  // in InlineAsmOp.
+  if (!isa<llvm::Function>(callInst->getCalledOperand()) && !isInlineAsm) {
+    FailureOr<Value> called = convertValue(callInst->getCalledOperand());
+    if (failed(called))
+      return failure();
+    operands.push_back(*called);
   }
+
   SmallVector<llvm::Value *> args(callInst->args());
   FailureOr<SmallVector<Value>> arguments = convertValues(args);
   if (failed(arguments))
@@ -1593,7 +1600,8 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
     return success();
   }
   if (inst->getOpcode() == llvm::Instruction::Call) {
-    auto *callInst = cast<llvm::CallInst>(inst);
+    auto callInst = cast<llvm::CallInst>(inst);
+    auto calledOperand = callInst->getCalledOperand();
 
     SmallVector<Type> types;
     SmallVector<Value> operands;
@@ -1601,14 +1609,12 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
                                           /*allowInlineAsm=*/true)))
       return failure();
 
-    auto funcTy =
-        dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
-    if (!funcTy)
-      return failure();
-
-    if (auto asmI = dyn_cast<llvm::InlineAsm>(callInst->getCalledOperand())) {
+    if (auto asmI = dyn_cast<llvm::InlineAsm>(calledOperand)) {
+      auto resultTy = convertType(callInst->getType());
+      if (!resultTy)
+        return failure();
       auto callOp = builder.create<InlineAsmOp>(
-          loc, funcTy.getReturnType(), operands,
+          loc, resultTy, operands,
           builder.getStringAttr(asmI->getAsmString()),
           builder.getStringAttr(asmI->getConstraintString()),
           /*has_side_effects=*/true,
@@ -1619,27 +1625,35 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
       else
         mapNoResultOp(inst, callOp);
     } else {
-      CallOp callOp;
+      auto funcTy = dyn_cast<LLVMFunctionType>([&] () -> Type {
+        // Retrieve the real function type. For direct calls, use the callee's
+        // function type, as it may differ from the operand type in the case of
+        // variadic functions. For indirect calls, use the call function type.
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand))
+          return convertType(callee->getFunctionType());
+        return convertType(callInst->getFunctionType());
+      }() );
+
+      if (!funcTy)
+        return failure();
 
-      if (llvm::Function *callee = callInst->getCalledFunction()) {
-        callOp = builder.create<CallOp>(
-            loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
-            operands);
-      } else {
-        callOp = builder.create<CallOp>(loc, funcTy, operands);
-      }
+      auto callOp = [&]() -> CallOp {
+        if (auto callee = dyn_cast<llvm::Function>(calledOperand)) {
+          auto name = SymbolRefAttr::get(context, callee->getName());
+          return builder.create<CallOp>(loc, funcTy, name, operands);
+        }
+        return builder.create<CallOp>(loc, funcTy, operands);
+      }();
+
+      // Handle function attributes.
       callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
       callOp.setTailCallKind(
           convertTailCallKindFromLLVM(callInst->getTailCallKind()));
       setFastmathFlagsAttr(inst, callOp);
 
-      // Handle function attributes.
-      if (callInst->hasFnAttr(llvm::Attribute::Convergent))
-        callOp.setConvergent(true);
-      if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
-        callOp.setNoUnwind(true);
-      if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
-        callOp.setWillReturn(true);
+      callOp.setConvergent(callInst->isConvergent());
+      callOp.setNoUnwind(callInst->doesNotThrow());
+      callOp.setWillReturn(callInst->hasFnAttr(llvm::Attribute::WillReturn));
 
       llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
       ModRefInfo othermem = convertModRefInfoFromLLVM(
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index 7377e2584110b5..47f821c8d29909 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -570,6 +570,41 @@ define void @varargs_call(i32 %0) {
 
 ; // -----
 
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @varargs_call
+; CHECK-SAME:  %[[ARG1:[a-zA-Z0-9]+]]
+define void @varargs_call(i32 %0) {
+  ; CHECK:  llvm.call @varargs(%[[ARG1]]) vararg(!llvm.func<void (...)>) : (i32) -> ()
+  call void @varargs(i32 %0)
+  ret void
+}
+
+; // -----
+
+; CHECK: @varargs(...)
+declare void @varargs(...)
+
+; CHECK-LABEL: @empty_varargs_call
+define void @empty_varargs_call() {
+  ; CHECK:  llvm.call @varargs() vararg(!llvm.func<void (...)>) : () -> ()
+  call void @varargs()
+  ret void
+}
+
+; // -----
+
+; CHECK-LABEL: @undef_call
+define void @undef_call() {
+  ; CHECK: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : !llvm.ptr
+  ; CHECK-NEXT: %[[CONST:[0-9]+]] = llvm.mlir.constant(0 : i32) : i32
+  ; CHECK-NEXT: llvm.call %[[UNDEF]](%[[CONST]]) : !llvm.ptr, (i32) -> ()
+  call void undef(i32 0)
+  ret void
+}
+; // -----
+
 ; CHECK: llvm.func @f()
 declare void @f()
 

Copy link

github-actions bot commented Jan 24, 2025

✅ With the latest revision this PR passed the undef deprecator.

Copy link

github-actions bot commented Jan 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@xlauko
Copy link
Contributor Author

xlauko commented Jan 24, 2025

Previously, the translation for the following LLVM IR:

define i32 @bar() {
entry:
  %call = call i32 @foo()
  ret i32 %call
}

declare i32 @foo(...)

was generating an indirect call like this in MLIR:

llvm.func @bar() -> i32 {
    %0 = llvm.mlir.addressof @foo : !llvm.ptr
    %1 = llvm.call %0() : !llvm.ptr, () -> i32
    llvm.return %1 : i32
}
llvm.func @foo(...) -> i32

However, this translation was incorrect. The fixed translation now generates the correct MLIR:

llvm.func @bar() -> i32 {
    %0 = llvm.call @foo() vararg(!llvm.func<i32 (...)>) : () -> i32
    llvm.return %0 : i32
}
llvm.func @foo(...) -> i32

@xlauko xlauko force-pushed the main branch 3 times, most recently from adc300e to 3fce659 Compare January 24, 2025 15:43
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. LGTM % minor nits.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for the fix!

LGTM modulo some ultra nits.

Previously, an indirect call was incorrectly generated when `llvm::CallBase::getCalledFunction` returned null due to a type mismatch between the call and the function. This patch updates the code to use `llvm::CallBase::getCalledOperand` instead.
@gysit gysit merged commit 95d993a into llvm:main Jan 24, 2025
8 checks passed
xlauko added a commit to xlauko/llvm-project that referenced this pull request Jan 28, 2025
This resolves the same issue addressed in llvm#124286 for call operations and refactors the common conversion code for both call and invoke instructions.
xlauko added a commit to xlauko/llvm-project that referenced this pull request Jan 29, 2025
This resolves the same issue addressed in llvm#124286 for call operations and refactors the common conversion code for both call and invoke instructions.
xlauko added a commit to xlauko/llvm-project that referenced this pull request Jan 29, 2025
This resolves the same issue addressed in llvm#124286, but for invoke
operations. The issue arose from duplicated logic for both imports. This
PR also refactors the common import code for call and invoke
instructions to mitigate issues in the future.
gysit pushed a commit that referenced this pull request Jan 29, 2025
This resolves the same issue addressed in
#124286, but for invoke
operations. The issue arose from duplicated logic for both imports. This
PR also refactors the common import code for call and invoke
instructions to mitigate issues in the future.
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 29, 2025
…s (#124828)

This resolves the same issue addressed in
llvm/llvm-project#124286, but for invoke
operations. The issue arose from duplicated logic for both imports. This
PR also refactors the common import code for call and invoke
instructions to mitigate issues in the future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants