|
53 | 53 | #include "llvm/Support/CodeGen.h"
|
54 | 54 | #include "llvm/Support/CommandLine.h"
|
55 | 55 | #include "llvm/Support/ErrorHandling.h"
|
| 56 | +#include "llvm/Support/NVPTXAddrSpace.h" |
56 | 57 | #include "llvm/Support/raw_ostream.h"
|
57 | 58 | #include "llvm/Target/TargetMachine.h"
|
58 | 59 | #include "llvm/Target/TargetOptions.h"
|
@@ -678,8 +679,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
|
678 | 679 | setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
|
679 | 680 | setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
|
680 | 681 |
|
681 |
| - setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i32, Custom); |
682 |
| - setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom); |
| 682 | + setOperationAction(ISD::DYNAMIC_STACKALLOC, {MVT::i32, MVT::i64}, Custom); |
| 683 | + setOperationAction({ISD::STACKRESTORE, ISD::STACKSAVE}, MVT::Other, Custom); |
683 | 684 |
|
684 | 685 | // TRAP can be lowered to PTX trap
|
685 | 686 | setOperationAction(ISD::TRAP, MVT::Other, Legal);
|
@@ -972,6 +973,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
|
972 | 973 | MAKE_CASE(NVPTXISD::PRMT)
|
973 | 974 | MAKE_CASE(NVPTXISD::FCOPYSIGN)
|
974 | 975 | MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
|
| 976 | + MAKE_CASE(NVPTXISD::STACKRESTORE) |
| 977 | + MAKE_CASE(NVPTXISD::STACKSAVE) |
975 | 978 | MAKE_CASE(NVPTXISD::SETP_F16X2)
|
976 | 979 | MAKE_CASE(NVPTXISD::SETP_BF16X2)
|
977 | 980 | MAKE_CASE(NVPTXISD::Dummy)
|
@@ -2298,6 +2301,54 @@ SDValue NVPTXTargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
|
2298 | 2301 | return DAG.getNode(NVPTXISD::DYNAMIC_STACKALLOC, DL, RetTypes, AllocOps);
|
2299 | 2302 | }
|
2300 | 2303 |
|
| 2304 | +SDValue NVPTXTargetLowering::LowerSTACKRESTORE(SDValue Op, |
| 2305 | + SelectionDAG &DAG) const { |
| 2306 | + SDLoc DL(Op.getNode()); |
| 2307 | + if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { |
| 2308 | + const Function &Fn = DAG.getMachineFunction().getFunction(); |
| 2309 | + |
| 2310 | + DiagnosticInfoUnsupported NoStackRestore( |
| 2311 | + Fn, |
| 2312 | + "Support for stackrestore requires PTX ISA version >= 7.3 and target " |
| 2313 | + ">= sm_52.", |
| 2314 | + DL.getDebugLoc()); |
| 2315 | + DAG.getContext()->diagnose(NoStackRestore); |
| 2316 | + return Op.getOperand(0); |
| 2317 | + } |
| 2318 | + |
| 2319 | + const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL); |
| 2320 | + SDValue Chain = Op.getOperand(0); |
| 2321 | + SDValue Ptr = Op.getOperand(1); |
| 2322 | + SDValue ASC = DAG.getAddrSpaceCast(DL, LocalVT, Ptr, ADDRESS_SPACE_GENERIC, |
| 2323 | + ADDRESS_SPACE_LOCAL); |
| 2324 | + return DAG.getNode(NVPTXISD::STACKRESTORE, DL, MVT::Other, {Chain, ASC}); |
| 2325 | +} |
| 2326 | + |
| 2327 | +SDValue NVPTXTargetLowering::LowerSTACKSAVE(SDValue Op, |
| 2328 | + SelectionDAG &DAG) const { |
| 2329 | + SDLoc DL(Op.getNode()); |
| 2330 | + if (STI.getPTXVersion() < 73 || STI.getSmVersion() < 52) { |
| 2331 | + const Function &Fn = DAG.getMachineFunction().getFunction(); |
| 2332 | + |
| 2333 | + DiagnosticInfoUnsupported NoStackSave( |
| 2334 | + Fn, |
| 2335 | + "Support for stacksave requires PTX ISA version >= 7.3 and target >= " |
| 2336 | + "sm_52.", |
| 2337 | + DL.getDebugLoc()); |
| 2338 | + DAG.getContext()->diagnose(NoStackSave); |
| 2339 | + auto Ops = {DAG.getConstant(0, DL, Op.getValueType()), Op.getOperand(0)}; |
| 2340 | + return DAG.getMergeValues(Ops, DL); |
| 2341 | + } |
| 2342 | + |
| 2343 | + const MVT LocalVT = getPointerTy(DAG.getDataLayout(), ADDRESS_SPACE_LOCAL); |
| 2344 | + SDValue Chain = Op.getOperand(0); |
| 2345 | + SDValue SS = |
| 2346 | + DAG.getNode(NVPTXISD::STACKSAVE, DL, {LocalVT, MVT::Other}, Chain); |
| 2347 | + SDValue ASC = DAG.getAddrSpaceCast( |
| 2348 | + DL, Op.getValueType(), SS, ADDRESS_SPACE_LOCAL, ADDRESS_SPACE_GENERIC); |
| 2349 | + return DAG.getMergeValues({ASC, SDValue(SS.getNode(), 1)}, DL); |
| 2350 | +} |
| 2351 | + |
2301 | 2352 | // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
|
2302 | 2353 | // (see LegalizeDAG.cpp). This is slow and uses local memory.
|
2303 | 2354 | // We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
|
@@ -2909,6 +2960,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
|
2909 | 2960 | return LowerVectorArith(Op, DAG);
|
2910 | 2961 | case ISD::DYNAMIC_STACKALLOC:
|
2911 | 2962 | return LowerDYNAMIC_STACKALLOC(Op, DAG);
|
| 2963 | + case ISD::STACKRESTORE: |
| 2964 | + return LowerSTACKRESTORE(Op, DAG); |
| 2965 | + case ISD::STACKSAVE: |
| 2966 | + return LowerSTACKSAVE(Op, DAG); |
2912 | 2967 | case ISD::CopyToReg:
|
2913 | 2968 | return LowerCopyToReg_128(Op, DAG);
|
2914 | 2969 | default:
|
|
0 commit comments