Skip to content

Commit 4e18d62

Browse files
committed
Lower build_vector to broadcast load if possible
1 parent b122956 commit 4e18d62

File tree

5 files changed

+80
-4
lines changed

5 files changed

+80
-4
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,6 +1721,47 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
17211721
return false;
17221722
}
17231723

1724+
// Lower BUILD_VECTOR as broadcast load (if possible).
1725+
// For example:
1726+
// %a = load i8, ptr %ptr
1727+
// %b = build_vector %a, %a, %a, %a
1728+
// is lowered to :
1729+
// (VLDREPL_B $a0, 0)
1730+
static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
1731+
const SDLoc &DL,
1732+
SelectionDAG &DAG) {
1733+
MVT VT = BVOp->getSimpleValueType(0);
1734+
int NumOps = BVOp->getNumOperands();
1735+
1736+
assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
1737+
"Unsupported vector type for broadcast.");
1738+
1739+
SDValue IdentitySrc;
1740+
bool IsIdeneity = true;
1741+
1742+
for (int i = 0; i != NumOps; i++) {
1743+
SDValue Op = BVOp->getOperand(i);
1744+
if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
1745+
IsIdeneity = false;
1746+
break;
1747+
}
1748+
IdentitySrc = BVOp->getOperand(0);
1749+
}
1750+
1751+
if (IsIdeneity) {
1752+
auto *LN = cast<LoadSDNode>(IdentitySrc);
1753+
SDVTList Tys =
1754+
LN->isIndexed()
1755+
? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
1756+
: DAG.getVTList(VT, MVT::Other);
1757+
SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
1758+
SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
1759+
DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
1760+
return BCast;
1761+
}
1762+
return SDValue();
1763+
}
1764+
17241765
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
17251766
SelectionDAG &DAG) const {
17261767
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
@@ -1736,6 +1777,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
17361777
(!Subtarget.hasExtLASX() || !Is256Vec))
17371778
return SDValue();
17381779

1780+
if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
1781+
return Result;
1782+
17391783
if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
17401784
/*MinSplatBits=*/8) &&
17411785
SplatBitSize <= 64) {
@@ -5171,6 +5215,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
51715215
NODE_NAME_CASE(VSRLI)
51725216
NODE_NAME_CASE(VBSLL)
51735217
NODE_NAME_CASE(VBSRL)
5218+
NODE_NAME_CASE(VLDREPL)
51745219
}
51755220
#undef NODE_NAME_CASE
51765221
return nullptr;

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ enum NodeType : unsigned {
155155

156156
// Vector byte logicial left / right shift
157157
VBSLL,
158-
VBSRL
158+
VBSRL,
159+
160+
// Scalar load broadcast to vector
161+
VLDREPL
159162

160163
// Intrinsic operations end =============================================
161164
};

llvm/lib/Target/LoongArch/LoongArchInstrInfo.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
307307
}
308308
}
309309

310-
def simm9_lsl3 : Operand<GRLenVT> {
310+
def simm9_lsl3 : Operand<GRLenVT>,
311+
ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
311312
let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
312313
let EncoderMethod = "getImmOpValueAsr<3>";
313314
let DecoderMethod = "decodeSImmOperand<9, 3>";
@@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
317318
let ParserMatchClass = SImmAsmOperand<10>;
318319
}
319320

320-
def simm10_lsl2 : Operand<GRLenVT> {
321+
def simm10_lsl2 : Operand<GRLenVT>,
322+
ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
321323
let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
322324
let EncoderMethod = "getImmOpValueAsr<2>";
323325
let DecoderMethod = "decodeSImmOperand<10, 2>";
324326
}
325327

326-
def simm11_lsl1 : Operand<GRLenVT> {
328+
def simm11_lsl1 : Operand<GRLenVT>,
329+
ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
327330
let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
328331
let EncoderMethod = "getImmOpValueAsr<1>";
329332
let DecoderMethod = "decodeSImmOperand<11, 1>";

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2161,6 +2161,7 @@ def : Pat<(int_loongarch_lasx_xvld GPR:$rj, timm:$imm),
21612161
def : Pat<(int_loongarch_lasx_xvldx GPR:$rj, GPR:$rk),
21622162
(XVLDX GPR:$rj, GPR:$rk)>;
21632163

2164+
// xvldrepl
21642165
def : Pat<(int_loongarch_lasx_xvldrepl_b GPR:$rj, timm:$imm),
21652166
(XVLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
21662167
def : Pat<(int_loongarch_lasx_xvldrepl_h GPR:$rj, timm:$imm),
@@ -2170,6 +2171,11 @@ def : Pat<(int_loongarch_lasx_xvldrepl_w GPR:$rj, timm:$imm),
21702171
def : Pat<(int_loongarch_lasx_xvldrepl_d GPR:$rj, timm:$imm),
21712172
(XVLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
21722173

2174+
defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
2175+
defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
2176+
defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
2177+
defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
2178+
21732179
// store
21742180
def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
21752181
(XVST LASX256:$xd, GPR:$rj, (to_valid_timm timm:$imm))>;

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def SDT_LoongArchV1RUimm: SDTypeProfile<1, 2, [SDTCisVec<0>,
2626
def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>]>;
2727
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
2828
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
29+
def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
2930

3031
// Target nodes.
3132
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -64,6 +65,10 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
6465
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
6566
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
6667

68+
def loongarch_vldrepl
69+
: SDNode<"LoongArchISD::VLDREPL",
70+
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
71+
6772
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
6873
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
6974
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1433,6 +1438,14 @@ multiclass PatCCVrVrF<CondCode CC, string Inst> {
14331438
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
14341439
}
14351440

1441+
multiclass VldreplPat<ValueType vt, LAInst Inst, Operand ImmOpnd> {
1442+
def : Pat<(vt(loongarch_vldrepl BaseAddr:$rj)), (Inst BaseAddr:$rj, 0)>;
1443+
def : Pat<(vt(loongarch_vldrepl(AddrConstant GPR:$rj, ImmOpnd:$imm))),
1444+
(Inst GPR:$rj, ImmOpnd:$imm)>;
1445+
def : Pat<(vt(loongarch_vldrepl(AddLike BaseAddr:$rj, ImmOpnd:$imm))),
1446+
(Inst BaseAddr:$rj, ImmOpnd:$imm)>;
1447+
}
1448+
14361449
let Predicates = [HasExtLSX] in {
14371450

14381451
// VADD_{B/H/W/D}
@@ -2338,6 +2351,7 @@ def : Pat<(int_loongarch_lsx_vld GPR:$rj, timm:$imm),
23382351
def : Pat<(int_loongarch_lsx_vldx GPR:$rj, GPR:$rk),
23392352
(VLDX GPR:$rj, GPR:$rk)>;
23402353

2354+
// vldrepl
23412355
def : Pat<(int_loongarch_lsx_vldrepl_b GPR:$rj, timm:$imm),
23422356
(VLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
23432357
def : Pat<(int_loongarch_lsx_vldrepl_h GPR:$rj, timm:$imm),
@@ -2347,6 +2361,11 @@ def : Pat<(int_loongarch_lsx_vldrepl_w GPR:$rj, timm:$imm),
23472361
def : Pat<(int_loongarch_lsx_vldrepl_d GPR:$rj, timm:$imm),
23482362
(VLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
23492363

2364+
defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
2365+
defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
2366+
defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
2367+
defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
2368+
23502369
// store
23512370
def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
23522371
(VST LSX128:$vd, GPR:$rj, (to_valid_timm timm:$imm))>;

0 commit comments

Comments
 (0)