Skip to content

Commit 57183b6

Browse files
authored
[NVPTX] Add support for stacksave, stackrestore intrinsics (#114484)
Add support for the '`@llvm.stacksave`' and '`@llvm.stackrestore`' intrinsics to NVPTX. These are implemented with the `stacksave` and `stackrestore` PTX instructions respectively. See [PTX ISA 9.7.17. Stack Manipulation Instructions] (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#stack-manipulation-instructions).
1 parent b24650e commit 57183b6

File tree

4 files changed

+182
-2
lines changed

4 files changed

+182
-2
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/Support/CodeGen.h"
5454
#include "llvm/Support/CommandLine.h"
5555
#include "llvm/Support/ErrorHandling.h"
56+
#include "llvm/Support/NVPTXAddrSpace.h"
5657
#include "llvm/Support/raw_ostream.h"
5758
#include "llvm/Target/TargetMachine.h"
5859
#include "llvm/Target/TargetOptions.h"
@@ -678,8 +679,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
678679
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
679680
setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
680681

681-
setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i32, Custom);
682-
setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
682+
setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
683+
setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom);
683684

684685
// TRAP can be lowered to PTX trap
685686
setOperationAction(ISD::TRAP, MVT::Other, Legal);
@@ -972,6 +973,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
972973
MAKE_CASE(NVPTXISD::PRMT)
973974
MAKE_CASE(NVPTXISD::FCOPYSIGN)
974975
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
976+
MAKE_CASE(NVPTXISD::STACKRESTORE)
977+
MAKE_CASE(NVPTXISD::STACKSAVE)
975978
MAKE_CASE(NVPTXISD::SETP_F16X2)
976979
MAKE_CASE(NVPTXISD::SETP_BF16X2)
977980
MAKE_CASE(NVPTXISD::Dummy)
@@ -2298,6 +2301,54 @@ SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
22982301
return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
22992302
}
23002303

2304+
SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
2305+
SelectionDAG &DAG) const {
2306+
SDLoc DL(Op.getNode());
2307+
if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2308+
const Function &Fn = DAG.getMachineFunction().getFunction();
2309+
2310+
DiagnosticInfoUnsupported NoStackRestore(
2311+
Fn,
2312+
"Support for stackrestore requires PTX ISA version >= 7.3 and target "
2313+
">= sm_52.",
2314+
DL.getDebugLoc());
2315+
DAG.getContext()->diagnose(NoStackRestore);
2316+
return Op.getOperand(0);
2317+
}
2318+
2319+
const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2320+
SDValue Chain = Op.getOperand(0);
2321+
SDValue Ptr = Op.getOperand(1);
2322+
SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC,
2323+
ADDRESS_SPACE_LOCAL);
2324+
return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
2325+
}
2326+
2327+
SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
2328+
SelectionDAG &DAG) const {
2329+
SDLoc DL(Op.getNode());
2330+
if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
2331+
const Function &Fn = DAG.getMachineFunction().getFunction();
2332+
2333+
DiagnosticInfoUnsupported NoStackSave(
2334+
Fn,
2335+
"Support for stacksave requires PTX ISA version >= 7.3 and target >= "
2336+
"sm_52.",
2337+
DL.getDebugLoc());
2338+
DAG.getContext()->diagnose(NoStackSave);
2339+
auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
2340+
return DAG.getMergeValues(Ops, DL);
2341+
}
2342+
2343+
const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
2344+
SDValue Chain = Op.getOperand(0);
2345+
SDValue SS =
2346+
DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
2347+
SDValue ASC = DAG.getAddrSpaceCast(
2348+
DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
2349+
return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
2350+
}
2351+
23012352
// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
23022353
// (see LegalizeDAG.cpp). This is slow and uses local memory.
23032354
// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
@@ -2909,6 +2960,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29092960
return LowerVectorArith(Op, DAG);
29102961
case ISD::DYNAMIC_STACKALLOC:
29112962
return LowerDYNAMIC_STACKALLOC(Op, DAG);
2963+
case ISD::STACKRESTORE:
2964+
return LowerSTACKRESTORE(Op, DAG);
2965+
case ISD::STACKSAVE:
2966+
return LowerSTACKSAVE(Op, DAG);
29122967
case ISD::CopyToReg:
29132968
return LowerCopyToReg_128(Op, DAG);
29142969
default:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ enum NodeType : unsigned {
6363
PRMT,
6464
FCOPYSIGN,
6565
DYNAMIC_STACKALLOC,
66+
STACKRESTORE,
67+
STACKSAVE,
6668
BrxStart,
6769
BrxItem,
6870
BrxEnd,
@@ -526,6 +528,8 @@ class NVPTXTargetLowering : public TargetLowering {
526528
SmallVectorImpl<SDValue> &InVals) const override;
527529

528530
SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
531+
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
532+
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
529533

530534
std::string
531535
getPrototype(const DataLayout &DL, Type *, const ArgListTy &,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3860,6 +3860,44 @@ foreach a_type = ["s", "u"] in {
38603860
}
38613861
}
38623862

3863+
//
3864+
// Stack Manipulation
3865+
//
3866+
3867+
def SDTStackRestore : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
3868+
3869+
def stackrestore :
3870+
SDNode<"NVPTXISD::STACKRESTORE", SDTStackRestore,
3871+
[SDNPHasChain, SDNPSideEffect]>;
3872+
3873+
def stacksave :
3874+
SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
3875+
[SDNPHasChain, SDNPSideEffect]>;
3876+
3877+
def STACKRESTORE_32 :
3878+
NVPTXInst<(outs), (ins Int32Regs:$ptr),
3879+
"stackrestore.u32 \t$ptr;",
3880+
[(stackrestore (i32 Int32Regs:$ptr))]>,
3881+
Requires<[hasPTX<73>, hasSM<52>]>;
3882+
3883+
def STACKSAVE_32 :
3884+
NVPTXInst<(outs Int32Regs:$dst), (ins),
3885+
"stacksave.u32 \t$dst;",
3886+
[(set Int32Regs:$dst, (i32 stacksave))]>,
3887+
Requires<[hasPTX<73>, hasSM<52>]>;
3888+
3889+
def STACKRESTORE_64 :
3890+
NVPTXInst<(outs), (ins Int64Regs:$ptr),
3891+
"stackrestore.u64 \t$ptr;",
3892+
[(stackrestore (i64 Int64Regs:$ptr))]>,
3893+
Requires<[hasPTX<73>, hasSM<52>]>;
3894+
3895+
def STACKSAVE_64 :
3896+
NVPTXInst<(outs Int64Regs:$dst), (ins),
3897+
"stacksave.u64 \t$dst;",
3898+
[(set Int64Regs:$dst, (i64 stacksave))]>,
3899+
Requires<[hasPTX<73>, hasSM<52>]>;
3900+
38633901
include "NVPTXIntrinsics.td"
38643902

38653903
//-----------------------------------
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-32
3+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-64
4+
; RUN: llc < %s -march=nvptx64 -nvptx-short-ptr -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-MIXED
5+
; RUN: %if ptxas && ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_60 -mattr=+ptx73 | %ptxas-verify %}
6+
7+
target triple = "nvptx64-nvidia-cuda"
8+
9+
define ptr @test_save() {
10+
; CHECK-32-LABEL: test_save(
11+
; CHECK-32: {
12+
; CHECK-32-NEXT: .reg .b32 %r<3>;
13+
; CHECK-32-EMPTY:
14+
; CHECK-32-NEXT: // %bb.0:
15+
; CHECK-32-NEXT: stacksave.u32 %r1;
16+
; CHECK-32-NEXT: cvta.local.u32 %r2, %r1;
17+
; CHECK-32-NEXT: st.param.b32 [func_retval0], %r2;
18+
; CHECK-32-NEXT: ret;
19+
;
20+
; CHECK-64-LABEL: test_save(
21+
; CHECK-64: {
22+
; CHECK-64-NEXT: .reg .b64 %rd<3>;
23+
; CHECK-64-EMPTY:
24+
; CHECK-64-NEXT: // %bb.0:
25+
; CHECK-64-NEXT: stacksave.u64 %rd1;
26+
; CHECK-64-NEXT: cvta.local.u64 %rd2, %rd1;
27+
; CHECK-64-NEXT: st.param.b64 [func_retval0], %rd2;
28+
; CHECK-64-NEXT: ret;
29+
;
30+
; CHECK-MIXED-LABEL: test_save(
31+
; CHECK-MIXED: {
32+
; CHECK-MIXED-NEXT: .reg .b32 %r<2>;
33+
; CHECK-MIXED-NEXT: .reg .b64 %rd<3>;
34+
; CHECK-MIXED-EMPTY:
35+
; CHECK-MIXED-NEXT: // %bb.0:
36+
; CHECK-MIXED-NEXT: stacksave.u32 %r1;
37+
; CHECK-MIXED-NEXT: cvt.u64.u32 %rd1, %r1;
38+
; CHECK-MIXED-NEXT: cvta.local.u64 %rd2, %rd1;
39+
; CHECK-MIXED-NEXT: st.param.b64 [func_retval0], %rd2;
40+
; CHECK-MIXED-NEXT: ret;
41+
%1 = call ptr @llvm.stacksave()
42+
ret ptr %1
43+
}
44+
45+
46+
define void @test_restore(ptr %p) {
47+
; CHECK-32-LABEL: test_restore(
48+
; CHECK-32: {
49+
; CHECK-32-NEXT: .reg .b32 %r<3>;
50+
; CHECK-32-EMPTY:
51+
; CHECK-32-NEXT: // %bb.0:
52+
; CHECK-32-NEXT: ld.param.u32 %r1, [test_restore_param_0];
53+
; CHECK-32-NEXT: cvta.to.local.u32 %r2, %r1;
54+
; CHECK-32-NEXT: stackrestore.u32 %r2;
55+
; CHECK-32-NEXT: ret;
56+
;
57+
; CHECK-64-LABEL: test_restore(
58+
; CHECK-64: {
59+
; CHECK-64-NEXT: .reg .b64 %rd<3>;
60+
; CHECK-64-EMPTY:
61+
; CHECK-64-NEXT: // %bb.0:
62+
; CHECK-64-NEXT: ld.param.u64 %rd1, [test_restore_param_0];
63+
; CHECK-64-NEXT: cvta.to.local.u64 %rd2, %rd1;
64+
; CHECK-64-NEXT: stackrestore.u64 %rd2;
65+
; CHECK-64-NEXT: ret;
66+
;
67+
; CHECK-MIXED-LABEL: test_restore(
68+
; CHECK-MIXED: {
69+
; CHECK-MIXED-NEXT: .reg .b32 %r<2>;
70+
; CHECK-MIXED-NEXT: .reg .b64 %rd<3>;
71+
; CHECK-MIXED-EMPTY:
72+
; CHECK-MIXED-NEXT: // %bb.0:
73+
; CHECK-MIXED-NEXT: ld.param.u64 %rd1, [test_restore_param_0];
74+
; CHECK-MIXED-NEXT: cvta.to.local.u64 %rd2, %rd1;
75+
; CHECK-MIXED-NEXT: cvt.u32.u64 %r1, %rd2;
76+
; CHECK-MIXED-NEXT: stackrestore.u32 %r1;
77+
; CHECK-MIXED-NEXT: ret;
78+
call void @llvm.stackrestore(ptr %p)
79+
ret void
80+
}
81+
82+
declare ptr @llvm.stacksave()
83+
declare void @llvm.stackrestore(ptr)

0 commit comments

Comments
 (0)