37
37
#include "llvm/IR/Instructions.h"
38
38
#include "llvm/IR/IntrinsicsRISCV.h"
39
39
#include "llvm/IR/PatternMatch.h"
40
+ #include "llvm/MC/MCCodeEmitter.h"
41
+ #include "llvm/MC/MCInstBuilder.h"
40
42
#include "llvm/Support/CommandLine.h"
41
43
#include "llvm/Support/Debug.h"
42
44
#include "llvm/Support/ErrorHandling.h"
@@ -625,6 +627,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
625
627
setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64,
626
628
Subtarget.is64Bit() ? Legal : Custom);
627
629
630
+ if (Subtarget.is64Bit()) {
631
+ setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
632
+ setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
633
+ }
634
+
628
635
setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
629
636
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
630
637
if (Subtarget.is64Bit())
@@ -7402,6 +7409,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
7402
7409
return emitFlushICache(DAG, Op.getOperand(0), Op.getOperand(1),
7403
7410
Op.getOperand(2), Flags, DL);
7404
7411
}
7412
+ case ISD::INIT_TRAMPOLINE:
7413
+ return lowerINIT_TRAMPOLINE(Op, DAG);
7414
+ case ISD::ADJUST_TRAMPOLINE:
7415
+ return lowerADJUST_TRAMPOLINE(Op, DAG);
7405
7416
}
7406
7417
}
7407
7418
@@ -7417,6 +7428,126 @@ SDValue RISCVTargetLowering::emitFlushICache(SelectionDAG &DAG, SDValue InChain,
7417
7428
return CallResult.second;
7418
7429
}
7419
7430
7431
+ SDValue RISCVTargetLowering::lowerINIT_TRAMPOLINE(SDValue Op,
7432
+ SelectionDAG &DAG) const {
7433
+ if (!Subtarget.is64Bit())
7434
+ llvm::report_fatal_error("Trampolines only implemented for RV64");
7435
+
7436
+ // Create an MCCodeEmitter to encode instructions.
7437
+ TargetLoweringObjectFile *TLO = getTargetMachine().getObjFileLowering();
7438
+ assert(TLO);
7439
+ MCContext &MCCtx = TLO->getContext();
7440
+
7441
+ std::unique_ptr<MCCodeEmitter> CodeEmitter(
7442
+ createRISCVMCCodeEmitter(*getTargetMachine().getMCInstrInfo(), MCCtx));
7443
+
7444
+ SDValue Root = Op.getOperand(0);
7445
+ SDValue Trmp = Op.getOperand(1); // trampoline
7446
+ SDLoc dl(Op);
7447
+
7448
+ const Value *TrmpAddr = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
7449
+
7450
+ // We store in the trampoline buffer the following instructions and data.
7451
+ // Offset:
7452
+ // 0: auipc t2, 0
7453
+ // 4: ld t0, 24(t2)
7454
+ // 8: ld t2, 16(t2)
7455
+ // 12: jalr t0
7456
+ // 16: <StaticChainOffset>
7457
+ // 24: <FunctionAddressOffset>
7458
+ // 32:
7459
+
7460
+ constexpr unsigned StaticChainOffset = 16;
7461
+ constexpr unsigned FunctionAddressOffset = 24;
7462
+
7463
+ const MCSubtargetInfo *STI = getTargetMachine().getMCSubtargetInfo();
7464
+ assert(STI);
7465
+ auto GetEncoding = [&](const MCInst &MC) {
7466
+ SmallVector<char, 4> CB;
7467
+ SmallVector<MCFixup> Fixups;
7468
+ CodeEmitter->encodeInstruction(MC, CB, Fixups, *STI);
7469
+ uint32_t Encoding = support::endian::read32le(CB.data());
7470
+ return Encoding;
7471
+ };
7472
+
7473
+ SDValue OutChains[6];
7474
+
7475
+ uint32_t Encodings[] = {
7476
+ // auipc t2, 0
7477
+ // Loads the current PC into t2.
7478
+ GetEncoding(MCInstBuilder(RISCV::AUIPC).addReg(RISCV::X7).addImm(0)),
7479
+ // ld t0, 24(t2)
7480
+ // Loads the function address into t0. Note that we are using offsets
7481
+ // pc-relative to the first instruction of the trampoline.
7482
+ GetEncoding(
7483
+ MCInstBuilder(RISCV::LD).addReg(RISCV::X5).addReg(RISCV::X7).addImm(
7484
+ FunctionAddressOffset)),
7485
+ // ld t2, 16(t2)
7486
+ // Load the value of the static chain.
7487
+ GetEncoding(
7488
+ MCInstBuilder(RISCV::LD).addReg(RISCV::X7).addReg(RISCV::X7).addImm(
7489
+ StaticChainOffset)),
7490
+ // jalr t0
7491
+ // Jump to the function.
7492
+ GetEncoding(MCInstBuilder(RISCV::JALR)
7493
+ .addReg(RISCV::X0)
7494
+ .addReg(RISCV::X5)
7495
+ .addImm(0))};
7496
+
7497
+ // Store encoded instructions.
7498
+ for (auto [Idx, Encoding] : llvm::enumerate(Encodings)) {
7499
+ SDValue Addr = Idx > 0 ? DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7500
+ DAG.getConstant(Idx * 4, dl, MVT::i64))
7501
+ : Trmp;
7502
+ OutChains[Idx] = DAG.getTruncStore(
7503
+ Root, dl, DAG.getConstant(Encoding, dl, MVT::i64), Addr,
7504
+ MachinePointerInfo(TrmpAddr, Idx * 4), MVT::i32);
7505
+ }
7506
+
7507
+ // Now store the variable part of the trampoline.
7508
+ SDValue FunctionAddress = Op.getOperand(2);
7509
+ SDValue StaticChain = Op.getOperand(3);
7510
+
7511
+ // Store the given static chain and function pointer in the trampoline buffer.
7512
+ struct OffsetValuePair {
7513
+ const unsigned Offset;
7514
+ const SDValue Value;
7515
+ SDValue Addr = SDValue(); // Used to cache the address.
7516
+ } OffsetValues[] = {
7517
+ {StaticChainOffset, StaticChain},
7518
+ {FunctionAddressOffset, FunctionAddress},
7519
+ };
7520
+ for (auto [Idx, OffsetValue] : llvm::enumerate(OffsetValues)) {
7521
+ SDValue Addr =
7522
+ DAG.getNode(ISD::ADD, dl, MVT::i64, Trmp,
7523
+ DAG.getConstant(OffsetValue.Offset, dl, MVT::i64));
7524
+ OffsetValue.Addr = Addr;
7525
+ OutChains[Idx + 4] =
7526
+ DAG.getStore(Root, dl, OffsetValue.Value, Addr,
7527
+ MachinePointerInfo(TrmpAddr, OffsetValue.Offset));
7528
+ }
7529
+
7530
+ SDValue StoreToken = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
7531
+
7532
+ // The end of instructions of trampoline is the same as the static chain
7533
+ // address that we computed earlier.
7534
+ SDValue EndOfTrmp = OffsetValues[0].Addr;
7535
+
7536
+ // Call clear cache on the trampoline instructions.
7537
+ SDValue Chain = DAG.getNode(ISD::CLEAR_CACHE, dl, MVT::Other, StoreToken,
7538
+ Trmp, EndOfTrmp);
7539
+
7540
+ return Chain;
7541
+ }
7542
+
7543
+ SDValue RISCVTargetLowering::lowerADJUST_TRAMPOLINE(SDValue Op,
7544
+ SelectionDAG &DAG) const {
7545
+ if (!Subtarget.is64Bit())
7546
+ llvm::report_fatal_error("Trampolines only implemented for RV64");
7547
+
7548
+ return Op.getOperand(0);
7549
+ }
7550
+
7420
7551
static SDValue getTargetNode(GlobalAddressSDNode *N, const SDLoc &DL, EVT Ty,
7421
7552
SelectionDAG &DAG, unsigned Flags) {
7422
7553
return DAG.getTargetGlobalAddress(N->getGlobal(), DL, Ty, 0, Flags);
0 commit comments