Skip to content

Commit 1e5b136

Browse files
committed
[RISCV] Implement trampolines for rv64
This is implementation is heavily based on what the X86 target does but emitting the instructions that GCC emits for rv64.
1 parent 14dfdc0 commit 1e5b136

File tree

3 files changed

+209
-0
lines changed

3 files changed

+209
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
633633
setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64,
634634
Subtarget.is64Bit() ? Legal : Custom);
635635

636+
if (Subtarget.is64Bit()) {
637+
setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
638+
setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
639+
}
640+
636641
setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
637642
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
638643
if (Subtarget.is64Bit())
@@ -7264,6 +7269,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
72647269
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
72657270
Op.getOperand(2), Flags, DL);
72667271
}
7272+
case ISD::INIT_TRAMPOLINE:
7273+
return lowerINIT_TRAMPOLINE(Op, DAG);
7274+
case ISD::ADJUST_TRAMPOLINE:
7275+
return lowerADJUST_TRAMPOLINE(Op, DAG);
72677276
}
72687277
}
72697278

@@ -7279,6 +7288,123 @@ SDValue RISCVTargetLowering::emitFlushICache(SelectionDAG &DAG, SDValue InChain,
72797288
return CallResult.second;
72807289
}
72817290

7291+
SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
7292+
SelectionDAG &DAG) const {
7293+
if (!Subtarget.is64Bit())
7294+
llvm::report_fatal_error("Trampolines only implemented for RV64");
7295+
7296+
SDValue Root = Op.getOperand(0);
7297+
SDValue Trmp = Op.getOperand(1); // trampoline
7298+
SDLoc dl(Op);
7299+
7300+
const Value *TrmpAddr = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
7301+
7302+
// We store in the trampoline buffer the following instructions and data.
7303+
// Offset:
7304+
// 0: auipc t2, 0
7305+
// 4: ld t0, 24(t2)
7306+
// 8: ld t2, 16(t2)
7307+
// 12: jalr t0
7308+
// 16: <StaticChainOffset>
7309+
// 24: <FunctionAddressOffset>
7310+
// 32:
7311+
7312+
// Constants shamelessly taken from GCC.
7313+
constexpr unsigned Opcode_AUIPC = 0x17;
7314+
constexpr unsigned Opcode_LD = 0x3003;
7315+
constexpr unsigned Opcode_JALR = 0x67;
7316+
constexpr unsigned ShiftField_RD = 7;
7317+
constexpr unsigned ShiftField_RS1 = 15;
7318+
constexpr unsigned ShiftField_IMM = 20;
7319+
constexpr unsigned Reg_X5 = 0x5; // x5/t0 (holds the address to the function)
7320+
constexpr unsigned Reg_X7 = 0x7; // x7/t2 (holds the static chain)
7321+
7322+
constexpr unsigned StaticChainOffset = 16;
7323+
constexpr unsigned FunctionAddressOffset = 24;
7324+
7325+
SDValue OutChains[6];
7326+
SDValue Addr = Trmp;
7327+
7328+
// auipc t2, 0
7329+
// Loads the current PC into t2.
7330+
constexpr uint32_t AUIPC_X7_0 =
7331+
Opcode_AUIPC | (Reg_X7 << ShiftField_RD);
7332+
OutChains[0] =
7333+
DAG.getTruncStore(Root, dl, DAG.getConstant(AUIPC_X7_0, dl, MVT::i64),
7334+
Addr, MachinePointerInfo(TrmpAddr), MVT::i32);
7335+
7336+
// ld t0, 24(t2)
7337+
// Loads the function address into t0. Note that we are using offsets
7338+
// pc-relative to the first instruction of the trampoline.
7339+
const uint32_t LD_X5_TargetFunctionOffset =
7340+
Opcode_LD | (Reg_X5 << ShiftField_RD) |
7341+
(Reg_X7 << ShiftField_RS1) | (FunctionAddressOffset << ShiftField_IMM);
7342+
Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7343+
DAG.getConstant(4, dl, MVT::i64));
7344+
OutChains[1] = DAG.getTruncStore(
7345+
Root, dl,
7346+
DAG.getConstant(LD_X5_TargetFunctionOffset, dl, MVT::i64), Addr,
7347+
MachinePointerInfo(TrmpAddr, 4), MVT::i32);
7348+
7349+
// ld t2, 16(t2)
7350+
// Load the value of the static chain.
7351+
const uint32_t LD_X7_StaticChainOffset =
7352+
Opcode_LD | (Reg_X7 << ShiftField_RD) |
7353+
(Reg_X7 << ShiftField_RS1) | (StaticChainOffset << ShiftField_IMM);
7354+
Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7355+
DAG.getConstant(8, dl, MVT::i64));
7356+
OutChains[2] = DAG.getTruncStore(
7357+
Root, dl, DAG.getConstant(LD_X7_StaticChainOffset, dl, MVT::i64),
7358+
Addr, MachinePointerInfo(TrmpAddr, 8), MVT::i32);
7359+
7360+
// jalr t0
7361+
// Jump to the function.
7362+
const uint32_t JALR_X5 =
7363+
Opcode_JALR | (Reg_X5 << ShiftField_RS1);
7364+
Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7365+
DAG.getConstant(12, dl, MVT::i64));
7366+
OutChains[3] =
7367+
DAG.getTruncStore(Root, dl, DAG.getConstant(JALR_X5, dl, MVT::i64), Addr,
7368+
MachinePointerInfo(TrmpAddr, 12), MVT::i32);
7369+
7370+
// Now store the variable part of the trampoline.
7371+
SDValue FunctionAddress = Op.getOperand(2);
7372+
SDValue StaticChain = Op.getOperand(3);
7373+
7374+
// Store the given static chain in the trampoline buffer.
7375+
Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7376+
DAG.getConstant(StaticChainOffset, dl, MVT::i64));
7377+
OutChains[4] = DAG.getStore(Root, dl, StaticChain, Addr,
7378+
MachinePointerInfo(TrmpAddr, StaticChainOffset));
7379+
7380+
// Store the given function address in the trampoline buffer.
7381+
Addr = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7382+
DAG.getConstant(FunctionAddressOffset, dl, MVT::i64));
7383+
OutChains[5] =
7384+
DAG.getStore(Root, dl, FunctionAddress, Addr,
7385+
MachinePointerInfo(TrmpAddr, FunctionAddressOffset));
7386+
7387+
SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
7388+
7389+
// Compute end of trampoline.
7390+
SDValue EndOfTrmp = DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7391+
DAG.getConstant(32, dl, MVT::i64));
7392+
7393+
// Call clear cache on the trampoline buffer.
7394+
SDValue Chain = DAG.getNode(ISD::CLEAR_CACHE, dl, MVT::Other, StoreToken,
7395+
Trmp, EndOfTrmp);
7396+
7397+
return Chain;
7398+
}
7399+
7400+
SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
7401+
SelectionDAG &DAG) const {
7402+
if (!Subtarget.is64Bit())
7403+
llvm::report_fatal_error("Trampolines only implemented for RV64");
7404+
7405+
return Op.getOperand(0);
7406+
}
7407+
72827408
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
72837409
SelectionDAG &DAG, unsigned Flags) {
72847410
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,9 @@ class RISCVTargetLowering : public TargetLowering {
998998
SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
999999
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;
10001000

1001+
SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
1002+
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
1003+
10011004
bool isEligibleForTailCallOptimization(
10021005
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
10031006
const SmallVector<CCValAssign, 16> &ArgLocs) const;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s \
3+
; RUN: | FileCheck -check-prefix=RV64 %s
4+
; RUN: llc -mtriple=riscv64-unknown-linux-gnu -verify-machineinstrs < %s \
5+
; RUN: | FileCheck -check-prefix=RV64-LINUX %s
6+
7+
declare void @llvm.init.trampoline(ptr, ptr, ptr)
8+
declare ptr @llvm.adjust.trampoline(ptr)
9+
declare i64 @f(ptr nest, i64)
10+
11+
define i64 @test0(i64 %n, ptr %p) nounwind {
12+
; RV64-LABEL: test0:
13+
; RV64: # %bb.0:
14+
; RV64-NEXT: addi sp, sp, -64
15+
; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
16+
; RV64-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
17+
; RV64-NEXT: sd s1, 40(sp) # 8-byte Folded Spill
18+
; RV64-NEXT: mv s0, a0
19+
; RV64-NEXT: lui a0, %hi(.LCPI0_0)
20+
; RV64-NEXT: ld a0, %lo(.LCPI0_0)(a0)
21+
; RV64-NEXT: lui a2, %hi(f)
22+
; RV64-NEXT: addi a2, a2, %lo(f)
23+
; RV64-NEXT: sd a2, 32(sp)
24+
; RV64-NEXT: sd a1, 24(sp)
25+
; RV64-NEXT: sd a0, 16(sp)
26+
; RV64-NEXT: lui a0, 6203
27+
; RV64-NEXT: addi a0, a0, 643
28+
; RV64-NEXT: slli a0, a0, 32
29+
; RV64-NEXT: addi a0, a0, 919
30+
; RV64-NEXT: sd a0, 8(sp)
31+
; RV64-NEXT: addi a1, sp, 40
32+
; RV64-NEXT: addi a0, sp, 8
33+
; RV64-NEXT: addi s1, sp, 8
34+
; RV64-NEXT: call __clear_cache
35+
; RV64-NEXT: mv a0, s0
36+
; RV64-NEXT: jalr s1
37+
; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
38+
; RV64-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
39+
; RV64-NEXT: ld s1, 40(sp) # 8-byte Folded Reload
40+
; RV64-NEXT: addi sp, sp, 64
41+
; RV64-NEXT: ret
42+
;
43+
; RV64-LINUX-LABEL: test0:
44+
; RV64-LINUX: # %bb.0:
45+
; RV64-LINUX-NEXT: addi sp, sp, -64
46+
; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
47+
; RV64-LINUX-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
48+
; RV64-LINUX-NEXT: sd s1, 40(sp) # 8-byte Folded Spill
49+
; RV64-LINUX-NEXT: mv s0, a0
50+
; RV64-LINUX-NEXT: lui a0, %hi(.LCPI0_0)
51+
; RV64-LINUX-NEXT: ld a0, %lo(.LCPI0_0)(a0)
52+
; RV64-LINUX-NEXT: lui a2, %hi(f)
53+
; RV64-LINUX-NEXT: addi a2, a2, %lo(f)
54+
; RV64-LINUX-NEXT: sd a2, 32(sp)
55+
; RV64-LINUX-NEXT: sd a1, 24(sp)
56+
; RV64-LINUX-NEXT: sd a0, 16(sp)
57+
; RV64-LINUX-NEXT: lui a0, 6203
58+
; RV64-LINUX-NEXT: addi a0, a0, 643
59+
; RV64-LINUX-NEXT: slli a0, a0, 32
60+
; RV64-LINUX-NEXT: addi a0, a0, 919
61+
; RV64-LINUX-NEXT: sd a0, 8(sp)
62+
; RV64-LINUX-NEXT: addi a1, sp, 40
63+
; RV64-LINUX-NEXT: addi a0, sp, 8
64+
; RV64-LINUX-NEXT: addi s1, sp, 8
65+
; RV64-LINUX-NEXT: li a2, 0
66+
; RV64-LINUX-NEXT: call __riscv_flush_icache
67+
; RV64-LINUX-NEXT: mv a0, s0
68+
; RV64-LINUX-NEXT: jalr s1
69+
; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
70+
; RV64-LINUX-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
71+
; RV64-LINUX-NEXT: ld s1, 40(sp) # 8-byte Folded Reload
72+
; RV64-LINUX-NEXT: addi sp, sp, 64
73+
; RV64-LINUX-NEXT: ret
74+
%alloca = alloca [32 x i8], align 8
75+
call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
76+
%tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
77+
%ret = call i64 %tramp(i64 %n)
78+
ret i64 %ret
79+
80+
}

0 commit comments

Comments
 (0)