Skip to content

Commit 9d7f119

Browse files
committed
[RISCV] Move vnclipu patterns into DAGCombiner.
I plan to add support for multiple layers of vnclipu. For example, i32->i8 using 2 vnclipu instructions. First clipping to 65535, then clipping to 255. Similar for signed vnclip. This scales poorly if we need to add patterns with 2 or 3 truncates. Instead, move the code to DAGCombiner with new ISD opcodes to represent VCLIP(U). This patch just moves the existing patterns into DAG combine. Support for multiple truncates will as a follow up. A similar patch will be made for the signed vnclip patterns.
1 parent e10966f commit 9d7f119

File tree

5 files changed

+174
-20
lines changed

5 files changed

+174
-20
lines changed

llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ inline static bool isValidRoundingMode(unsigned Mode) {
373373
}
374374
} // namespace RISCVFPRndMode
375375

376+
namespace RISCVVXRndMode {
377+
enum RoundingMode {
378+
RNU = 0,
379+
RNE = 1,
380+
RDN = 2,
381+
ROD = 3,
382+
};
383+
} // namespace RISCVVXRndMode
384+
376385
//===----------------------------------------------------------------------===//
377386
// Floating-point Immediates
378387
//

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5960,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
59605960
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
59615961
"not a RISC-V target specific op");
59625962
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5963-
128 &&
5963+
130 &&
59645964
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
59655965
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
59665966
21 &&
@@ -5986,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
59865986
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
59875987
"not a RISC-V target specific op");
59885988
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5989-
128 &&
5989+
130 &&
59905990
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
59915991
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
59925992
21 &&
@@ -16139,6 +16139,66 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
1613916139
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
1614016140
}
1614116141

16142+
// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
16143+
// value for the truncated type.
16144+
static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
16145+
const RISCVSubtarget &Subtarget) {
16146+
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
16147+
16148+
MVT VT = N->getSimpleValueType(0);
16149+
16150+
SDValue Mask = N->getOperand(1);
16151+
SDValue VL = N->getOperand(2);
16152+
16153+
SDValue Src = N->getOperand(0);
16154+
16155+
// Src must be a UMIN or UMIN_VL.
16156+
if (Src.getOpcode() != ISD::UMIN &&
16157+
!(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
16158+
Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
16159+
return SDValue();
16160+
16161+
auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
16162+
// Peek through conversion between fixed and scalable vectors.
16163+
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
16164+
isNullConstant(Op.getOperand(2)) &&
16165+
Op.getOperand(1).getValueType().isFixedLengthVector() &&
16166+
Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
16167+
Op.getOperand(1).getOperand(0).getValueType() == Op.getValueType() &&
16168+
isNullConstant(Op.getOperand(1).getOperand(1)))
16169+
Op = Op.getOperand(1).getOperand(0);
16170+
16171+
if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
16172+
return true;
16173+
16174+
if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
16175+
Op.getOperand(2) == VL) {
16176+
if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
16177+
SplatVal =
16178+
Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
16179+
return true;
16180+
}
16181+
}
16182+
16183+
return false;
16184+
};
16185+
16186+
APInt C;
16187+
if (!IsSplat(Src.getOperand(1), C))
16188+
return SDValue();
16189+
16190+
if (!C.isMask(VT.getScalarSizeInBits()))
16191+
return SDValue();
16192+
16193+
SDLoc DL(N);
16194+
// Rounding mode here is arbitrary since we aren't shifting out any bits.
16195+
return DAG.getNode(
16196+
RISCVISD::VNCLIPU_VL, DL, VT,
16197+
{Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
16198+
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
16199+
VL});
16200+
}
16201+
1614216202
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1614316203
DAGCombinerInfo &DCI) const {
1614416204
SelectionDAG &DAG = DCI.DAG;
@@ -16356,7 +16416,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1635616416
}
1635716417
return SDValue();
1635816418
case RISCVISD::TRUNCATE_VECTOR_VL:
16359-
return combineTruncOfSraSext(N, DAG);
16419+
if (SDValue V = combineTruncOfSraSext(N, DAG))
16420+
return V;
16421+
return combineTruncToVnclipu(N, DAG, Subtarget);
1636016422
case ISD::TRUNCATE:
1636116423
return performTRUNCATECombine(N, DAG, Subtarget);
1636216424
case ISD::SELECT:
@@ -19975,6 +20037,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1997520037
NODE_NAME_CASE(UADDSAT_VL)
1997620038
NODE_NAME_CASE(SSUBSAT_VL)
1997720039
NODE_NAME_CASE(USUBSAT_VL)
20040+
NODE_NAME_CASE(VNCLIP_VL)
20041+
NODE_NAME_CASE(VNCLIPU_VL)
1997820042
NODE_NAME_CASE(FADD_VL)
1997920043
NODE_NAME_CASE(FSUB_VL)
1998020044
NODE_NAME_CASE(FMUL_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ enum NodeType : unsigned {
273273
// Rounding averaging adds of unsigned integers.
274274
AVGCEILU_VL,
275275

276+
// Operands are (source, shift, merge, mask, roundmode, vl)
277+
VNCLIPU_VL,
278+
VNCLIP_VL,
279+
276280
MULHS_VL,
277281
MULHU_VL,
278282
FADD_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,13 +1196,6 @@ multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
11961196
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
11971197
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
11981198
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
1199-
1200-
def : Pat<(vti.Vector (riscv_trunc_vector_vl
1201-
(wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
1202-
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
1203-
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
1204-
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
1205-
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
12061199
}
12071200
}
12081201

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
3636
SDTCisSameNumEltsAs<0, 4>,
3737
SDTCisVT<5, XLenVT>]>;
3838

39+
// Input: (vector, vector/scalar, merge, mask, roundmode, vl)
40+
def SDT_RISCVVNBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisInt<0>,
41+
SDTCisSameAs<0, 3>,
42+
SDTCisSameNumEltsAs<0, 1>,
43+
SDTCisVec<1>,
44+
SDTCisOpSmallerThanOp<2, 1>,
45+
SDTCisSameAs<0, 2>,
46+
SDTCisSameNumEltsAs<0, 4>,
47+
SDTCVecEltisVT<4, i1>,
48+
SDTCisVT<5, XLenVT>,
49+
SDTCisVT<6, XLenVT>]>;
50+
3951
def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
4052
SDTCisVec<0>, SDTCisFP<0>,
4153
SDTCVecEltisVT<2, i1>,
@@ -120,6 +132,9 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
120132
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
121133
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
122134

135+
def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
136+
def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
137+
123138
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
124139
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
125140
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -635,6 +650,34 @@ class VPatBinaryVL_V<SDPatternOperator vop,
635650
op2_reg_class:$rs2,
636651
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
637652

653+
multiclass VPatBinaryRM_VL_V<SDNode vop,
654+
string instruction_name,
655+
string suffix,
656+
ValueType result_type,
657+
ValueType op1_type,
658+
ValueType op2_type,
659+
ValueType mask_type,
660+
int sew,
661+
LMULInfo vlmul,
662+
VReg result_reg_class,
663+
VReg op1_reg_class,
664+
VReg op2_reg_class> {
665+
def : Pat<(result_type (vop
666+
(op1_type op1_reg_class:$rs1),
667+
(op2_type op2_reg_class:$rs2),
668+
(result_type result_reg_class:$merge),
669+
(mask_type V0),
670+
(XLenVT timm:$roundmode),
671+
VLOpFrag)),
672+
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
673+
result_reg_class:$merge,
674+
op1_reg_class:$rs1,
675+
op2_reg_class:$rs2,
676+
(mask_type V0),
677+
(XLenVT timm:$roundmode),
678+
GPR:$vl, sew, TAIL_AGNOSTIC)>;
679+
}
680+
638681
class VPatBinaryVL_V_RM<SDPatternOperator vop,
639682
string instruction_name,
640683
string suffix,
@@ -795,6 +838,35 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
795838
xop_kind:$rs2,
796839
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
797840

841+
multiclass VPatBinaryRM_VL_XI<SDNode vop,
842+
string instruction_name,
843+
string suffix,
844+
ValueType result_type,
845+
ValueType vop1_type,
846+
ValueType vop2_type,
847+
ValueType mask_type,
848+
int sew,
849+
LMULInfo vlmul,
850+
VReg result_reg_class,
851+
VReg vop_reg_class,
852+
ComplexPattern SplatPatKind,
853+
DAGOperand xop_kind> {
854+
def : Pat<(result_type (vop
855+
(vop1_type vop_reg_class:$rs1),
856+
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
857+
(result_type result_reg_class:$merge),
858+
(mask_type V0),
859+
(XLenVT timm:$roundmode),
860+
VLOpFrag)),
861+
(!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
862+
result_reg_class:$merge,
863+
vop_reg_class:$rs1,
864+
xop_kind:$rs2,
865+
(mask_type V0),
866+
(XLenVT timm:$roundmode),
867+
GPR:$vl, sew, TAIL_AGNOSTIC)>;
868+
}
869+
798870
multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
799871
list<VTypeInfo> vtilist = AllIntegerVectors,
800872
bit isSEWAware = 0> {
@@ -893,6 +965,24 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
893965
}
894966
}
895967

968+
multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
969+
foreach VtiToWti = AllWidenableIntVectors in {
970+
defvar vti = VtiToWti.Vti;
971+
defvar wti = VtiToWti.Wti;
972+
defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
973+
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
974+
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
975+
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
976+
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
977+
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
978+
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
979+
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
980+
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
981+
!cast<ComplexPattern>(SplatPat#_#uimm5),
982+
uimm5>;
983+
}
984+
}
985+
896986
class VPatBinaryVL_VF<SDPatternOperator vop,
897987
string instruction_name,
898988
ValueType result_type,
@@ -2376,6 +2466,10 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
23762466
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
23772467
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
23782468

2469+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
2470+
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
2471+
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
2472+
23792473
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
23802474
multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
23812475
defvar sew = vti.SEW;
@@ -2410,16 +2504,6 @@ multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
24102504
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
24112505
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
24122506
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
2413-
2414-
def : Pat<(vti.Vector (riscv_trunc_vector_vl
2415-
(wti.Vector (riscv_umin_vl
2416-
(wti.Vector wti.RegClass:$rs1),
2417-
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
2418-
(wti.Vector undef), (wti.Mask V0), VLOpFrag)),
2419-
(vti.Mask V0), VLOpFrag)),
2420-
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
2421-
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
2422-
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
24232507
}
24242508
}
24252509

0 commit comments

Comments
 (0)