Skip to content

[RISCV] Support llvm.readsteadycounter intrinsic #82322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,10 +625,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit())
setOperationAction(ISD::Constant, MVT::i64, Custom);

// TODO: On M-mode only targets, the cycle[h] CSR may not be present.
// TODO: On M-mode only targets, the cycle[h]/time[h] CSR may not be present.
// Unfortunately this can't be determined just from the ISA naming string.
setOperationAction(ISD::READCYCLECOUNTER, MVT::i64,
Subtarget.is64Bit() ? Legal : Custom);
setOperationAction(ISD::READSTEADYCOUNTER, MVT::i64,
Subtarget.is64Bit() ? Legal : Custom);

setOperationAction({ISD::TRAP, ISD::DEBUGTRAP}, MVT::Other, Legal);
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
Expand Down Expand Up @@ -11724,13 +11726,27 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
Results.push_back(Result);
break;
}
case ISD::READCYCLECOUNTER: {
assert(!Subtarget.is64Bit() &&
"READCYCLECOUNTER only has custom type legalization on riscv32");
case ISD::READCYCLECOUNTER:
case ISD::READSTEADYCOUNTER: {
assert(!Subtarget.is64Bit() && "READCYCLECOUNTER/READSTEADYCOUNTER only "
"has custom type legalization on riscv32");

SDValue LoCounter, HiCounter;
MVT XLenVT = Subtarget.getXLenVT();
if (N->getOpcode() == ISD::READCYCLECOUNTER) {
LoCounter = DAG.getConstant(
RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding, DL, XLenVT);
HiCounter = DAG.getConstant(
RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding, DL, XLenVT);
} else {
LoCounter = DAG.getConstant(
RISCVSysReg::lookupSysRegByName("TIME")->Encoding, DL, XLenVT);
HiCounter = DAG.getConstant(
RISCVSysReg::lookupSysRegByName("TIMEH")->Encoding, DL, XLenVT);
}
SDVTList VTs = DAG.getVTList(MVT::i32, MVT::i32, MVT::Other);
SDValue RCW =
DAG.getNode(RISCVISD::READ_CYCLE_WIDE, DL, VTs, N->getOperand(0));
SDValue RCW = DAG.getNode(RISCVISD::READ_COUNTER_WIDE, DL, VTs,
N->getOperand(0), LoCounter, HiCounter);

Results.push_back(
DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, RCW, RCW.getValue(1)));
Expand Down Expand Up @@ -16902,29 +16918,30 @@ RISCVTargetLowering::getTargetConstantFromLoad(LoadSDNode *Ld) const {
return CNodeLo->getConstVal();
}

static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI,
MachineBasicBlock *BB) {
assert(MI.getOpcode() == RISCV::ReadCycleWide && "Unexpected instruction");
static MachineBasicBlock *emitReadCounterWidePseudo(MachineInstr &MI,
MachineBasicBlock *BB) {
assert(MI.getOpcode() == RISCV::ReadCounterWide && "Unexpected instruction");

// To read the 64-bit cycle CSR on a 32-bit target, we read the two halves.
// To read a 64-bit counter CSR on a 32-bit target, we read the two halves.
// Should the count have wrapped while it was being read, we need to try
// again.
// ...
// For example:
// ```
// read:
// rdcycleh x3 # load high word of cycle
// rdcycle x2 # load low word of cycle
// rdcycleh x4 # load high word of cycle
// bne x3, x4, read # check if high word reads match, otherwise try again
// ...
// csrrs x3, counterh # load high word of counter
// csrrs x2, counter # load low word of counter
// csrrs x4, counterh # load high word of counter
// bne x3, x4, read # check if high word reads match, otherwise try again
// ```

MachineFunction &MF = *BB->getParent();
const BasicBlock *LLVM_BB = BB->getBasicBlock();
const BasicBlock *LLVMBB = BB->getBasicBlock();
MachineFunction::iterator It = ++BB->getIterator();

MachineBasicBlock *LoopMBB = MF.CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *LoopMBB = MF.CreateMachineBasicBlock(LLVMBB);
MF.insert(It, LoopMBB);

MachineBasicBlock *DoneMBB = MF.CreateMachineBasicBlock(LLVM_BB);
MachineBasicBlock *DoneMBB = MF.CreateMachineBasicBlock(LLVMBB);
MF.insert(It, DoneMBB);

// Transfer the remainder of BB and its successor edges to DoneMBB.
Expand All @@ -16938,17 +16955,19 @@ static MachineBasicBlock *emitReadCycleWidePseudo(MachineInstr &MI,
Register ReadAgainReg = RegInfo.createVirtualRegister(&RISCV::GPRRegClass);
Register LoReg = MI.getOperand(0).getReg();
Register HiReg = MI.getOperand(1).getReg();
int64_t LoCounter = MI.getOperand(2).getImm();
int64_t HiCounter = MI.getOperand(3).getImm();
DebugLoc DL = MI.getDebugLoc();

const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), HiReg)
.addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
.addImm(HiCounter)
.addReg(RISCV::X0);
BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), LoReg)
.addImm(RISCVSysReg::lookupSysRegByName("CYCLE")->Encoding)
.addImm(LoCounter)
.addReg(RISCV::X0);
BuildMI(LoopMBB, DL, TII->get(RISCV::CSRRS), ReadAgainReg)
.addImm(RISCVSysReg::lookupSysRegByName("CYCLEH")->Encoding)
.addImm(HiCounter)
.addReg(RISCV::X0);

BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
Expand Down Expand Up @@ -17527,10 +17546,10 @@ RISCVTargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
switch (MI.getOpcode()) {
default:
llvm_unreachable("Unexpected instr type to insert");
case RISCV::ReadCycleWide:
case RISCV::ReadCounterWide:
assert(!Subtarget.is64Bit() &&
"ReadCycleWrite is only to be used on riscv32");
return emitReadCycleWidePseudo(MI, BB);
"ReadCounterWide is only to be used on riscv32");
return emitReadCounterWidePseudo(MI, BB);
case RISCV::Select_GPR_Using_CC_GPR:
case RISCV::Select_FPR16_Using_CC_GPR:
case RISCV::Select_FPR16INX_Using_CC_GPR:
Expand Down Expand Up @@ -19202,7 +19221,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(FCLASS)
NODE_NAME_CASE(FMAX)
NODE_NAME_CASE(FMIN)
NODE_NAME_CASE(READ_CYCLE_WIDE)
NODE_NAME_CASE(READ_COUNTER_WIDE)
NODE_NAME_CASE(BREV8)
NODE_NAME_CASE(ORC_B)
NODE_NAME_CASE(ZIP)
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ enum NodeType : unsigned {
// Floating point fmax and fmin matching the RISC-V instruction semantics.
FMAX, FMIN,

// READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target
// (returns (Lo, Hi)). It takes a chain operand.
READ_CYCLE_WIDE,
// A read of the 64-bit counter CSR on a 32-bit target (returns (Lo, Hi)).
// It takes a chain operand.
READ_COUNTER_WIDE,

// brev8, orc.b, zip, and unzip from Zbb and Zbkb. All operands are i32 or
// XLenVT.
BREV8,
Expand Down
33 changes: 20 additions & 13 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def SDT_RISCVReadCSR : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDT_RISCVWriteCSR : SDTypeProfile<0, 2, [SDTCisInt<0>, SDTCisInt<1>]>;
def SDT_RISCVSwapCSR : SDTypeProfile<1, 2, [SDTCisInt<0>, SDTCisInt<1>,
SDTCisInt<2>]>;
def SDT_RISCVReadCycleWide : SDTypeProfile<2, 0, [SDTCisVT<0, i32>,
SDTCisVT<1, i32>]>;
def SDT_RISCVReadCounterWide : SDTypeProfile<2, 2, [SDTCisVT<0, i32>,
SDTCisVT<1, i32>,
SDTCisInt<2>,
SDTCisInt<3>]>;
def SDT_RISCVIntUnaryOpW : SDTypeProfile<1, 1, [
SDTCisSameAs<0, 1>, SDTCisVT<0, i64>
]>;
Expand Down Expand Up @@ -77,9 +79,9 @@ def riscv_write_csr : SDNode<"RISCVISD::WRITE_CSR", SDT_RISCVWriteCSR,
def riscv_swap_csr : SDNode<"RISCVISD::SWAP_CSR", SDT_RISCVSwapCSR,
[SDNPHasChain]>;

def riscv_read_cycle_wide : SDNode<"RISCVISD::READ_CYCLE_WIDE",
SDT_RISCVReadCycleWide,
[SDNPHasChain, SDNPSideEffect]>;
def riscv_read_counter_wide : SDNode<"RISCVISD::READ_COUNTER_WIDE",
SDT_RISCVReadCounterWide,
[SDNPHasChain, SDNPSideEffect]>;

def riscv_add_lo : SDNode<"RISCVISD::ADD_LO", SDTIntBinOp>;
def riscv_hi : SDNode<"RISCVISD::HI", SDTIntUnaryOp>;
Expand Down Expand Up @@ -363,7 +365,7 @@ def CSRSystemRegister : AsmOperandClass {
let DiagnosticType = "InvalidCSRSystemRegister";
}

def csr_sysreg : RISCVOp {
def csr_sysreg : RISCVOp, ImmLeaf<XLenVT, "return isUInt<12>(Imm);"> {
let ParserMatchClass = CSRSystemRegister;
let PrintMethod = "printCSRSystemRegister";
let DecoderMethod = "decodeUImmOperand<12>";
Expand Down Expand Up @@ -1827,16 +1829,21 @@ def : StPat<truncstorei32, SW, GPR, i64>;
def : StPat<store, SD, GPR, i64>;
} // Predicates = [IsRV64]

// On RV64, we can directly read these 64-bit counter CSRs.
let Predicates = [IsRV64] in {
/// readcyclecounter
// On RV64, we can directly read the 64-bit "cycle" CSR.
let Predicates = [IsRV64] in
def : Pat<(i64 (readcyclecounter)), (CSRRS CYCLE.Encoding, (XLenVT X0))>;
// On RV32, ReadCycleWide will be expanded to the suggested loop reading both
// halves of the 64-bit "cycle" CSR.
/// readsteadycounter
def : Pat<(i64 (readsteadycounter)), (CSRRS TIME.Encoding, (XLenVT X0))>;
}

// On RV32, ReadCounterWide will be expanded to the suggested loop reading both
// halves of 64-bit counter CSRs.
let Predicates = [IsRV32], usesCustomInserter = 1, hasNoSchedulingInfo = 1 in
def ReadCycleWide : Pseudo<(outs GPR:$lo, GPR:$hi), (ins),
[(set GPR:$lo, GPR:$hi, (riscv_read_cycle_wide))],
"", "">;
def ReadCounterWide : Pseudo<(outs GPR:$lo, GPR:$hi), (ins i32imm:$csr_lo, i32imm:$csr_hi),
[(set GPR:$lo, GPR:$hi,
(riscv_read_counter_wide csr_sysreg:$csr_lo, csr_sysreg:$csr_hi))],
"", "">;

/// traps

Expand Down
28 changes: 28 additions & 0 deletions llvm/test/CodeGen/RISCV/readsteadycounter.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=riscv32 -verify-machineinstrs < %s \
; RUN: | FileCheck -check-prefix=RV32I %s
; RUN: llc -mtriple=riscv64 -verify-machineinstrs < %s \
; RUN: | FileCheck -check-prefix=RV64I %s

; Verify that we lower @llvm.readsteadycounter() correctly.

declare i64 @llvm.readsteadycounter()

define i64 @test_builtin_readsteadycounter() nounwind {
; RV32I-LABEL: test_builtin_readsteadycounter:
; RV32I: # %bb.0:
; RV32I-NEXT: .LBB0_1: # =>This Inner Loop Header: Depth=1
; RV32I-NEXT: rdtimeh a1
; RV32I-NEXT: rdtime a0
; RV32I-NEXT: rdtimeh a2
; RV32I-NEXT: bne a1, a2, .LBB0_1
; RV32I-NEXT: # %bb.2:
; RV32I-NEXT: ret
;
; RV64I-LABEL: test_builtin_readsteadycounter:
; RV64I: # %bb.0:
; RV64I-NEXT: rdtime a0
; RV64I-NEXT: ret
%1 = tail call i64 @llvm.readsteadycounter()
ret i64 %1
}