Skip to content

Commit 07d5f49

Browse files
authored
[RISCV] Add patterns for fixed vector vwsll (#87316)
Fixed vectors have their sext/zext operands legalized to _VL nodes, so we need to handle them in the patterns. This adds a riscv_ext_vl_oneuse pattern since we don't care about the type of extension used for the shift amount, and extends Low8BitsSplatPat to handle other _VL nodes. We don't actually need to check the mask or VL there since none of the _VL nodes have passthru operands. The remaining test cases that are widening from i8->i64 need to be handled by extending combineBinOp_VLToVWBinOp_VL. This also fixes Low8BitsSplatPat incorrectly checking the vector size instead of the element size to determine if the splat value might have been truncated below 8 bits.
1 parent fb635be commit 07d5f49

File tree

4 files changed

+130
-126
lines changed

4 files changed

+130
-126
lines changed

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3287,24 +3287,24 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
32873287
}
32883288

32893289
bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
3290-
// Truncates are custom lowered during legalization.
3291-
auto IsTrunc = [this](SDValue N) {
3292-
if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
3290+
auto IsExtOrTrunc = [](SDValue N) {
3291+
switch (N->getOpcode()) {
3292+
case ISD::SIGN_EXTEND:
3293+
case ISD::ZERO_EXTEND:
3294+
// There's no passthru on these _VL nodes so any VL/mask is ok, since any
3295+
// inactive elements will be undef.
3296+
case RISCVISD::TRUNCATE_VECTOR_VL:
3297+
case RISCVISD::VSEXT_VL:
3298+
case RISCVISD::VZEXT_VL:
3299+
return true;
3300+
default:
32933301
return false;
3294-
SDValue VL;
3295-
selectVLOp(N->getOperand(2), VL);
3296-
// Any vmset_vl is ok, since any bits past VL are undefined and we can
3297-
// assume they are set.
3298-
return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
3299-
isa<ConstantSDNode>(VL) &&
3300-
cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel;
3302+
}
33013303
};
33023304

3303-
// We can have multiple nested truncates, so unravel them all if needed.
3304-
while (N->getOpcode() == ISD::SIGN_EXTEND ||
3305-
N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) {
3306-
if (!N.hasOneUse() ||
3307-
N.getValueType().getSizeInBits().getKnownMinValue() < 8)
3305+
// We can have multiple nested nodes, so unravel them all if needed.
3306+
while (IsExtOrTrunc(N)) {
3307+
if (!N.hasOneUse() || N.getScalarValueSizeInBits() < 8)
33083308
return false;
33093309
N = N->getOperand(0);
33103310
}

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def SDT_RISCVVEXTEND_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
387387
SDTCisVT<3, XLenVT>]>;
388388
def riscv_sext_vl : SDNode<"RISCVISD::VSEXT_VL", SDT_RISCVVEXTEND_VL>;
389389
def riscv_zext_vl : SDNode<"RISCVISD::VZEXT_VL", SDT_RISCVVEXTEND_VL>;
390+
def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
391+
[(riscv_sext_vl node:$A, node:$B, node:$C),
392+
(riscv_zext_vl node:$A, node:$B, node:$C)]>;
390393

391394
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
392395
SDTypeProfile<1, 3, [SDTCisVec<0>,
@@ -535,6 +538,11 @@ def riscv_zext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
535538
return N->hasOneUse();
536539
}]>;
537540

541+
def riscv_ext_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
542+
(riscv_ext_vl node:$A, node:$B, node:$C), [{
543+
return N->hasOneUse();
544+
}]>;
545+
538546
def riscv_fpextend_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C),
539547
(riscv_fpextend_vl node:$A, node:$B, node:$C), [{
540548
return N->hasOneUse();

llvm/lib/Target/RISCV/RISCVInstrInfoZvk.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,19 @@ foreach vtiToWti = AllWidenableIntVectors in {
629629
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
630630
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
631631

632+
def : Pat<(riscv_shl_vl
633+
(wti.Vector (riscv_zext_vl_oneuse
634+
(vti.Vector vti.RegClass:$rs2),
635+
(vti.Mask V0), VLOpFrag)),
636+
(wti.Vector (riscv_ext_vl_oneuse
637+
(vti.Vector vti.RegClass:$rs1),
638+
(vti.Mask V0), VLOpFrag)),
639+
(wti.Vector wti.RegClass:$merge),
640+
(vti.Mask V0), VLOpFrag),
641+
(!cast<Instruction>("PseudoVWSLL_VV_"#vti.LMul.MX#"_MASK")
642+
wti.RegClass:$merge, vti.RegClass:$rs2, vti.RegClass:$rs1,
643+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
644+
632645
def : Pat<(riscv_shl_vl
633646
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
634647
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
@@ -638,6 +651,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
638651
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
639652
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
640653

654+
def : Pat<(riscv_shl_vl
655+
(wti.Vector (riscv_zext_vl_oneuse
656+
(vti.Vector vti.RegClass:$rs2),
657+
(vti.Mask V0), VLOpFrag)),
658+
(wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1))),
659+
(wti.Vector wti.RegClass:$merge),
660+
(vti.Mask V0), VLOpFrag),
661+
(!cast<Instruction>("PseudoVWSLL_VX_"#vti.LMul.MX#"_MASK")
662+
wti.RegClass:$merge, vti.RegClass:$rs2, GPR:$rs1,
663+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
664+
641665
def : Pat<(riscv_shl_vl
642666
(wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs2))),
643667
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
@@ -647,6 +671,17 @@ foreach vtiToWti = AllWidenableIntVectors in {
647671
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
648672
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
649673

674+
def : Pat<(riscv_shl_vl
675+
(wti.Vector (riscv_zext_vl_oneuse
676+
(vti.Vector vti.RegClass:$rs2),
677+
(vti.Mask V0), VLOpFrag)),
678+
(wti.Vector (SplatPat_uimm5 uimm5:$rs1)),
679+
(wti.Vector wti.RegClass:$merge),
680+
(vti.Mask V0), VLOpFrag),
681+
(!cast<Instruction>("PseudoVWSLL_VI_"#vti.LMul.MX#"_MASK")
682+
wti.RegClass:$merge, vti.RegClass:$rs2, uimm5:$rs1,
683+
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
684+
650685
def : Pat<(riscv_vwsll_vl
651686
(vti.Vector vti.RegClass:$rs2),
652687
(vti.Vector vti.RegClass:$rs1),

0 commit comments

Comments
 (0)