Skip to content

Commit caaba2a

Browse files
authored
[RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/USAT opcodes (#100173)
These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0. This should simplify #99418 a little.
1 parent 7868c04 commit caaba2a

File tree

3 files changed

+44
-106
lines changed

3 files changed

+44
-106
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
29972997
CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2);
29982998
CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT);
29992999
// Rounding mode here is arbitrary since we aren't shifting out any bits.
3000-
unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL;
3001-
Res = DAG.getNode(
3002-
ClipOpc, DL, CvtContainerVT,
3003-
{Res, DAG.getConstant(0, DL, CvtContainerVT),
3004-
DAG.getUNDEF(CvtContainerVT), Mask,
3005-
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
3006-
VL});
3000+
unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT
3001+
: RISCVISD::TRUNCATE_VECTOR_VL_USAT;
3002+
Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL);
30073003
}
30083004

30093005
SDValue SplatZero = DAG.getNode(
@@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
1664316639
SDValue Val;
1664416640
unsigned ClipOpc;
1664516641
if ((Val = DetectUSatPattern(Src)))
16646-
ClipOpc = RISCVISD::VNCLIPU_VL;
16642+
ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
1664716643
else if ((Val = DetectSSatPattern(Src)))
16648-
ClipOpc = RISCVISD::VNCLIP_VL;
16644+
ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
1664916645
else
1665016646
return SDValue();
1665116647

@@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
1665416650
do {
1665516651
MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2);
1665616652
ValVT = ValVT.changeVectorElementType(ValEltVT);
16657-
// Rounding mode here is arbitrary since we aren't shifting out any bits.
16658-
Val = DAG.getNode(
16659-
ClipOpc, DL, ValVT,
16660-
{Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask,
16661-
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
16662-
VL});
16653+
Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL);
1666316654
} while (ValVT != VT);
1666416655

1666516656
return Val;
@@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2046320454
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
2046420455
NODE_NAME_CASE(READ_VLENB)
2046520456
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
20457+
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
20458+
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
2046620459
NODE_NAME_CASE(VSLIDEUP_VL)
2046720460
NODE_NAME_CASE(VSLIDE1UP_VL)
2046820461
NODE_NAME_CASE(VSLIDEDOWN_VL)
@@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2050620499
NODE_NAME_CASE(UADDSAT_VL)
2050720500
NODE_NAME_CASE(SSUBSAT_VL)
2050820501
NODE_NAME_CASE(USUBSAT_VL)
20509-
NODE_NAME_CASE(VNCLIP_VL)
20510-
NODE_NAME_CASE(VNCLIPU_VL)
2051120502
NODE_NAME_CASE(FADD_VL)
2051220503
NODE_NAME_CASE(FSUB_VL)
2051320504
NODE_NAME_CASE(FMUL_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ enum NodeType : unsigned {
181181
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
182182
// mask and VL operand.
183183
TRUNCATE_VECTOR_VL,
184+
// Truncates a RVV integer vector by one power-of-two. If the value doesn't
185+
// fit in the destination type, the result is saturated. These correspond to
186+
// vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL
187+
// operand.
188+
TRUNCATE_VECTOR_VL_SSAT,
189+
TRUNCATE_VECTOR_VL_USAT,
184190
// Matches the semantics of vslideup/vslidedown. The first operand is the
185191
// pass-thru operand, the second is the source vector, the third is the XLenVT
186192
// index (either constant or non-constant), the fourth is the mask, the fifth
@@ -273,10 +279,6 @@ enum NodeType : unsigned {
273279
// Rounding averaging adds of unsigned integers.
274280
AVGCEILU_VL,
275281

276-
// Operands are (source, shift, merge, mask, roundmode, vl)
277-
VNCLIPU_VL,
278-
VNCLIP_VL,
279-
280282
MULHS_VL,
281283
MULHU_VL,
282284
FADD_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 30 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
132132
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
133133
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
134134

135-
def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
136-
def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
137-
138135
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
139136
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
140137
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
408405
[(riscv_sext_vl node:$A, node:$B, node:$C),
409406
(riscv_zext_vl node:$A, node:$B, node:$C)]>;
410407

408+
def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
409+
SDTCisSameNumEltsAs<0, 1>,
410+
SDTCisSameNumEltsAs<0, 2>,
411+
SDTCVecEltisVT<2, i1>,
412+
SDTCisVT<3, XLenVT>]>;
411413
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
412-
SDTypeProfile<1, 3, [SDTCisVec<0>,
413-
SDTCisSameNumEltsAs<0, 1>,
414-
SDTCisSameNumEltsAs<0, 2>,
415-
SDTCVecEltisVT<2, i1>,
416-
SDTCisVT<3, XLenVT>]>>;
414+
SDT_RISCVVTRUNCATE_VL>;
415+
def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT",
416+
SDT_RISCVVTRUNCATE_VL>;
417+
def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT",
418+
SDT_RISCVVTRUNCATE_VL>;
417419

418420
def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
419421
SDTCisInt<1>,
@@ -650,34 +652,6 @@ class VPatBinaryVL_V<SDPatternOperator vop,
650652
op2_reg_class:$rs2,
651653
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
652654

653-
multiclass VPatBinaryRM_VL_V<SDNode vop,
654-
string instruction_name,
655-
string suffix,
656-
ValueType result_type,
657-
ValueType op1_type,
658-
ValueType op2_type,
659-
ValueType mask_type,
660-
int sew,
661-
LMULInfo vlmul,
662-
VReg result_reg_class,
663-
VReg op1_reg_class,
664-
VReg op2_reg_class> {
665-
def : Pat<(result_type (vop
666-
(op1_type op1_reg_class:$rs1),
667-
(op2_type op2_reg_class:$rs2),
668-
(result_type result_reg_class:$merge),
669-
(mask_type V0),
670-
(XLenVT timm:$roundmode),
671-
VLOpFrag)),
672-
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
673-
result_reg_class:$merge,
674-
op1_reg_class:$rs1,
675-
op2_reg_class:$rs2,
676-
(mask_type V0),
677-
(XLenVT timm:$roundmode),
678-
GPR:$vl, sew, TAIL_AGNOSTIC)>;
679-
}
680-
681655
class VPatBinaryVL_V_RM<SDPatternOperator vop,
682656
string instruction_name,
683657
string suffix,
@@ -838,35 +812,6 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
838812
xop_kind:$rs2,
839813
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
840814

841-
multiclass VPatBinaryRM_VL_XI<SDNode vop,
842-
string instruction_name,
843-
string suffix,
844-
ValueType result_type,
845-
ValueType vop1_type,
846-
ValueType vop2_type,
847-
ValueType mask_type,
848-
int sew,
849-
LMULInfo vlmul,
850-
VReg result_reg_class,
851-
VReg vop_reg_class,
852-
ComplexPattern SplatPatKind,
853-
DAGOperand xop_kind> {
854-
def : Pat<(result_type (vop
855-
(vop1_type vop_reg_class:$rs1),
856-
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
857-
(result_type result_reg_class:$merge),
858-
(mask_type V0),
859-
(XLenVT timm:$roundmode),
860-
VLOpFrag)),
861-
(!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
862-
result_reg_class:$merge,
863-
vop_reg_class:$rs1,
864-
xop_kind:$rs2,
865-
(mask_type V0),
866-
(XLenVT timm:$roundmode),
867-
GPR:$vl, sew, TAIL_AGNOSTIC)>;
868-
}
869-
870815
multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
871816
list<VTypeInfo> vtilist = AllIntegerVectors,
872817
bit isSEWAware = 0> {
@@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
965910
}
966911
}
967912

968-
multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
969-
foreach VtiToWti = AllWidenableIntVectors in {
970-
defvar vti = VtiToWti.Vti;
971-
defvar wti = VtiToWti.Wti;
972-
defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
973-
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
974-
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
975-
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
976-
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
977-
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
978-
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
979-
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
980-
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
981-
!cast<ComplexPattern>(SplatPat#_#uimm5),
982-
uimm5>;
983-
}
984-
}
985-
986913
class VPatBinaryVL_VF<SDPatternOperator vop,
987914
string instruction_name,
988915
ValueType result_type,
@@ -2468,8 +2395,26 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
24682395
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
24692396

24702397
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
2471-
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
2472-
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
2398+
foreach vtiTowti = AllWidenableIntVectors in {
2399+
defvar vti = vtiTowti.Vti;
2400+
defvar wti = vtiTowti.Wti;
2401+
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
2402+
GetVTypePredicates<wti>.Predicates) in {
2403+
// Rounding mode here is arbitrary since we aren't shifting out any bits.
2404+
def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1),
2405+
(vti.Mask V0),
2406+
VLOpFrag)),
2407+
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
2408+
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
2409+
(vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
2410+
def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1),
2411+
(vti.Mask V0),
2412+
VLOpFrag)),
2413+
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
2414+
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
2415+
(vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
2416+
}
2417+
}
24732418

24742419
// 13. Vector Floating-Point Instructions
24752420

0 commit comments

Comments
 (0)