Skip to content

Commit 5743b28

Browse files
committed
[RISCV] Add codegen support for zicldst
1 parent 1f1ebfb commit 5743b28

File tree

5 files changed

+135
-0
lines changed

5 files changed

+135
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "llvm/CodeGen/MachineInstrBuilder.h"
2929
#include "llvm/CodeGen/MachineJumpTableInfo.h"
3030
#include "llvm/CodeGen/MachineRegisterInfo.h"
31+
#include "llvm/CodeGen/SDPatternMatch.h"
3132
#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
3233
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
3334
#include "llvm/CodeGen/ValueTypes.h"
@@ -20466,6 +20467,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2046620467
NODE_NAME_CASE(SF_VC_V_IVW_SE)
2046720468
NODE_NAME_CASE(SF_VC_V_VVW_SE)
2046820469
NODE_NAME_CASE(SF_VC_V_FVW_SE)
20470+
NODE_NAME_CASE(CLOAD)
20471+
NODE_NAME_CASE(CSTORE)
2046920472
}
2047020473
// clang-format on
2047120474
return nullptr;
@@ -22067,6 +22070,55 @@ SDValue RISCVTargetLowering::expandIndirectJTBranch(const SDLoc &dl,
2206722070
return TargetLowering::expandIndirectJTBranch(dl, Value, Addr, JTI, DAG);
2206822071
}
2206922072

22073+
SDValue RISCVTargetLowering::visitMaskedLoad(
22074+
SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, MachineMemOperand *MMO,
22075+
SDValue &NewLoad, SDValue Ptr, SDValue PassThru, SDValue Mask) const {
22076+
using namespace SDPatternMatch;
22077+
// @llvm.masked.load.v1*(ptr, alignment, mask, passthru)
22078+
// ->
22079+
// ptr_in = select (bit_cast_to_i1 mask), ptr, 0
22080+
// val, chain = CLOAD inchain, ptr_in, width
22081+
// res = select mask, (bit_cast_to_vt val), passthru
22082+
EVT VTy = PassThru.getValueType();
22083+
EVT Ty = VTy.getVectorElementType();
22084+
EVT XLenVT = Subtarget.getXLenVT();
22085+
SDVTList Tys = DAG.getVTList(XLenVT, MVT::Other);
22086+
SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
22087+
EVT PtrVT = Ptr.getValueType();
22088+
SDValue PtrIn =
22089+
DAG.getSelect(DL, PtrVT, ScalarMask, Ptr, DAG.getConstant(0, DL, PtrVT));
22090+
SDValue Ops[] = {Chain, PtrIn,
22091+
DAG.getConstant(Ty.getScalarSizeInBits(), DL, XLenVT)};
22092+
NewLoad = DAG.getMemIntrinsicNode(RISCVISD::CLOAD, DL, Tys, Ops, Ty, MMO);
22093+
SDValue Ret =
22094+
DAG.getBitcast(VTy, DAG.getNode(ISD::TRUNCATE, DL, Ty, NewLoad));
22095+
if (!PassThru.isUndef() && !sd_match(PassThru, m_Zero()))
22096+
Ret = DAG.getSelect(DL, VTy, Mask, Ret, PassThru);
22097+
return Ret;
22098+
}
22099+
22100+
SDValue RISCVTargetLowering::visitMaskedStore(SelectionDAG &DAG,
22101+
const SDLoc &DL, SDValue Chain,
22102+
MachineMemOperand *MMO,
22103+
SDValue Ptr, SDValue Val,
22104+
SDValue Mask) const {
22105+
// llvm.masked.store.v1*(Src0, Ptr, alignment, Mask)
22106+
// ->
22107+
// chain = CSTORE inchain, (bit_cast_to_scalar val), ptr_in, width
22108+
EVT Ty = Val.getValueType().getVectorElementType();
22109+
EVT XLenVT = Subtarget.getXLenVT();
22110+
SDVTList Tys = DAG.getVTList(MVT::Other);
22111+
SDValue ScalarMask = DAG.getBitcast(MVT::i1, Mask);
22112+
EVT PtrVT = Ptr.getValueType();
22113+
SDValue PtrIn =
22114+
DAG.getSelect(DL, PtrVT, ScalarMask, Ptr, DAG.getConstant(0, DL, PtrVT));
22115+
SDValue ScalarVal =
22116+
DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, DAG.getBitcast(Ty, Val));
22117+
SDValue Ops[] = {Chain, ScalarVal, PtrIn,
22118+
DAG.getConstant(Ty.getScalarSizeInBits(), DL, XLenVT)};
22119+
return DAG.getMemIntrinsicNode(RISCVISD::CSTORE, DL, Tys, Ops, Ty, MMO);
22120+
}
22121+
2207022122
namespace llvm::RISCVVIntrinsicsTable {
2207122123

2207222124
#define GET_RISCVVIntrinsicsTable_IMPL

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,10 @@ enum NodeType : unsigned {
477477
TH_LDD,
478478
TH_SWD,
479479
TH_SDD,
480+
481+
// Conditional load/store instructions
482+
CLOAD,
483+
CSTORE,
480484
};
481485
// clang-format on
482486
} // namespace RISCVISD
@@ -1045,6 +1049,13 @@ class RISCVTargetLowering : public TargetLowering {
10451049

10461050
SDValue emitFlushICache(SelectionDAG &DAG, SDValue InChain, SDValue Start,
10471051
SDValue End, SDValue Flags, SDLoc DL) const;
1052+
1053+
SDValue visitMaskedLoad(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
1054+
MachineMemOperand *MMO, SDValue &NewLoad, SDValue Ptr,
1055+
SDValue PassThru, SDValue Mask) const override;
1056+
SDValue visitMaskedStore(SelectionDAG &DAG, const SDLoc &DL, SDValue Chain,
1057+
MachineMemOperand *MMO, SDValue Ptr, SDValue Val,
1058+
SDValue Mask) const override;
10481059
};
10491060

10501061
/// As per the spec, the rules for passing vector arguments are as follows:

llvm/lib/Target/RISCV/RISCVInstrInfoZicldst.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,34 @@ def CLWU : CondLoad_ri<0b110, "clwu">, Sched<[WriteLDW, ReadMemBase]>;
4545
def CLD : CondLoad_ri<0b011, "cld">, Sched<[WriteLDD, ReadMemBase]>;
4646
def CSD : CondStore_rri<0b011, "csd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>;
4747
} // Predicates = [HasStdExtZicldst, IsRV64]
48+
49+
def SDTRVCLoad : SDTypeProfile<1, 2, [
50+
SDTCisVT<0, XLenVT>, SDTCisPtrTy<1>, SDTCisVT<2, XLenVT>
51+
]>;
52+
def SDTRVCStore : SDTypeProfile<0, 3, [
53+
SDTCisVT<0, XLenVT>, SDTCisPtrTy<1>, SDTCisVT<2, XLenVT>
54+
]>;
55+
56+
def riscv_cload : SDNode<"RISCVISD::CLOAD", SDTRVCLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
57+
def riscv_cstore : SDNode<"RISCVISD::CSTORE", SDTRVCStore, [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>;
58+
59+
let Predicates = [HasStdExtZicldst] in {
60+
61+
def : Pat<(XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 8)), (CLB GPR:$rs1, simm12:$imm12)>;
62+
def : Pat<(XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 16)), (CLH GPR:$rs1, simm12:$imm12)>;
63+
def : Pat<(XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 32)), (CLW GPR:$rs1, simm12:$imm12)>;
64+
def : Pat<(and (XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 8)), 0xFF), (CLBU GPR:$rs1, simm12:$imm12)>;
65+
def : Pat<(and (XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 16)), 0xFFFF), (CLHU GPR:$rs1, simm12:$imm12)>;
66+
def : Pat<(riscv_cstore (XLenVT GPR:$rs2), (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 8), (CSB GPR:$rs2, GPR:$rs1, simm12:$imm12)>;
67+
def : Pat<(riscv_cstore (XLenVT GPR:$rs2), (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 16), (CSH GPR:$rs2, GPR:$rs1, simm12:$imm12)>;
68+
def : Pat<(riscv_cstore (XLenVT GPR:$rs2), (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 32), (CSW GPR:$rs2, GPR:$rs1, simm12:$imm12)>;
69+
70+
} // Predicates = [HasStdExtZicldst]
71+
72+
let Predicates = [HasStdExtZicldst, IsRV64] in {
73+
74+
def : Pat<(and (XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 32)), 0xFFFFFFFF), (CLWU GPR:$rs1, simm12:$imm12)>;
75+
def : Pat<(XLenVT (riscv_cload (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 64)), (CLD GPR:$rs1, simm12:$imm12)>;
76+
def : Pat<(riscv_cstore (XLenVT GPR:$rs2), (AddrRegImm (XLenVT GPR:$rs1), simm12:$imm12), 64), (CSD GPR:$rs2, GPR:$rs1, simm12:$imm12)>;
77+
78+
} // Predicates = [HasStdExtZicldst, IsRV64]

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,3 +1989,22 @@ bool RISCVTTIImpl::areInlineCompatible(const Function *Caller,
19891989
// target-features.
19901990
return (CallerBits & CalleeBits) == CalleeBits;
19911991
}
1992+
1993+
bool RISCVTTIImpl::hasConditionalLoadStoreForType(Type *Ty) const {
1994+
if (!ST->hasStdExtZicldst())
1995+
return false;
1996+
if (!Ty)
1997+
return true;
1998+
if (!Ty->isIntOrPtrTy())
1999+
return false;
2000+
switch (DL.getTypeSizeInBits(Ty)) {
2001+
default:
2002+
return false;
2003+
case 8:
2004+
case 16:
2005+
case 32:
2006+
return true;
2007+
case 64:
2008+
return ST->is64Bit();
2009+
}
2010+
}

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,27 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
218218
}
219219

220220
bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
221+
if (ST->hasStdExtZicldst() && isa<FixedVectorType>(DataType) &&
222+
DataType->getScalarType()->isIntOrPtrTy() &&
223+
cast<FixedVectorType>(DataType)->getNumElements() == 1) {
224+
EVT DataTypeVT = TLI->getValueType(DL, DataType);
225+
EVT ElemType = DataTypeVT.getScalarType();
226+
if (!ST->enableUnalignedScalarMem() &&
227+
Alignment < ElemType.getStoreSize())
228+
return false;
229+
230+
switch (DL.getTypeSizeInBits(DataType)) {
231+
default:
232+
return false;
233+
case 8:
234+
case 16:
235+
case 32:
236+
return true;
237+
case 64:
238+
return ST->is64Bit();
239+
}
240+
}
241+
221242
if (!ST->hasVInstructions())
222243
return false;
223244

@@ -399,6 +420,7 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
399420
}
400421

401422
std::optional<unsigned> getMinPageSize() const { return 4096; }
423+
bool hasConditionalLoadStoreForType(Type *Ty) const;
402424
};
403425

404426
} // end namespace llvm

0 commit comments

Comments
 (0)