Skip to content

Commit b72c3e3

Browse files
committed
[MLIR][LLVMIR] Relax mismatching calls
LLVM IR currently [accepts](https://godbolt.org/z/nqnEsW1ja): ``` define void @incompatible_call_and_callee_types() { call void @callee(i64 0) ret void } define void @callee({ptr, i64}, i32) { ret void } ``` This currently fails to import. Even though these constructs are dangerous and probably indicate some ODR violation (or optimization bug), they are "valid" and should be imported into LLVM IR dialect. This PR implements that by using an indirect call to represent it. Translation already works nicely and outputs the same source llvm IR file. The error is now a warning, the tests in `mlir/test/Target/LLVMIR/Import/import-failure.ll` already use `CHECK` lines, so no need to add extra diagnostic tests.
1 parent 7f922f1 commit b72c3e3

File tree

3 files changed

+68
-19
lines changed

3 files changed

+68
-19
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,12 @@ class ModuleImport {
362362
/// Converts the callee's function type. For direct calls, it converts the
363363
/// actual function type, which may differ from the called operand type in
364364
/// variadic functions. For indirect calls, it converts the function type
365-
/// associated with the call instruction. Returns failure when the call and
366-
/// the callee are not compatible or when nested type conversions failed.
367-
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
365+
/// associated with the call instruction. When the call and the callee are not
366+
/// compatible (or when nested type conversions failed), emit a warning but
367+
/// attempt translation using a bitcast and an indirect call (in order
368+
/// represent valid and verified LLVM IR).
369+
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
370+
Value &castResult);
368371
/// Returns the callee name, or an empty symbol if the call is not direct.
369372
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
370373
/// Converts the parameter and result attributes attached to `func` and adds

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,8 +1721,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
17211721
/// Checks if `callType` and `calleeType` are compatible and can be represented
17221722
/// in MLIR.
17231723
static LogicalResult
1724-
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
1725-
LLVMFunctionType calleeType) {
1724+
checkFunctionTypeCompatibility(LLVMFunctionType callType,
1725+
LLVMFunctionType calleeType) {
17261726
if (callType.getReturnType() != calleeType.getReturnType())
17271727
return failure();
17281728

@@ -1748,7 +1748,7 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
17481748
}
17491749

17501750
FailureOr<LLVMFunctionType>
1751-
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
1751+
ModuleImport::convertFunctionType(llvm::CallBase *callInst, Value &castResult) {
17521752
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
17531753
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
17541754
if (!funcTy)
@@ -1771,11 +1771,17 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
17711771
if (failed(calleeType))
17721772
return failure();
17731773

1774-
// Compare the types to avoid constructing illegal call/invoke operations.
1775-
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
1774+
// Compare the types, if they are not compatible, avoid illegal call/invoke
1775+
// operations by casting to the callsite type and issuing an indirect call.
1776+
// LLVM IR currently supports this usage.
1777+
if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
17761778
Location loc = translateLoc(callInst->getDebugLoc());
1777-
return emitError(loc) << "incompatible call and callee types: " << *callType
1778-
<< " and " << *calleeType;
1779+
FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
1780+
castResult = builder.create<LLVM::AddressOfOp>(
1781+
loc, LLVM::LLVMPointerType::get(context), calleeSym);
1782+
emitWarning(loc) << "incompatible call and callee types: " << *callType
1783+
<< " and " << *calleeType;
1784+
return callType;
17791785
}
17801786

17811787
return calleeType;
@@ -1892,16 +1898,29 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
18921898
/*operand_attrs=*/nullptr)
18931899
.getOperation();
18941900
}
1895-
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
1901+
Value castResult;
1902+
FailureOr<LLVMFunctionType> funcTy =
1903+
convertFunctionType(callInst, castResult);
18961904
if (failed(funcTy))
18971905
return failure();
18981906

1899-
FlatSymbolRefAttr callee = convertCalleeName(callInst);
1900-
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
1907+
FlatSymbolRefAttr callee = nullptr;
1908+
// If no cast is needed, use the original callee name. Otherwise patch
1909+
// operands to include the indirect call target. Build indirect call by
1910+
// passing using a nullptr `callee`.
1911+
if (!castResult)
1912+
callee = convertCalleeName(callInst);
1913+
else
1914+
operands->insert(operands->begin(), castResult);
1915+
CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
1916+
19011917
if (failed(convertCallAttributes(callInst, callOp)))
19021918
return failure();
1903-
// Handle parameter and result attributes.
1904-
convertParameterAttributes(callInst, callOp, builder);
1919+
1920+
// Handle parameter and result attributes. Don't bother if there's a
1921+
// type mismatch.
1922+
if (!castResult)
1923+
convertParameterAttributes(callInst, callOp, builder);
19051924
return callOp.getOperation();
19061925
}();
19071926

@@ -1966,11 +1985,20 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
19661985
unwindArgs)))
19671986
return failure();
19681987

1969-
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
1988+
Value castResult;
1989+
FailureOr<LLVMFunctionType> funcTy =
1990+
convertFunctionType(invokeInst, castResult);
19701991
if (failed(funcTy))
19711992
return failure();
19721993

1973-
FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
1994+
FlatSymbolRefAttr calleeName = nullptr;
1995+
// If no cast is needed, use the original callee name. Otherwise patch
1996+
// operands to include the indirect call target. Build indirect call by
1997+
// passing using a nullptr `callee`.
1998+
if (!castResult)
1999+
calleeName = convertCalleeName(invokeInst);
2000+
else
2001+
operands->insert(operands->begin(), castResult);
19742002

19752003
// Create the invoke operation. Normal destination block arguments will be
19762004
// added later on to handle the case in which the operation result is
@@ -1982,8 +2010,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
19822010
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
19832011
return failure();
19842012

1985-
// Handle parameter and result attributes.
1986-
convertParameterAttributes(invokeInst, invokeOp, builder);
2013+
// Handle parameter and result attributes. Don't bother if there's a
2014+
// type mismatch.
2015+
if (!castResult)
2016+
convertParameterAttributes(invokeInst, invokeOp, builder);
19872017

19882018
if (!invokeInst->getType()->isVoidTy())
19892019
mapValue(inst, invokeOp.getResults().front());

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,19 @@ bb2:
739739
declare void @g(...)
740740

741741
declare i32 @__gxx_personality_v0(...)
742+
743+
; // -----
744+
745+
; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
746+
define void @incompatible_call_and_callee_types() {
747+
; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
748+
; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
749+
; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
750+
call void @callee(i64 0)
751+
; CHECK: llvm.return
752+
ret void
753+
}
754+
755+
define void @callee({ptr, i64}, i32) {
756+
ret void
757+
}

0 commit comments

Comments
 (0)