Skip to content

Commit 4992f78

Browse files
committed
[MLIR][LLVMIR] Relax mismatching calls
LLVM IR currently [accepts](https://godbolt.org/z/nqnEsW1ja): ``` define void @incompatible_call_and_callee_types() { call void @callee(i64 0) ret void } define void @callee({ptr, i64}, i32) { ret void } ``` This currently fails to import. Even though these constructs are dangerous and probably indicate some ODR violation (or optimization bug), they are "valid" and should be imported into LLVM IR dialect. This PR implements that by using an indirect call to represent it. Translation already works nicely and outputs the same source llvm IR file. The error is now a warning, the tests in `mlir/test/Target/LLVMIR/Import/import-failure.ll` already use `CHECK` lines, so no need to add extra diagnostic tests.
1 parent e4f2191 commit 4992f78

File tree

3 files changed

+68
-19
lines changed

3 files changed

+68
-19
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,12 @@ class ModuleImport {
362362
/// Converts the callee's function type. For direct calls, it converts the
363363
/// actual function type, which may differ from the called operand type in
364364
/// variadic functions. For indirect calls, it converts the function type
365-
/// associated with the call instruction. Returns failure when the call and
366-
/// the callee are not compatible or when nested type conversions failed.
367-
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst);
365+
/// associated with the call instruction. When the call and the callee are not
366+
/// compatible (or when nested type conversions failed), emit a warning but
367+
/// attempt translation using a bitcast and an indirect call (in order
368+
/// represent valid and verified LLVM IR).
369+
FailureOr<LLVMFunctionType> convertFunctionType(llvm::CallBase *callInst,
370+
Value &castResult);
368371
/// Returns the callee name, or an empty symbol if the call is not direct.
369372
FlatSymbolRefAttr convertCalleeName(llvm::CallBase *callInst);
370373
/// Converts the parameter and result attributes attached to `func` and adds

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,8 +1668,8 @@ ModuleImport::convertCallOperands(llvm::CallBase *callInst,
16681668
/// Checks if `callType` and `calleeType` are compatible and can be represented
16691669
/// in MLIR.
16701670
static LogicalResult
1671-
verifyFunctionTypeCompatibility(LLVMFunctionType callType,
1672-
LLVMFunctionType calleeType) {
1671+
checkFunctionTypeCompatibility(LLVMFunctionType callType,
1672+
LLVMFunctionType calleeType) {
16731673
if (callType.getReturnType() != calleeType.getReturnType())
16741674
return failure();
16751675

@@ -1695,7 +1695,7 @@ verifyFunctionTypeCompatibility(LLVMFunctionType callType,
16951695
}
16961696

16971697
FailureOr<LLVMFunctionType>
1698-
ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
1698+
ModuleImport::convertFunctionType(llvm::CallBase *callInst, Value &castResult) {
16991699
auto castOrFailure = [](Type convertedType) -> FailureOr<LLVMFunctionType> {
17001700
auto funcTy = dyn_cast_or_null<LLVMFunctionType>(convertedType);
17011701
if (!funcTy)
@@ -1718,11 +1718,17 @@ ModuleImport::convertFunctionType(llvm::CallBase *callInst) {
17181718
if (failed(calleeType))
17191719
return failure();
17201720

1721-
// Compare the types to avoid constructing illegal call/invoke operations.
1722-
if (failed(verifyFunctionTypeCompatibility(*callType, *calleeType))) {
1721+
// Compare the types, if they are not compatible, avoid illegal call/invoke
1722+
// operations by casting to the callsite type and issuing an indirect call.
1723+
// LLVM IR currently supports this usage.
1724+
if (failed(checkFunctionTypeCompatibility(*callType, *calleeType))) {
17231725
Location loc = translateLoc(callInst->getDebugLoc());
1724-
return emitError(loc) << "incompatible call and callee types: " << *callType
1725-
<< " and " << *calleeType;
1726+
FlatSymbolRefAttr calleeSym = convertCalleeName(callInst);
1727+
castResult = builder.create<LLVM::AddressOfOp>(
1728+
loc, LLVM::LLVMPointerType::get(context), calleeSym);
1729+
emitWarning(loc) << "incompatible call and callee types: " << *callType
1730+
<< " and " << *calleeType;
1731+
return callType;
17261732
}
17271733

17281734
return calleeType;
@@ -1839,16 +1845,29 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
18391845
/*operand_attrs=*/nullptr)
18401846
.getOperation();
18411847
}
1842-
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(callInst);
1848+
Value castResult;
1849+
FailureOr<LLVMFunctionType> funcTy =
1850+
convertFunctionType(callInst, castResult);
18431851
if (failed(funcTy))
18441852
return failure();
18451853

1846-
FlatSymbolRefAttr callee = convertCalleeName(callInst);
1847-
auto callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
1854+
FlatSymbolRefAttr callee = nullptr;
1855+
// If no cast is needed, use the original callee name. Otherwise patch
1856+
// operands to include the indirect call target. Build indirect call by
1857+
// passing using a nullptr `callee`.
1858+
if (!castResult)
1859+
callee = convertCalleeName(callInst);
1860+
else
1861+
operands->insert(operands->begin(), castResult);
1862+
CallOp callOp = builder.create<CallOp>(loc, *funcTy, callee, *operands);
1863+
18481864
if (failed(convertCallAttributes(callInst, callOp)))
18491865
return failure();
1850-
// Handle parameter and result attributes.
1851-
convertParameterAttributes(callInst, callOp, builder);
1866+
1867+
// Handle parameter and result attributes. Don't bother if there's a
1868+
// type mismatch.
1869+
if (!castResult)
1870+
convertParameterAttributes(callInst, callOp, builder);
18521871
return callOp.getOperation();
18531872
}();
18541873

@@ -1913,11 +1932,20 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
19131932
unwindArgs)))
19141933
return failure();
19151934

1916-
FailureOr<LLVMFunctionType> funcTy = convertFunctionType(invokeInst);
1935+
Value castResult;
1936+
FailureOr<LLVMFunctionType> funcTy =
1937+
convertFunctionType(invokeInst, castResult);
19171938
if (failed(funcTy))
19181939
return failure();
19191940

1920-
FlatSymbolRefAttr calleeName = convertCalleeName(invokeInst);
1941+
FlatSymbolRefAttr calleeName = nullptr;
1942+
// If no cast is needed, use the original callee name. Otherwise patch
1943+
// operands to include the indirect call target. Build indirect call by
1944+
// passing using a nullptr `callee`.
1945+
if (!castResult)
1946+
calleeName = convertCalleeName(invokeInst);
1947+
else
1948+
operands->insert(operands->begin(), castResult);
19211949

19221950
// Create the invoke operation. Normal destination block arguments will be
19231951
// added later on to handle the case in which the operation result is
@@ -1929,8 +1957,10 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
19291957
if (failed(convertInvokeAttributes(invokeInst, invokeOp)))
19301958
return failure();
19311959

1932-
// Handle parameter and result attributes.
1933-
convertParameterAttributes(invokeInst, invokeOp, builder);
1960+
// Handle parameter and result attributes. Don't bother if there's a
1961+
// type mismatch.
1962+
if (!castResult)
1963+
convertParameterAttributes(invokeInst, invokeOp, builder);
19341964

19351965
if (!invokeInst->getType()->isVoidTy())
19361966
mapValue(inst, invokeOp.getResults().front());

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,3 +720,19 @@ bb2:
720720
declare void @g(...)
721721

722722
declare i32 @__gxx_personality_v0(...)
723+
724+
; // -----
725+
726+
; CHECK-LABEL: llvm.func @incompatible_call_and_callee_types
727+
define void @incompatible_call_and_callee_types() {
728+
; CHECK: %[[CST:.*]] = llvm.mlir.constant(0 : i64) : i64
729+
; CHECK: %[[TARGET:.*]] = llvm.mlir.addressof @callee : !llvm.ptr
730+
; CHECK: llvm.call %[[TARGET]](%[[CST]]) : !llvm.ptr, (i64) -> ()
731+
call void @callee(i64 0)
732+
; CHECK: llvm.return
733+
ret void
734+
}
735+
736+
define void @callee({ptr, i64}, i32) {
737+
ret void
738+
}

0 commit comments

Comments
 (0)