Skip to content

Commit add224c

Browse files
committed
[LoongArch] Custom lowering ISD::BUILD_VECTOR
1 parent f2cbd1f commit add224c

File tree

8 files changed

+1112
-23
lines changed

8 files changed

+1112
-23
lines changed

llvm/lib/Target/LoongArch/LoongArchISelDAGToDAG.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,63 @@ void LoongArchDAGToDAGISel::Select(SDNode *Node) {
7777
return;
7878
}
7979
case ISD::BITCAST: {
80-
if (VT.is128BitVector() || VT.is512BitVector()) {
80+
if (VT.is128BitVector() || VT.is256BitVector()) {
8181
ReplaceUses(SDValue(Node, 0), Node->getOperand(0));
8282
CurDAG->RemoveDeadNode(Node);
8383
return;
8484
}
8585
break;
8686
}
87+
case ISD::BUILD_VECTOR: {
88+
// Select appropriate [x]vrepli.[bhwd] instructions for constant splats of
89+
// 128/256-bit when LSX/LASX is enabled.
90+
BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Node);
91+
APInt SplatValue, SplatUndef;
92+
unsigned SplatBitSize;
93+
bool HasAnyUndefs;
94+
unsigned Op;
95+
EVT ViaVecTy;
96+
bool Is128Vec = BVN->getValueType(0).is128BitVector();
97+
bool Is256Vec = BVN->getValueType(0).is256BitVector();
98+
99+
if (!Subtarget->hasExtLSX() || (!Is128Vec && !Is256Vec))
100+
break;
101+
if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
102+
HasAnyUndefs, 8))
103+
break;
104+
105+
switch (SplatBitSize) {
106+
default:
107+
break;
108+
case 8:
109+
Op = Is256Vec ? LoongArch::PseudoXVREPLI_B : LoongArch::PseudoVREPLI_B;
110+
ViaVecTy = Is256Vec ? MVT::v32i8 : MVT::v16i8;
111+
break;
112+
case 16:
113+
Op = Is256Vec ? LoongArch::PseudoXVREPLI_H : LoongArch::PseudoVREPLI_H;
114+
ViaVecTy = Is256Vec ? MVT::v16i16 : MVT::v8i16;
115+
break;
116+
case 32:
117+
Op = Is256Vec ? LoongArch::PseudoXVREPLI_W : LoongArch::PseudoVREPLI_W;
118+
ViaVecTy = Is256Vec ? MVT::v8i32 : MVT::v4i32;
119+
break;
120+
case 64:
121+
Op = Is256Vec ? LoongArch::PseudoXVREPLI_D : LoongArch::PseudoVREPLI_D;
122+
ViaVecTy = Is256Vec ? MVT::v4i64 : MVT::v2i64;
123+
break;
124+
}
125+
126+
SDNode *Res;
127+
// If we have a signed 10 bit integer, we can splat it directly.
128+
if (SplatValue.isSignedIntN(10)) {
129+
SDValue Imm = CurDAG->getTargetConstant(SplatValue, DL,
130+
ViaVecTy.getVectorElementType());
131+
Res = CurDAG->getMachineNode(Op, DL, ViaVecTy, Imm);
132+
ReplaceNode(Node, Res);
133+
return;
134+
}
135+
break;
136+
}
87137
}
88138

89139
// Select the default instruction.

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,9 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
243243
setOperationAction(ISD::BITCAST, VT, Legal);
244244
setOperationAction(ISD::UNDEF, VT, Legal);
245245

246-
// FIXME: For BUILD_VECTOR, it is temporarily set to `Legal` here, and it
247-
// will be `Custom` handled in the future.
248-
setOperationAction(ISD::BUILD_VECTOR, VT, Legal);
249246
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
250247
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Legal);
248+
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
251249
}
252250
for (MVT VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64}) {
253251
setOperationAction({ISD::ADD, ISD::SUB}, VT, Legal);
@@ -274,10 +272,9 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
274272
setOperationAction(ISD::BITCAST, VT, Legal);
275273
setOperationAction(ISD::UNDEF, VT, Legal);
276274

277-
// FIXME: Same as above.
278-
setOperationAction(ISD::BUILD_VECTOR, VT, Legal);
279275
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
280276
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Legal);
277+
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
281278
}
282279
for (MVT VT : {MVT::v4i64, MVT::v8i32, MVT::v16i16, MVT::v32i8}) {
283280
setOperationAction({ISD::ADD, ISD::SUB}, VT, Legal);
@@ -382,10 +379,105 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
382379
return lowerWRITE_REGISTER(Op, DAG);
383380
case ISD::INSERT_VECTOR_ELT:
384381
return lowerINSERT_VECTOR_ELT(Op, DAG);
382+
case ISD::BUILD_VECTOR:
383+
return lowerBUILD_VECTOR(Op, DAG);
385384
}
386385
return SDValue();
387386
}
388387

388+
static bool isConstantOrUndef(const SDValue Op) {
389+
if (Op->isUndef())
390+
return true;
391+
if (isa<ConstantSDNode>(Op))
392+
return true;
393+
if (isa<ConstantFPSDNode>(Op))
394+
return true;
395+
return false;
396+
}
397+
398+
static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
399+
for (unsigned i = 0; i < Op->getNumOperands(); ++i)
400+
if (isConstantOrUndef(Op->getOperand(i)))
401+
return true;
402+
return false;
403+
}
404+
405+
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
406+
SelectionDAG &DAG) const {
407+
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
408+
EVT ResTy = Op->getValueType(0);
409+
SDLoc DL(Op);
410+
APInt SplatValue, SplatUndef;
411+
unsigned SplatBitSize;
412+
bool HasAnyUndefs;
413+
bool Is128Vec = ResTy.is128BitVector();
414+
bool Is256Vec = ResTy.is256BitVector();
415+
416+
if ((!Subtarget.hasExtLSX() || !Is128Vec) &&
417+
(!Subtarget.hasExtLASX() || !Is256Vec))
418+
return SDValue();
419+
420+
if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
421+
/*MinSplatBits=*/8) &&
422+
SplatBitSize <= 64) {
423+
// We can only cope with 8, 16, 32, or 64-bit elements.
424+
if (SplatBitSize != 8 && SplatBitSize != 16 && SplatBitSize != 32 &&
425+
SplatBitSize != 64)
426+
return SDValue();
427+
428+
EVT ViaVecTy;
429+
430+
switch (SplatBitSize) {
431+
default:
432+
return SDValue();
433+
case 8:
434+
ViaVecTy = Is128Vec ? MVT::v16i8 : MVT::v32i8;
435+
break;
436+
case 16:
437+
ViaVecTy = Is128Vec ? MVT::v8i16 : MVT::v16i16;
438+
break;
439+
case 32:
440+
ViaVecTy = Is128Vec ? MVT::v4i32 : MVT::v8i32;
441+
break;
442+
case 64:
443+
ViaVecTy = Is128Vec ? MVT::v2i64 : MVT::v4i64;
444+
break;
445+
}
446+
447+
// SelectionDAG::getConstant will promote SplatValue appropriately.
448+
SDValue Result = DAG.getConstant(SplatValue, DL, ViaVecTy);
449+
450+
// Bitcast to the type we originally wanted.
451+
if (ViaVecTy != ResTy)
452+
Result = DAG.getNode(ISD::BITCAST, SDLoc(Node), ResTy, Result);
453+
454+
return Result;
455+
}
456+
457+
if (DAG.isSplatValue(Op, /*AllowUndefs=*/false))
458+
return Op;
459+
460+
if (!isConstantOrUndefBUILD_VECTOR(Node)) {
461+
// Use INSERT_VECTOR_ELT operations rather than expand to stores.
462+
// The resulting code is the same length as the expansion, but it doesn't
463+
// use memory operations.
464+
EVT ResTy = Node->getValueType(0);
465+
466+
assert(ResTy.isVector());
467+
468+
unsigned NumElts = ResTy.getVectorNumElements();
469+
SDValue Vector = DAG.getUNDEF(ResTy);
470+
for (unsigned i = 0; i < NumElts; ++i) {
471+
Vector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Vector,
472+
Node->getOperand(i),
473+
DAG.getConstant(i, DL, Subtarget.getGRLenVT()));
474+
}
475+
return Vector;
476+
}
477+
478+
return SDValue();
479+
}
480+
389481
SDValue
390482
LoongArchTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
391483
SelectionDAG &DAG) const {

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ class LoongArchTargetLowering : public TargetLowering {
277277
SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;
278278
SDValue lowerWRITE_REGISTER(SDValue Op, SelectionDAG &DAG) const;
279279
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
280+
SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
280281

281282
bool isFPImmLegal(const APFloat &Imm, EVT VT,
282283
bool ForCodeSize) const override;

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ def lasxsplati32
3333
def lasxsplati64
3434
: PatFrag<(ops node:$e0),
3535
(v4i64 (build_vector node:$e0, node:$e0, node:$e0, node:$e0))>;
36+
def lasxsplatf32
37+
: PatFrag<(ops node:$e0),
38+
(v8f32 (build_vector node:$e0, node:$e0, node:$e0, node:$e0,
39+
node:$e0, node:$e0, node:$e0, node:$e0))>;
40+
def lasxsplatf64
41+
: PatFrag<(ops node:$e0),
42+
(v4f64 (build_vector node:$e0, node:$e0, node:$e0, node:$e0))>;
3643

3744
//===----------------------------------------------------------------------===//
3845
// Instruction class templates
@@ -1411,6 +1418,12 @@ def : Pat<(loongarch_vreplve v8i32:$xj, GRLenVT:$rk),
14111418
def : Pat<(loongarch_vreplve v4i64:$xj, GRLenVT:$rk),
14121419
(XVREPLVE_D v4i64:$xj, GRLenVT:$rk)>;
14131420

1421+
// XVREPL128VEI_{W/D}
1422+
def : Pat<(lasxsplatf32 FPR32:$fj),
1423+
(XVREPL128VEI_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32), 0)>;
1424+
def : Pat<(lasxsplatf64 FPR64:$fj),
1425+
(XVREPL128VEI_D (SUBREG_TO_REG (i64 0), FPR64:$fj, sub_64), 0)>;
1426+
14141427
// Loads/Stores
14151428
foreach vt = [v32i8, v16i16, v8i32, v4i64, v8f32, v4f64] in {
14161429
defm : LdPat<load, XVLD, vt>;

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,13 @@ def lsxsplati16 : PatFrag<(ops node:$e0),
141141
def lsxsplati32 : PatFrag<(ops node:$e0),
142142
(v4i32 (build_vector node:$e0, node:$e0,
143143
node:$e0, node:$e0))>;
144-
145144
def lsxsplati64 : PatFrag<(ops node:$e0),
146145
(v2i64 (build_vector node:$e0, node:$e0))>;
146+
def lsxsplatf32 : PatFrag<(ops node:$e0),
147+
(v4f32 (build_vector node:$e0, node:$e0,
148+
node:$e0, node:$e0))>;
149+
def lsxsplatf64 : PatFrag<(ops node:$e0),
150+
(v2f64 (build_vector node:$e0, node:$e0))>;
147151

148152
def to_valid_timm : SDNodeXForm<timm, [{
149153
auto CN = cast<ConstantSDNode>(N);
@@ -1498,6 +1502,12 @@ def : Pat<(loongarch_vreplve v4i32:$vj, GRLenVT:$rk),
14981502
def : Pat<(loongarch_vreplve v2i64:$vj, GRLenVT:$rk),
14991503
(VREPLVE_D v2i64:$vj, GRLenVT:$rk)>;
15001504

1505+
// VREPLVEI_{W/D}
1506+
def : Pat<(lsxsplatf32 FPR32:$fj),
1507+
(VREPLVEI_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32), 0)>;
1508+
def : Pat<(lsxsplatf64 FPR64:$fj),
1509+
(VREPLVEI_D (SUBREG_TO_REG (i64 0), FPR64:$fj, sub_64), 0)>;
1510+
15011511
// Loads/Stores
15021512
foreach vt = [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64] in {
15031513
defm : LdPat<load, VLD, vt>;

0 commit comments

Comments
 (0)