Skip to content

Commit 9d469b5

Browse files
rofirrimwangpc-pp
andauthored
[RISCV] Implement trampolines for rv64 (#96309)
This is implementation is based on what the X86 target does but emitting the instructions that GCC emits for rv64. --------- Co-authored-by: Pengcheng Wang <[email protected]>
1 parent 6bb6300 commit 9d469b5

File tree

3 files changed

+214
-0
lines changed

3 files changed

+214
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
#include "llvm/IR/Instructions.h"
3838
#include "llvm/IR/IntrinsicsRISCV.h"
3939
#include "llvm/IR/PatternMatch.h"
40+
#include "llvm/MC/MCCodeEmitter.h"
41+
#include "llvm/MC/MCInstBuilder.h"
4042
#include "llvm/Support/CommandLine.h"
4143
#include "llvm/Support/Debug.h"
4244
#include "llvm/Support/ErrorHandling.h"
@@ -625,6 +627,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
625627
setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64,
626628
Subtarget.is64Bit() ? Legal : Custom);
627629

630+
if (Subtarget.is64Bit()) {
631+
setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
632+
setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
633+
}
634+
628635
setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
629636
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
630637
if (Subtarget.is64Bit())
@@ -7402,6 +7409,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
74027409
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
74037410
Op.getOperand(2), Flags, DL);
74047411
}
7412+
case ISD::INIT_TRAMPOLINE:
7413+
return lowerINIT_TRAMPOLINE(Op, DAG);
7414+
case ISD::ADJUST_TRAMPOLINE:
7415+
return lowerADJUST_TRAMPOLINE(Op, DAG);
74057416
}
74067417
}
74077418

@@ -7417,6 +7428,126 @@ SDValue RISCVTargetLowering::emitFlushICache(SelectionDAG &DAG, SDValue InChain,
74177428
return CallResult.second;
74187429
}
74197430

7431+
SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
7432+
SelectionDAG &DAG) const {
7433+
if (!Subtarget.is64Bit())
7434+
llvm::report_fatal_error("Trampolines only implemented for RV64");
7435+
7436+
// Create an MCCodeEmitter to encode instructions.
7437+
TargetLoweringObjectFile *TLO = getTargetMachine().getObjFileLowering();
7438+
assert(TLO);
7439+
MCContext &MCCtx = TLO->getContext();
7440+
7441+
std::unique_ptr<MCCodeEmitter> CodeEmitter(
7442+
createRISCVMCCodeEmitter(*getTargetMachine().getMCInstrInfo(), MCCtx));
7443+
7444+
SDValue Root = Op.getOperand(0);
7445+
SDValue Trmp = Op.getOperand(1); // trampoline
7446+
SDLoc dl(Op);
7447+
7448+
const Value *TrmpAddr = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
7449+
7450+
// We store in the trampoline buffer the following instructions and data.
7451+
// Offset:
7452+
// 0: auipc t2, 0
7453+
// 4: ld t0, 24(t2)
7454+
// 8: ld t2, 16(t2)
7455+
// 12: jalr t0
7456+
// 16: <StaticChainOffset>
7457+
// 24: <FunctionAddressOffset>
7458+
// 32:
7459+
7460+
constexpr unsigned StaticChainOffset = 16;
7461+
constexpr unsigned FunctionAddressOffset = 24;
7462+
7463+
const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
7464+
assert(STI);
7465+
auto GetEncoding = [&](const MCInst &MC) {
7466+
SmallVector<char, 4> CB;
7467+
SmallVector<MCFixup> Fixups;
7468+
CodeEmitter->encodeInstruction(MC, CB, Fixups, *STI);
7469+
uint32_t Encoding = support::endian::read32le(CB.data());
7470+
return Encoding;
7471+
};
7472+
7473+
SDValue OutChains[6];
7474+
7475+
uint32_t Encodings[] = {
7476+
// auipc t2, 0
7477+
// Loads the current PC into t2.
7478+
GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
7479+
// ld t0, 24(t2)
7480+
// Loads the function address into t0. Note that we are using offsets
7481+
// pc-relative to the first instruction of the trampoline.
7482+
GetEncoding(
7483+
MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
7484+
FunctionAddressOffset)),
7485+
// ld t2, 16(t2)
7486+
// Load the value of the static chain.
7487+
GetEncoding(
7488+
MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
7489+
StaticChainOffset)),
7490+
// jalr t0
7491+
// Jump to the function.
7492+
GetEncoding(MCInstBuilder(RISCV::JALR)
7493+
.addReg(RISCV::X0)
7494+
.addReg(RISCV::X5)
7495+
.addImm(0))};
7496+
7497+
// Store encoded instructions.
7498+
for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
7499+
SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7500+
DAG.getConstant(Idx * 4, dl, MVT::i64))
7501+
: Trmp;
7502+
OutChains[Idx] = DAG.getTruncStore(
7503+
Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
7504+
MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
7505+
}
7506+
7507+
// Now store the variable part of the trampoline.
7508+
SDValue FunctionAddress = Op.getOperand(2);
7509+
SDValue StaticChain = Op.getOperand(3);
7510+
7511+
// Store the given static chain and function pointer in the trampoline buffer.
7512+
struct OffsetValuePair {
7513+
const unsigned Offset;
7514+
const SDValue Value;
7515+
SDValue Addr = SDValue(); // Used to cache the address.
7516+
} OffsetValues[] = {
7517+
{StaticChainOffset, StaticChain},
7518+
{FunctionAddressOffset, FunctionAddress},
7519+
};
7520+
for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) {
7521+
SDValue Addr =
7522+
DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7523+
DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
7524+
OffsetValue.Addr = Addr;
7525+
OutChains[Idx + 4] =
7526+
DAG.getStore(Root, dl, OffsetValue.Value, Addr,
7527+
MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
7528+
}
7529+
7530+
SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
7531+
7532+
// The end of instructions of trampoline is the same as the static chain
7533+
// address that we computed earlier.
7534+
SDValue EndOfTrmp = OffsetValues[0].Addr;
7535+
7536+
// Call clear cache on the trampoline instructions.
7537+
SDValue Chain = DAG.getNode(ISD::CLEAR_CACHE, dl, MVT::Other, StoreToken,
7538+
Trmp, EndOfTrmp);
7539+
7540+
return Chain;
7541+
}
7542+
7543+
SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
7544+
SelectionDAG &DAG) const {
7545+
if (!Subtarget.is64Bit())
7546+
llvm::report_fatal_error("Trampolines only implemented for RV64");
7547+
7548+
return Op.getOperand(0);
7549+
}
7550+
74207551
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
74217552
SelectionDAG &DAG, unsigned Flags) {
74227553
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,9 @@ class RISCVTargetLowering : public TargetLowering {
992992
SDValue expandUnalignedRVVLoad(SDValue Op, SelectionDAG &DAG) const;
993993
SDValue expandUnalignedRVVStore(SDValue Op, SelectionDAG &DAG) const;
994994

995+
SDValue lowerINIT_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
996+
SDValue lowerADJUST_TRAMPOLINE(SDValue Op, SelectionDAG &DAG) const;
997+
995998
bool isEligibleForTailCallOptimization(
996999
CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
9971000
const SmallVector<CCValAssign, 16> &ArgLocs) const;
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s \
3+
; RUN: | FileCheck -check-prefix=RV64 %s
4+
; RUN: llc -mtriple=riscv64-unknown-linux-gnu -verify-machineinstrs < %s \
5+
; RUN: | FileCheck -check-prefix=RV64-LINUX %s
6+
7+
declare void @llvm.init.trampoline(ptr, ptr, ptr)
8+
declare ptr @llvm.adjust.trampoline(ptr)
9+
declare i64 @f(ptr nest, i64)
10+
11+
define i64 @test0(i64 %n, ptr %p) nounwind {
12+
; RV64-LABEL: test0:
13+
; RV64: # %bb.0:
14+
; RV64-NEXT: addi sp, sp, -64
15+
; RV64-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
16+
; RV64-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
17+
; RV64-NEXT: sd s1, 40(sp) # 8-byte Folded Spill
18+
; RV64-NEXT: mv s0, a0
19+
; RV64-NEXT: lui a0, %hi(f)
20+
; RV64-NEXT: addi a0, a0, %lo(f)
21+
; RV64-NEXT: sd a0, 32(sp)
22+
; RV64-NEXT: li a0, 919
23+
; RV64-NEXT: lui a2, %hi(.LCPI0_0)
24+
; RV64-NEXT: ld a2, %lo(.LCPI0_0)(a2)
25+
; RV64-NEXT: lui a3, 6203
26+
; RV64-NEXT: addi a3, a3, 643
27+
; RV64-NEXT: sw a0, 8(sp)
28+
; RV64-NEXT: sw a3, 12(sp)
29+
; RV64-NEXT: sd a2, 16(sp)
30+
; RV64-NEXT: sd a1, 24(sp)
31+
; RV64-NEXT: addi a1, sp, 24
32+
; RV64-NEXT: addi a0, sp, 8
33+
; RV64-NEXT: addi s1, sp, 8
34+
; RV64-NEXT: call __clear_cache
35+
; RV64-NEXT: mv a0, s0
36+
; RV64-NEXT: jalr s1
37+
; RV64-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
38+
; RV64-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
39+
; RV64-NEXT: ld s1, 40(sp) # 8-byte Folded Reload
40+
; RV64-NEXT: addi sp, sp, 64
41+
; RV64-NEXT: ret
42+
;
43+
; RV64-LINUX-LABEL: test0:
44+
; RV64-LINUX: # %bb.0:
45+
; RV64-LINUX-NEXT: addi sp, sp, -64
46+
; RV64-LINUX-NEXT: sd ra, 56(sp) # 8-byte Folded Spill
47+
; RV64-LINUX-NEXT: sd s0, 48(sp) # 8-byte Folded Spill
48+
; RV64-LINUX-NEXT: sd s1, 40(sp) # 8-byte Folded Spill
49+
; RV64-LINUX-NEXT: mv s0, a0
50+
; RV64-LINUX-NEXT: lui a0, %hi(f)
51+
; RV64-LINUX-NEXT: addi a0, a0, %lo(f)
52+
; RV64-LINUX-NEXT: sd a0, 32(sp)
53+
; RV64-LINUX-NEXT: li a0, 919
54+
; RV64-LINUX-NEXT: lui a2, %hi(.LCPI0_0)
55+
; RV64-LINUX-NEXT: ld a2, %lo(.LCPI0_0)(a2)
56+
; RV64-LINUX-NEXT: lui a3, 6203
57+
; RV64-LINUX-NEXT: addi a3, a3, 643
58+
; RV64-LINUX-NEXT: sw a0, 8(sp)
59+
; RV64-LINUX-NEXT: sw a3, 12(sp)
60+
; RV64-LINUX-NEXT: sd a2, 16(sp)
61+
; RV64-LINUX-NEXT: sd a1, 24(sp)
62+
; RV64-LINUX-NEXT: addi a1, sp, 24
63+
; RV64-LINUX-NEXT: addi a0, sp, 8
64+
; RV64-LINUX-NEXT: addi s1, sp, 8
65+
; RV64-LINUX-NEXT: li a2, 0
66+
; RV64-LINUX-NEXT: call __riscv_flush_icache
67+
; RV64-LINUX-NEXT: mv a0, s0
68+
; RV64-LINUX-NEXT: jalr s1
69+
; RV64-LINUX-NEXT: ld ra, 56(sp) # 8-byte Folded Reload
70+
; RV64-LINUX-NEXT: ld s0, 48(sp) # 8-byte Folded Reload
71+
; RV64-LINUX-NEXT: ld s1, 40(sp) # 8-byte Folded Reload
72+
; RV64-LINUX-NEXT: addi sp, sp, 64
73+
; RV64-LINUX-NEXT: ret
74+
%alloca = alloca [32 x i8], align 8
75+
call void @llvm.init.trampoline(ptr %alloca, ptr @f, ptr %p)
76+
%tramp = call ptr @llvm.adjust.trampoline(ptr %alloca)
77+
%ret = call i64 %tramp(i64 %n)
78+
ret i64 %ret
79+
80+
}

0 commit comments

Comments
 (0)