Skip to content

[NVPTX] Add support for stacksave, stackrestore intrinsics #114484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
Expand Down Expand Up @@ -667,8 +668,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);

setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i32, Custom);
setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom);
setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom);

// TRAP can be lowered to PTX trap
setOperationAction(ISD::TRAP, MVT::Other, Legal);
Expand Down Expand Up @@ -961,6 +962,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::PRMT)
MAKE_CASE(NVPTXISD::FCOPYSIGN)
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
MAKE_CASE(NVPTXISD::STACKRESTORE)
MAKE_CASE(NVPTXISD::STACKSAVE)
MAKE_CASE(NVPTXISD::SETP_F16X2)
MAKE_CASE(NVPTXISD::SETP_BF16X2)
MAKE_CASE(NVPTXISD::Dummy)
Expand Down Expand Up @@ -2287,6 +2290,54 @@ SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
}

SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op.getNode());
if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
const Function &Fn = DAG.getMachineFunction().getFunction();

DiagnosticInfoUnsupported NoStackRestore(
Fn,
"Support for stackrestore requires PTX ISA version >= 7.3 and target "
">= sm_52.",
DL.getDebugLoc());
DAG.getContext()->diagnose(NoStackRestore);
return Op.getOperand(0);
}

const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
SDValue Chain = Op.getOperand(0);
SDValue Ptr = Op.getOperand(1);
SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC,
ADDRESS_SPACE_LOCAL);
return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC});
}

SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op.getNode());
if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) {
const Function &Fn = DAG.getMachineFunction().getFunction();

DiagnosticInfoUnsupported NoStackSave(
Fn,
"Support for stacksave requires PTX ISA version >= 7.3 and target >= "
"sm_52.",
DL.getDebugLoc());
DAG.getContext()->diagnose(NoStackSave);
auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)};
return DAG.getMergeValues(Ops, DL);
}

const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL);
SDValue Chain = Op.getOperand(0);
SDValue SS =
DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain);
SDValue ASC = DAG.getAddrSpaceCast(
DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC);
return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL);
}

// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
// (see LegalizeDAG.cpp). This is slow and uses local memory.
// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
Expand Down Expand Up @@ -2871,6 +2922,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVectorArith(Op, DAG);
case ISD::DYNAMIC_STACKALLOC:
return LowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::STACKRESTORE:
return LowerSTACKRESTORE(Op, DAG);
case ISD::STACKSAVE:
return LowerSTACKSAVE(Op, DAG);
case ISD::CopyToReg:
return LowerCopyToReg_128(Op, DAG);
default:
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ enum NodeType : unsigned {
PRMT,
FCOPYSIGN,
DYNAMIC_STACKALLOC,
STACKRESTORE,
STACKSAVE,
BrxStart,
BrxItem,
BrxEnd,
Expand Down Expand Up @@ -526,6 +528,8 @@ class NVPTXTargetLowering : public TargetLowering {
SmallVectorImpl<SDValue> &InVals) const override;

SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;

std::string
getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3860,6 +3860,44 @@ foreach a_type = ["s", "u"] in {
}
}

//
// Stack Manipulation
//

def SDTStackRestore : SDTypeProfile<0, 1, [SDTCisInt<0>]>;

def stackrestore :
SDNode<"NVPTXISD::STACKRESTORE", SDTStackRestore,
[SDNPHasChain, SDNPSideEffect]>;

def stacksave :
SDNode<"NVPTXISD::STACKSAVE", SDTIntLeaf,
[SDNPHasChain, SDNPSideEffect]>;

def STACKRESTORE_32 :
NVPTXInst<(outs), (ins Int32Regs:$ptr),
"stackrestore.u32 \t$ptr;",
[(stackrestore (i32 Int32Regs:$ptr))]>,
Requires<[hasPTX<73>, hasSM<52>]>;

def STACKSAVE_32 :
NVPTXInst<(outs Int32Regs:$dst), (ins),
"stacksave.u32 \t$dst;",
[(set Int32Regs:$dst, (i32 stacksave))]>,
Requires<[hasPTX<73>, hasSM<52>]>;

def STACKRESTORE_64 :
NVPTXInst<(outs), (ins Int64Regs:$ptr),
"stackrestore.u64 \t$ptr;",
[(stackrestore (i64 Int64Regs:$ptr))]>,
Requires<[hasPTX<73>, hasSM<52>]>;

def STACKSAVE_64 :
NVPTXInst<(outs Int64Regs:$dst), (ins),
"stacksave.u64 \t$dst;",
[(set Int64Regs:$dst, (i64 stacksave))]>,
Requires<[hasPTX<73>, hasSM<52>]>;

include "NVPTXIntrinsics.td"

//-----------------------------------
Expand Down
83 changes: 83 additions & 0 deletions llvm/test/CodeGen/NVPTX/stacksaverestore.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -march=nvptx -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-32
; RUN: llc < %s -march=nvptx64 -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-64
; RUN: llc < %s -march=nvptx64 -nvptx-short-ptr -mcpu=sm_60 -mattr=+ptx73 | FileCheck %s --check-prefix=CHECK-MIXED
; RUN: %if ptxas && ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_60 -mattr=+ptx73 | %ptxas-verify %}

target triple = "nvptx64-nvidia-cuda"

define ptr @test_save() {
; CHECK-32-LABEL: test_save(
; CHECK-32: {
; CHECK-32-NEXT: .reg .b32 %r<3>;
; CHECK-32-EMPTY:
; CHECK-32-NEXT: // %bb.0:
; CHECK-32-NEXT: stacksave.u32 %r1;
; CHECK-32-NEXT: cvta.local.u32 %r2, %r1;
; CHECK-32-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-32-NEXT: ret;
;
; CHECK-64-LABEL: test_save(
; CHECK-64: {
; CHECK-64-NEXT: .reg .b64 %rd<3>;
; CHECK-64-EMPTY:
; CHECK-64-NEXT: // %bb.0:
; CHECK-64-NEXT: stacksave.u64 %rd1;
; CHECK-64-NEXT: cvta.local.u64 %rd2, %rd1;
; CHECK-64-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-64-NEXT: ret;
;
; CHECK-MIXED-LABEL: test_save(
; CHECK-MIXED: {
; CHECK-MIXED-NEXT: .reg .b32 %r<2>;
; CHECK-MIXED-NEXT: .reg .b64 %rd<3>;
; CHECK-MIXED-EMPTY:
; CHECK-MIXED-NEXT: // %bb.0:
; CHECK-MIXED-NEXT: stacksave.u32 %r1;
; CHECK-MIXED-NEXT: cvt.u64.u32 %rd1, %r1;
; CHECK-MIXED-NEXT: cvta.local.u64 %rd2, %rd1;
; CHECK-MIXED-NEXT: st.param.b64 [func_retval0], %rd2;
; CHECK-MIXED-NEXT: ret;
%1 = call ptr @llvm.stacksave()
ret ptr %1
}


define void @test_restore(ptr %p) {
; CHECK-32-LABEL: test_restore(
; CHECK-32: {
; CHECK-32-NEXT: .reg .b32 %r<3>;
; CHECK-32-EMPTY:
; CHECK-32-NEXT: // %bb.0:
; CHECK-32-NEXT: ld.param.u32 %r1, [test_restore_param_0];
; CHECK-32-NEXT: cvta.to.local.u32 %r2, %r1;
; CHECK-32-NEXT: stackrestore.u32 %r2;
; CHECK-32-NEXT: ret;
;
; CHECK-64-LABEL: test_restore(
; CHECK-64: {
; CHECK-64-NEXT: .reg .b64 %rd<3>;
; CHECK-64-EMPTY:
; CHECK-64-NEXT: // %bb.0:
; CHECK-64-NEXT: ld.param.u64 %rd1, [test_restore_param_0];
; CHECK-64-NEXT: cvta.to.local.u64 %rd2, %rd1;
; CHECK-64-NEXT: stackrestore.u64 %rd2;
; CHECK-64-NEXT: ret;
;
; CHECK-MIXED-LABEL: test_restore(
; CHECK-MIXED: {
; CHECK-MIXED-NEXT: .reg .b32 %r<2>;
; CHECK-MIXED-NEXT: .reg .b64 %rd<3>;
; CHECK-MIXED-EMPTY:
; CHECK-MIXED-NEXT: // %bb.0:
; CHECK-MIXED-NEXT: ld.param.u64 %rd1, [test_restore_param_0];
; CHECK-MIXED-NEXT: cvta.to.local.u64 %rd2, %rd1;
; CHECK-MIXED-NEXT: cvt.u32.u64 %r1, %rd2;
; CHECK-MIXED-NEXT: stackrestore.u32 %r1;
; CHECK-MIXED-NEXT: ret;
call void @llvm.stackrestore(ptr %p)
ret void
}

declare ptr @llvm.stacksave()
declare void @llvm.stackrestore(ptr)
Loading