-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Convert calls to indirect when call signature mismatches function signature #107644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Kevin McAfee (kalxr) ChangesWhen at least one of the return type, parameter type, or parameter count mismatches between a call instruction and the callee, lower the call to an indirect call. The current behavior is to produce direct calls that may or may not be valid PTX. Consider the following example with mismatching return types:
The return type of If we instead convert these calls to indirect calls, we will generate functional PTX when the types have the same size. If they do not have the same size then the PTX will be incorrect, though this will not necessarily be caught by ptxas. This change allows for more flexibility in the bitcode that can be lowered to functioning PTX, at the cost of sometimes producing PTX that is less clearly wrong than it would have been previously (i.e. incorrect indirect calls are not as obviously wrong as incorrect direct calls). Full diff: https://github.com/llvm/llvm-project/pull/107644.diff 3 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb76ffdfd99d7b..726493ccaa2569 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1657,6 +1657,33 @@ LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
return RetVal;
}
+static bool shouldConvertToIndirectCall(bool IsVarArg, unsigned ParamCount,
+ NVPTXTargetLowering::ArgListTy &Args,
+ const CallBase *CB,
+ GlobalAddressSDNode *Func) {
+ if (!Func)
+ return false;
+ auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal());
+ if (!CalleeFunc)
+ return false;
+
+ auto ActualReturnType = CalleeFunc->getReturnType();
+ if (CB->getType() != ActualReturnType)
+ return true;
+
+ if (IsVarArg)
+ return false;
+
+ auto ActualNumParams = CalleeFunc->getFunctionType()->getNumParams();
+ if (ParamCount != ActualNumParams)
+ return true;
+ for (const Argument &I : CalleeFunc->args())
+ if (I.getType() != Args[I.getArgNo()].Ty)
+ return true;
+
+ return false;
+}
+
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {
@@ -1971,10 +1998,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
VADeclareParam->getVTList(), DeclareParamOps);
}
+ // If the param count, type of any param, or return type of the callsite
+ // mismatches with that of the function signature, convert the callsite to an
+ // indirect call.
+ bool ConvertToIndirectCall =
+ shouldConvertToIndirectCall(CLI.IsVarArg, ParamCount, Args, CB, Func);
+
// Both indirect calls and libcalls have nullptr Func. In order to distinguish
// between them we must rely on the call site value which is valid for
// indirect calls but is always null for libcalls.
- bool isIndirectCall = !Func && CB;
+ bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
if (isa<ExternalSymbolSDNode>(Callee)) {
Function* CalleeFunc = nullptr;
@@ -2026,6 +2059,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
InGlue = Chain.getValue(1);
+ if (ConvertToIndirectCall) {
+ // Copy the function ptr to a ptx register and use the register to call the
+ // function.
+ EVT DestVT = Callee.getValueType();
+ MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ unsigned DestReg =
+ RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
+ auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
+ Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
+ }
+
// Ops to print out the function name
SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue CallVoidOps[] = { Chain, Callee, InGlue };
diff --git a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
index c5f7bd1bd1ba20..bd723a296e620f 100644
--- a/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
+++ b/llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll
@@ -17,8 +17,8 @@ target triple = "nvptx64-nvidia-cuda"
; CHECK: st.param.b16 [param2+0], %rs1;
; CHECK: st.param.b16 [param2+2], %rs2;
; CHECK: .param .align 2 .b8 retval0[4];
-; CHECK: call.uni (retval0),
-; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
+; CHECK-NEXT: prototype_0 : .callprototype (.param .align 2 .b8 _[4]) _ (.param .b32 _, .param .b32 _, .param .align 2 .b8 _[4]);
+; CHECK-NEXT: call (retval0),
define weak_odr void @foo() {
entry:
%call.i.i.i = tail call %"class.complex" @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32 0, i32 0, ptr byval(%"class.complex") null)
diff --git a/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
new file mode 100644
index 00000000000000..2602c3b0d041b5
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll
@@ -0,0 +1,89 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
+
+%struct.64 = type <{ i64 }>
+declare i64 @callee(ptr %p);
+declare i64 @callee_variadic(ptr %p, ...);
+
+define %struct.64 @test_return_type_mismatch(ptr %p) {
+; CHECK-LABEL: test_return_type_mismatch(
+; CHECK: .param .align 1 .b8 retval0[8];
+; CHECK-NEXT: prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
+; CHECK-NEXT: call (retval0),
+; CHECK-NEXT: %rd
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: )
+; CHECK-NEXT: , prototype_0;
+ %ret = call %struct.64 @callee(ptr %p)
+ ret %struct.64 %ret
+}
+
+define i64 @test_param_type_mismatch(ptr %p) {
+; CHECK-LABEL: test_param_type_mismatch(
+; CHECK: .param .b64 retval0;
+; CHECK-NEXT: prototype_1 : .callprototype (.param .b64 _) _ (.param .b64 _);
+; CHECK-NEXT: call (retval0),
+; CHECK-NEXT: %rd
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: )
+; CHECK-NEXT: , prototype_1;
+ %ret = call i64 @callee(i64 7)
+ ret i64 %ret
+}
+
+define i64 @test_param_count_mismatch(ptr %p) {
+; CHECK-LABEL: test_param_count_mismatch(
+; CHECK: .param .b64 retval0;
+; CHECK-NEXT: prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
+; CHECK-NEXT: call (retval0),
+; CHECK-NEXT: %rd
+; CHECK-NEXT: (
+; CHECK-NEXT: param0,
+; CHECK-NEXT: param1
+; CHECK-NEXT: )
+; CHECK-NEXT: , prototype_2;
+ %ret = call i64 @callee(ptr %p, i64 7)
+ ret i64 %ret
+}
+
+define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_return_type_mismatch_variadic(
+; CHECK: .param .align 1 .b8 retval0[8];
+; CHECK-NEXT: prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
+; CHECK-NEXT: call (retval0),
+; CHECK-NEXT: %rd
+; CHECK-NEXT: (
+; CHECK-NEXT: param0
+; CHECK-NEXT: )
+; CHECK-NEXT: , prototype_3;
+ %ret = call %struct.64 (ptr, ...) @callee_variadic(ptr %p)
+ ret %struct.64 %ret
+}
+
+define i64 @test_param_type_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_param_type_mismatch_variadic(
+; CHECK: .param .b64 retval0;
+; CHECK-NEXT: call.uni (retval0),
+; CHECK-NEXT: callee_variadic
+; CHECK-NEXT: (
+; CHECK-NEXT: param0,
+; CHECK-NEXT: param1
+; CHECK-NEXT: )
+ %ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
+ ret i64 %ret
+}
+
+define i64 @test_param_count_mismatch_variadic(ptr %p) {
+; CHECK-LABEL: test_param_count_mismatch_variadic(
+; CHECK: .param .b64 retval0;
+; CHECK-NEXT: call.uni (retval0),
+; CHECK-NEXT: callee_variadic
+; CHECK-NEXT: (
+; CHECK-NEXT: param0,
+; CHECK-NEXT: param1
+; CHECK-NEXT: )
+ %ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
+ ret i64 %ret
+}
|
Hmm. I wasn't aware that one is allowed to do this in LLVM. In your example all variants expect to see 64 bits of data returned, but what's supposed to happen if I decide to do LLVM IR appears to enforce prototype match on Before we proceed with the patch, I want to understand what's the right thing to do here. |
I should have included this very relevant part of the PTX programming guide:
We would generate PTX with undefined behavior. The same is true in some cases when the type sizes match, because an I was also unable to find much documentation beyond what you mentioned. I think an argument could be made that since the behavior of LLVM IR with these mismatches doesn't seem particularly well defined, the issue is with the input and it is okay to generate code that is undefined as well. I tested a bunch of these cases with other backends (x86, aarch64, riscv64) and they don't seem to do anything special, just follow the calling convention. This patch brings NVPTX closer to that behavior by sort of falling back on a particular calling convention when we know that the more abstract direct call will fail. I don't know that this behavior is ideal, but that's part of the motivation behind the change. |
Unless we find a good reason to make an extra effort to make these mismatches work (in an undefined way, at that), my preferred choice would be to fail. Debugging such a mismatch in a GPU code leading to subtle errors at runtime will be a rather terrible experience, and I'd rather catch it at compile time. However, the fact that LLVM does not complain about such mismatch suggests that I may be missing something. @nikic -- would you happen to have thoughts/suggestions on this? Do we need to handle mismatches between the function prototype and the call site? Or can we consider them to be an error? |
return true; | ||
|
||
return false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a complex way to spell CB->getFunctionType() != CalleeFunc->getFunctionType()
. Or even just !CB->getCalledFunction()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CB->getCalledFunction()
is a bit different because CB->getCalledOperand()
cannot be casted to a function in every case where Func->getGlobal()
can. NVPTX/alias.ll has a case of this with the function pointer aliasing. Comparing the types directly makes sense though, thanks for the suggestion.
@Artem-B At a high level, calls with mismatching function types are usually undefined behavior, but not malformed, as they may legally occur in dead code. Generating an error for such calls is not appropriate (at least normally, maybe the PTX programming model is different). When it comes to the details, it's a bit more complicated and we don't actually know what the precise semantics are. If the called function is a declaration, the declaration may not match the actual definition. This may be UB in some cases, but there are at least some carve-outs, in particular calling an "unprototyped" Generally treating mismatches between call and function function type as indirect calls is a good baseline behavior, which is also why CallBase->getCalledFunction() returns null in such cases, so you usually get that behavior by default. |
@nikic Thank you. That was very helpful. |
4cc77c8
to
6c4a734
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a nit.
6c4a734
to
9e92003
Compare
When at least one of the return type, parameter type, or parameter count mismatches between a call instruction and the callee, lower the call to an indirect call. The current behavior is to produce direct calls that may or may not be valid PTX. Consider the following example with mismatching return types:
The return type of
callee
in PTX is.b8 _[8]
. The return type of%call1
will be the same and so the PTX has no problems. The return type of%call2
will be.b64
, so the types will not match and PTX will be unacceptable to ptxas. This despite all the types having the same size. The same is true for mismatching parameter types.If we instead convert these calls to indirect calls, we will generate functional PTX when the types have the same size. If they do not have the same size then the PTX will be incorrect, though this will not necessarily be caught by ptxas. This change allows for more flexibility in the bitcode that can be lowered to functioning PTX, at the cost of sometimes producing PTX that is less clearly wrong than it would have been previously (i.e. incorrect indirect calls are not as obviously wrong as incorrect direct calls).