Skip to content

Commit 4cc77c8

Browse files
committed
[NVPTX] Convert calls to indirect when call signature mismatches function signature
1 parent 941841b commit 4cc77c8

File tree

3 files changed

+137
-3
lines changed

3 files changed

+137
-3
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1657,6 +1657,33 @@ LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
16571657
return RetVal;
16581658
}
16591659

1660+
static bool shouldConvertToIndirectCall(bool IsVarArg, unsigned ParamCount,
1661+
NVPTXTargetLowering::ArgListTy &Args,
1662+
const CallBase *CB,
1663+
GlobalAddressSDNode *Func) {
1664+
if (!Func)
1665+
return false;
1666+
auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal());
1667+
if (!CalleeFunc)
1668+
return false;
1669+
1670+
auto ActualReturnType = CalleeFunc->getReturnType();
1671+
if (CB->getType() != ActualReturnType)
1672+
return true;
1673+
1674+
if (IsVarArg)
1675+
return false;
1676+
1677+
auto ActualNumParams = CalleeFunc->getFunctionType()->getNumParams();
1678+
if (ParamCount != ActualNumParams)
1679+
return true;
1680+
for (const Argument &I : CalleeFunc->args())
1681+
if (I.getType() != Args[I.getArgNo()].Ty)
1682+
return true;
1683+
1684+
return false;
1685+
}
1686+
16601687
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16611688
SmallVectorImpl<SDValue> &InVals) const {
16621689

@@ -1971,10 +1998,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
19711998
VADeclareParam->getVTList(), DeclareParamOps);
19721999
}
19732000

2001+
// If the param count, type of any param, or return type of the callsite
2002+
// mismatches with that of the function signature, convert the callsite to an
2003+
// indirect call.
2004+
bool ConvertToIndirectCall =
2005+
shouldConvertToIndirectCall(CLI.IsVarArg, ParamCount, Args, CB, Func);
2006+
19742007
// Both indirect calls and libcalls have nullptr Func. In order to distinguish
19752008
// between them we must rely on the call site value which is valid for
19762009
// indirect calls but is always null for libcalls.
1977-
bool isIndirectCall = !Func && CB;
2010+
bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
19782011

19792012
if (isa<ExternalSymbolSDNode>(Callee)) {
19802013
Function* CalleeFunc = nullptr;
@@ -2026,6 +2059,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
20262059
Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
20272060
InGlue = Chain.getValue(1);
20282061

2062+
if (ConvertToIndirectCall) {
2063+
// Copy the function ptr to a ptx register and use the register to call the
2064+
// function.
2065+
EVT DestVT = Callee.getValueType();
2066+
MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
2067+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2068+
unsigned DestReg =
2069+
RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
2070+
auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
2071+
Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
2072+
}
2073+
20292074
// Ops to print out the function name
20302075
SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
20312076
SDValue CallVoidOps[] = { Chain, Callee, InGlue };

llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ target triple = "nvptx64-nvidia-cuda"
1717
; CHECK: st.param.b16 [param2+0], %rs1;
1818
; CHECK: st.param.b16 [param2+2], %rs2;
1919
; CHECK: .param .align 2 .b8 retval0[4];
20-
; CHECK: call.uni (retval0),
21-
; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
20+
; CHECK-NEXT: prototype_0 : .callprototype (.param .align 2 .b8 _[4]) _ (.param .b32 _, .param .b32 _, .param .align 2 .b8 _[4]);
21+
; CHECK-NEXT: call (retval0),
2222
define weak_odr void @foo() {
2323
entry:
2424
%call.i.i.i = tail call %"class.complex" @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32 0, i32 0, ptr byval(%"class.complex") null)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
3+
4+
%struct.64 = type <{ i64 }>
5+
declare i64 @callee(ptr %p);
6+
declare i64 @callee_variadic(ptr %p, ...);
7+
8+
define %struct.64 @test_return_type_mismatch(ptr %p) {
9+
; CHECK-LABEL: test_return_type_mismatch(
10+
; CHECK: .param .align 1 .b8 retval0[8];
11+
; CHECK-NEXT: prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
12+
; CHECK-NEXT: call (retval0),
13+
; CHECK-NEXT: %rd
14+
; CHECK-NEXT: (
15+
; CHECK-NEXT: param0
16+
; CHECK-NEXT: )
17+
; CHECK-NEXT: , prototype_0;
18+
%ret = call %struct.64 @callee(ptr %p)
19+
ret %struct.64 %ret
20+
}
21+
22+
define i64 @test_param_type_mismatch(ptr %p) {
23+
; CHECK-LABEL: test_param_type_mismatch(
24+
; CHECK: .param .b64 retval0;
25+
; CHECK-NEXT: prototype_1 : .callprototype (.param .b64 _) _ (.param .b64 _);
26+
; CHECK-NEXT: call (retval0),
27+
; CHECK-NEXT: %rd
28+
; CHECK-NEXT: (
29+
; CHECK-NEXT: param0
30+
; CHECK-NEXT: )
31+
; CHECK-NEXT: , prototype_1;
32+
%ret = call i64 @callee(i64 7)
33+
ret i64 %ret
34+
}
35+
36+
define i64 @test_param_count_mismatch(ptr %p) {
37+
; CHECK-LABEL: test_param_count_mismatch(
38+
; CHECK: .param .b64 retval0;
39+
; CHECK-NEXT: prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
40+
; CHECK-NEXT: call (retval0),
41+
; CHECK-NEXT: %rd
42+
; CHECK-NEXT: (
43+
; CHECK-NEXT: param0,
44+
; CHECK-NEXT: param1
45+
; CHECK-NEXT: )
46+
; CHECK-NEXT: , prototype_2;
47+
%ret = call i64 @callee(ptr %p, i64 7)
48+
ret i64 %ret
49+
}
50+
51+
define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
52+
; CHECK-LABEL: test_return_type_mismatch_variadic(
53+
; CHECK: .param .align 1 .b8 retval0[8];
54+
; CHECK-NEXT: prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
55+
; CHECK-NEXT: call (retval0),
56+
; CHECK-NEXT: %rd
57+
; CHECK-NEXT: (
58+
; CHECK-NEXT: param0
59+
; CHECK-NEXT: )
60+
; CHECK-NEXT: , prototype_3;
61+
%ret = call %struct.64 (ptr, ...) @callee_variadic(ptr %p)
62+
ret %struct.64 %ret
63+
}
64+
65+
define i64 @test_param_type_mismatch_variadic(ptr %p) {
66+
; CHECK-LABEL: test_param_type_mismatch_variadic(
67+
; CHECK: .param .b64 retval0;
68+
; CHECK-NEXT: call.uni (retval0),
69+
; CHECK-NEXT: callee_variadic
70+
; CHECK-NEXT: (
71+
; CHECK-NEXT: param0,
72+
; CHECK-NEXT: param1
73+
; CHECK-NEXT: )
74+
%ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
75+
ret i64 %ret
76+
}
77+
78+
define i64 @test_param_count_mismatch_variadic(ptr %p) {
79+
; CHECK-LABEL: test_param_count_mismatch_variadic(
80+
; CHECK: .param .b64 retval0;
81+
; CHECK-NEXT: call.uni (retval0),
82+
; CHECK-NEXT: callee_variadic
83+
; CHECK-NEXT: (
84+
; CHECK-NEXT: param0,
85+
; CHECK-NEXT: param1
86+
; CHECK-NEXT: )
87+
%ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
88+
ret i64 %ret
89+
}

0 commit comments

Comments
 (0)