Skip to content

Commit ec8fe59

Browse files
authored
[RISCV] Move vnclipu patterns into DAGCombiner. (#93596)
I plan to add support for multiple layers of vnclipu. For example, i32->i8 using 2 vnclipu instructions. First clipping to 65535, then clipping to 255. Similar for signed vnclip. This scales poorly if we need to add patterns with 2 or 3 truncates. Instead, move the code to DAGCombiner with new ISD opcodes to represent VCLIP(U). This patch just moves the existing patterns into DAG combine. Support for multiple truncates will as a follow up. A similar patch series will be made for the signed vnclip.
1 parent 1cff741 commit ec8fe59

File tree

5 files changed

+174
-20
lines changed

5 files changed

+174
-20
lines changed

llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ inline static bool isValidRoundingMode(unsigned Mode) {
373373
}
374374
} // namespace RISCVFPRndMode
375375

376+
namespace RISCVVXRndMode {
377+
enum RoundingMode {
378+
RNU = 0,
379+
RNE = 1,
380+
RDN = 2,
381+
ROD = 3,
382+
};
383+
} // namespace RISCVVXRndMode
384+
376385
//===----------------------------------------------------------------------===//
377386
// Floating-point Immediates
378387
//

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5960,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
59605960
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
59615961
"not a RISC-V target specific op");
59625962
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5963-
128 &&
5963+
130 &&
59645964
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
59655965
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
59665966
21 &&
@@ -5986,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
59865986
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
59875987
"not a RISC-V target specific op");
59885988
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5989-
128 &&
5989+
130 &&
59905990
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
59915991
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
59925992
21 &&
@@ -16183,6 +16183,66 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
1618316183
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
1618416184
}
1618516185

16186+
// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
16187+
// value for the truncated type.
16188+
static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
16189+
const RISCVSubtarget &Subtarget) {
16190+
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
16191+
16192+
MVT VT = N->getSimpleValueType(0);
16193+
16194+
SDValue Mask = N->getOperand(1);
16195+
SDValue VL = N->getOperand(2);
16196+
16197+
SDValue Src = N->getOperand(0);
16198+
16199+
// Src must be a UMIN or UMIN_VL.
16200+
if (Src.getOpcode() != ISD::UMIN &&
16201+
!(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
16202+
Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
16203+
return SDValue();
16204+
16205+
auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
16206+
// Peek through conversion between fixed and scalable vectors.
16207+
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
16208+
isNullConstant(Op.getOperand(2)) &&
16209+
Op.getOperand(1).getValueType().isFixedLengthVector() &&
16210+
Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
16211+
Op.getOperand(1).getOperand(0).getValueType() == Op.getValueType() &&
16212+
isNullConstant(Op.getOperand(1).getOperand(1)))
16213+
Op = Op.getOperand(1).getOperand(0);
16214+
16215+
if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
16216+
return true;
16217+
16218+
if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
16219+
Op.getOperand(2) == VL) {
16220+
if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
16221+
SplatVal =
16222+
Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
16223+
return true;
16224+
}
16225+
}
16226+
16227+
return false;
16228+
};
16229+
16230+
APInt C;
16231+
if (!IsSplat(Src.getOperand(1), C))
16232+
return SDValue();
16233+
16234+
if (!C.isMask(VT.getScalarSizeInBits()))
16235+
return SDValue();
16236+
16237+
SDLoc DL(N);
16238+
// Rounding mode here is arbitrary since we aren't shifting out any bits.
16239+
return DAG.getNode(
16240+
RISCVISD::VNCLIPU_VL, DL, VT,
16241+
{Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
16242+
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
16243+
VL});
16244+
}
16245+
1618616246
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1618716247
DAGCombinerInfo &DCI) const {
1618816248
SelectionDAG &DAG = DCI.DAG;
@@ -16400,7 +16460,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1640016460
}
1640116461
return SDValue();
1640216462
case RISCVISD::TRUNCATE_VECTOR_VL:
16403-
return combineTruncOfSraSext(N, DAG);
16463+
if (SDValue V = combineTruncOfSraSext(N, DAG))
16464+
return V;
16465+
return combineTruncToVnclipu(N, DAG, Subtarget);
1640416466
case ISD::TRUNCATE:
1640516467
return performTRUNCATECombine(N, DAG, Subtarget);
1640616468
case ISD::SELECT:
@@ -20019,6 +20081,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2001920081
NODE_NAME_CASE(UADDSAT_VL)
2002020082
NODE_NAME_CASE(SSUBSAT_VL)
2002120083
NODE_NAME_CASE(USUBSAT_VL)
20084+
NODE_NAME_CASE(VNCLIP_VL)
20085+
NODE_NAME_CASE(VNCLIPU_VL)
2002220086
NODE_NAME_CASE(FADD_VL)
2002320087
NODE_NAME_CASE(FSUB_VL)
2002420088
NODE_NAME_CASE(FMUL_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ enum NodeType : unsigned {
273273
// Rounding averaging adds of unsigned integers.
274274
AVGCEILU_VL,
275275

276+
// Operands are (source, shift, merge, mask, roundmode, vl)
277+
VNCLIPU_VL,
278+
VNCLIP_VL,
279+
276280
MULHS_VL,
277281
MULHU_VL,
278282
FADD_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,13 +1196,6 @@ multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
11961196
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
11971197
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
11981198
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
1199-
1200-
def : Pat<(vti.Vector (riscv_trunc_vector_vl
1201-
(wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
1202-
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
1203-
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
1204-
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
1205-
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
12061199
}
12071200
}
12081201

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,18 @@ def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
3636
SDTCisSameNumEltsAs<0, 4>,
3737
SDTCisVT<5, XLenVT>]>;
3838

39+
// Input: (vector, vector/scalar, merge, mask, roundmode, vl)
40+
def SDT_RISCVVNBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisInt<0>,
41+
SDTCisSameAs<0, 3>,
42+
SDTCisSameNumEltsAs<0, 1>,
43+
SDTCisVec<1>,
44+
SDTCisOpSmallerThanOp<2, 1>,
45+
SDTCisSameAs<0, 2>,
46+
SDTCisSameNumEltsAs<0, 4>,
47+
SDTCVecEltisVT<4, i1>,
48+
SDTCisVT<5, XLenVT>,
49+
SDTCisVT<6, XLenVT>]>;
50+
3951
def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
4052
SDTCisVec<0>, SDTCisFP<0>,
4153
SDTCVecEltisVT<2, i1>,
@@ -120,6 +132,9 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
120132
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
121133
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
122134

135+
def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
136+
def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
137+
123138
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
124139
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
125140
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -635,6 +650,34 @@ class VPatBinaryVL_V<SDPatternOperator vop,
635650
op2_reg_class:$rs2,
636651
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
637652

653+
multiclass VPatBinaryRM_VL_V<SDNode vop,
654+
string instruction_name,
655+
string suffix,
656+
ValueType result_type,
657+
ValueType op1_type,
658+
ValueType op2_type,
659+
ValueType mask_type,
660+
int sew,
661+
LMULInfo vlmul,
662+
VReg result_reg_class,
663+
VReg op1_reg_class,
664+
VReg op2_reg_class> {
665+
def : Pat<(result_type (vop
666+
(op1_type op1_reg_class:$rs1),
667+
(op2_type op2_reg_class:$rs2),
668+
(result_type result_reg_class:$merge),
669+
(mask_type V0),
670+
(XLenVT timm:$roundmode),
671+
VLOpFrag)),
672+
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
673+
result_reg_class:$merge,
674+
op1_reg_class:$rs1,
675+
op2_reg_class:$rs2,
676+
(mask_type V0),
677+
(XLenVT timm:$roundmode),
678+
GPR:$vl, sew, TAIL_AGNOSTIC)>;
679+
}
680+
638681
class VPatBinaryVL_V_RM<SDPatternOperator vop,
639682
string instruction_name,
640683
string suffix,
@@ -795,6 +838,35 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
795838
xop_kind:$rs2,
796839
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
797840

841+
multiclass VPatBinaryRM_VL_XI<SDNode vop,
842+
string instruction_name,
843+
string suffix,
844+
ValueType result_type,
845+
ValueType vop1_type,
846+
ValueType vop2_type,
847+
ValueType mask_type,
848+
int sew,
849+
LMULInfo vlmul,
850+
VReg result_reg_class,
851+
VReg vop_reg_class,
852+
ComplexPattern SplatPatKind,
853+
DAGOperand xop_kind> {
854+
def : Pat<(result_type (vop
855+
(vop1_type vop_reg_class:$rs1),
856+
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
857+
(result_type result_reg_class:$merge),
858+
(mask_type V0),
859+
(XLenVT timm:$roundmode),
860+
VLOpFrag)),
861+
(!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
862+
result_reg_class:$merge,
863+
vop_reg_class:$rs1,
864+
xop_kind:$rs2,
865+
(mask_type V0),
866+
(XLenVT timm:$roundmode),
867+
GPR:$vl, sew, TAIL_AGNOSTIC)>;
868+
}
869+
798870
multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
799871
list<VTypeInfo> vtilist = AllIntegerVectors,
800872
bit isSEWAware = 0> {
@@ -893,6 +965,24 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
893965
}
894966
}
895967

968+
multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
969+
foreach VtiToWti = AllWidenableIntVectors in {
970+
defvar vti = VtiToWti.Vti;
971+
defvar wti = VtiToWti.Wti;
972+
defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
973+
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
974+
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
975+
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
976+
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
977+
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
978+
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
979+
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
980+
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
981+
!cast<ComplexPattern>(SplatPat#_#uimm5),
982+
uimm5>;
983+
}
984+
}
985+
896986
class VPatBinaryVL_VF<SDPatternOperator vop,
897987
string instruction_name,
898988
ValueType result_type,
@@ -2376,6 +2466,10 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
23762466
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
23772467
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
23782468

2469+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
2470+
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
2471+
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
2472+
23792473
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
23802474
multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
23812475
defvar sew = vti.SEW;
@@ -2410,16 +2504,6 @@ multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
24102504
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
24112505
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
24122506
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
2413-
2414-
def : Pat<(vti.Vector (riscv_trunc_vector_vl
2415-
(wti.Vector (riscv_umin_vl
2416-
(wti.Vector wti.RegClass:$rs1),
2417-
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
2418-
(wti.Vector undef), (wti.Mask V0), VLOpFrag)),
2419-
(vti.Mask V0), VLOpFrag)),
2420-
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
2421-
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
2422-
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
24232507
}
24242508
}
24252509

0 commit comments

Comments
 (0)