Skip to content

Commit 51102a4

Browse files
committed
[RISCV] Add patterns for fixed vector vwsll
Fixed vectors have their sext/zext operands legalized to _VL nodes, so we need to handle them in the patterns. This adds a riscv_ext_vl_oneuse pattern since we don't care about the type of extension used for the shift amount, and extends Low8BitsSplatPat to handle other _VL nodes. We don't actually need to check the mask at all there since none of the _VL nodes have passthru operands. The remaining test cases that are widening from i8->i64 need to be handled by extending combineBinOp_VLToVWBinOp_VL.
1 parent 59dd10f commit 51102a4

File tree

4 files changed

+128
-121
lines changed

4 files changed

+128
-121
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3287,22 +3287,25 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
32873287
}
32883288

32893289
bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
3290-
// Truncates are custom lowered during legalization.
3291-
auto IsTrunc = [this](SDValue N) {
3292-
if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
3290+
auto IsVLNode = [this](SDValue N) {
3291+
switch (N->getOpcode()) {
3292+
case RISCVISD::TRUNCATE_VECTOR_VL:
3293+
case RISCVISD::VSEXT_VL:
3294+
case RISCVISD::VZEXT_VL:
3295+
break;
3296+
default:
32933297
return false;
3298+
}
32943299
SDValue VL;
32953300
selectVLOp(N->getOperand(2), VL);
3296-
// Any vmset_vl is ok, since any bits past VL are undefined and we can
3297-
// assume they are set.
3298-
return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
3299-
isa<ConstantSDNode>(VL) &&
3300-
cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel;
3301+
// There's no passthru so any mask is ok, since any inactive elements will
3302+
// be undef.
3303+
return true;
33013304
};
33023305

3303-
// We can have multiple nested truncates, so unravel them all if needed.
3306+
// We can have multiple nested nodes, so unravel them all if needed.
33043307
while (N->getOpcode() == ISD::SIGN_EXTEND ||
3305-
N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) {
3308+
N->getOpcode() == ISD::ZERO_EXTEND || IsVLNode(N)) {
33063309
if (!N.hasOneUse() ||
33073310
N.getValueType().getSizeInBits().getKnownMinValue() < 8)
33083311
return false;

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
387387
SDTCisVT<3, XLenVT>]>;
388388
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
389389
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
390+
def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
391+
[(riscv_sext_vl node:$A, node:$B, node:$C),
392+
(riscv_zext_vl node:$A, node:$B, node:$C)]>;
390393

391394
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
392395
SDTypeProfile<1, 3, [SDTCisVec<0>,
@@ -535,6 +538,11 @@ def riscv_zext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
535538
return N->hasOneUse();
536539
}]>;
537540

541+
def riscv_ext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
542+
(riscv_ext_vl node:$A, node:$B, node:$C), [{
543+
return N->hasOneUse();
544+
}]>;
545+
538546
def riscv_fpextend_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
539547
(riscv_fpextend_vl node:$A, node:$B, node:$C), [{
540548
return N->hasOneUse();

llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,19 @@ foreach vtiToWti = AllWidenableIntVectors in {
642642
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
643643
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
644644

645+
def : Pat<(riscv_shl_vl
646+
(wti.Vector (riscv_zext_vl_oneuse
647+
(vti.Vector vti.RegClass:$rs2),
648+
(vti.Mask V0), VLOpFrag)),
649+
(wti.Vector (riscv_ext_vl_oneuse
650+
(vti.Vector vti.RegClass:$rs1),
651+
(vti.Mask V0), VLOpFrag)),
652+
(wti.Vector wti.RegClass:$merge),
653+
(vti.Mask V0), VLOpFrag),
654+
(!cast<Instruction>("PseudoVWSLL_VV_"#vti.LMul.MX#"_MASK")
655+
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
656+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
657+
645658
def : Pat<(riscv_shl_vl
646659
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
647660
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
@@ -651,6 +664,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
651664
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
652665
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
653666

667+
def : Pat<(riscv_shl_vl
668+
(wti.Vector (riscv_zext_vl_oneuse
669+
(vti.Vector vti.RegClass:$rs2),
670+
(vti.Mask V0), VLOpFrag)),
671+
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
672+
(wti.Vector wti.RegClass:$merge),
673+
(vti.Mask V0), VLOpFrag),
674+
(!cast<Instruction>("PseudoVWSLL_VX_"#vti.LMul.MX#"_MASK")
675+
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
676+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
677+
654678
def : Pat<(riscv_shl_vl
655679
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
656680
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
@@ -660,6 +684,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
660684
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
661685
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
662686

687+
def : Pat<(riscv_shl_vl
688+
(wti.Vector (riscv_zext_vl_oneuse
689+
(vti.Vector vti.RegClass:$rs2),
690+
(vti.Mask V0), VLOpFrag)),
691+
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
692+
(wti.Vector wti.RegClass:$merge),
693+
(vti.Mask V0), VLOpFrag),
694+
(!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK")
695+
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
696+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
697+
663698
def : Pat<(riscv_vwsll_vl
664699
(vti.Vector vti.RegClass:$rs2),
665700
(vti.Vector vti.RegClass:$rs1),

0 commit comments

Comments
 (0)