Skip to content

[RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/USAT opcodes #100173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2);
CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL;
Res = DAG.getNode(
ClipOpc, DL, CvtContainerVT,
{Res, DAG.getConstant(0, DL, CvtContainerVT),
DAG.getUNDEF(CvtContainerVT), Mask,
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
VL});
unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT
: RISCVISD::TRUNCATE_VECTOR_VL_USAT;
Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL);
}

SDValue SplatZero = DAG.getNode(
Expand Down Expand Up @@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
SDValue Val;
unsigned ClipOpc;
if ((Val = DetectUSatPattern(Src)))
ClipOpc = RISCVISD::VNCLIPU_VL;
ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
else if ((Val = DetectSSatPattern(Src)))
ClipOpc = RISCVISD::VNCLIP_VL;
ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
else
return SDValue();

Expand All @@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
do {
MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2);
ValVT = ValVT.changeVectorElementType(ValEltVT);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
Val = DAG.getNode(
ClipOpc, DL, ValVT,
{Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask,
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
VL});
Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL);
} while (ValVT != VT);

return Val;
Expand Down Expand Up @@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
Expand Down Expand Up @@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UADDSAT_VL)
NODE_NAME_CASE(SSUBSAT_VL)
NODE_NAME_CASE(USUBSAT_VL)
NODE_NAME_CASE(VNCLIP_VL)
NODE_NAME_CASE(VNCLIPU_VL)
NODE_NAME_CASE(FADD_VL)
NODE_NAME_CASE(FSUB_VL)
NODE_NAME_CASE(FMUL_VL)
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
// Truncates a RVV integer vector by one power-of-two. If the value doesn't
// fit in the destination type, the result is saturated. These correspond to
// vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL
// operand.
TRUNCATE_VECTOR_VL_SSAT,
TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
Expand Down Expand Up @@ -273,10 +279,6 @@ enum NodeType : unsigned {
// Rounding averaging adds of unsigned integers.
AVGCEILU_VL,

// Operands are (source, shift, merge, mask, roundmode, vl)
VNCLIPU_VL,
VNCLIP_VL,

MULHS_VL,
MULHU_VL,
FADD_VL,
Expand Down
115 changes: 30 additions & 85 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,6 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;

def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;

def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
Expand Down Expand Up @@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
[(riscv_sext_vl node:$A, node:$B, node:$C),
(riscv_zext_vl node:$A, node:$B, node:$C)]>;

def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisSameNumEltsAs<0, 2>,
SDTCVecEltisVT<2, i1>,
SDTCisVT<3, XLenVT>]>;
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
SDTypeProfile<1, 3, [SDTCisVec<0>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisSameNumEltsAs<0, 2>,
SDTCVecEltisVT<2, i1>,
SDTCisVT<3, XLenVT>]>>;
SDT_RISCVVTRUNCATE_VL>;
def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT",
SDT_RISCVVTRUNCATE_VL>;
def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT",
SDT_RISCVVTRUNCATE_VL>;

def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisInt<1>,
Expand Down Expand Up @@ -650,34 +652,6 @@ class VPatBinaryVL_V<SDPatternOperator vop,
op2_reg_class:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;

multiclass VPatBinaryRM_VL_V<SDNode vop,
string instruction_name,
string suffix,
ValueType result_type,
ValueType op1_type,
ValueType op2_type,
ValueType mask_type,
int sew,
LMULInfo vlmul,
VReg result_reg_class,
VReg op1_reg_class,
VReg op2_reg_class> {
def : Pat<(result_type (vop
(op1_type op1_reg_class:$rs1),
(op2_type op2_reg_class:$rs2),
(result_type result_reg_class:$merge),
(mask_type V0),
(XLenVT timm:$roundmode),
VLOpFrag)),
(!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
result_reg_class:$merge,
op1_reg_class:$rs1,
op2_reg_class:$rs2,
(mask_type V0),
(XLenVT timm:$roundmode),
GPR:$vl, sew, TAIL_AGNOSTIC)>;
}

class VPatBinaryVL_V_RM<SDPatternOperator vop,
string instruction_name,
string suffix,
Expand Down Expand Up @@ -838,35 +812,6 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
xop_kind:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;

multiclass VPatBinaryRM_VL_XI<SDNode vop,
string instruction_name,
string suffix,
ValueType result_type,
ValueType vop1_type,
ValueType vop2_type,
ValueType mask_type,
int sew,
LMULInfo vlmul,
VReg result_reg_class,
VReg vop_reg_class,
ComplexPattern SplatPatKind,
DAGOperand xop_kind> {
def : Pat<(result_type (vop
(vop1_type vop_reg_class:$rs1),
(vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
(result_type result_reg_class:$merge),
(mask_type V0),
(XLenVT timm:$roundmode),
VLOpFrag)),
(!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
result_reg_class:$merge,
vop_reg_class:$rs1,
xop_kind:$rs2,
(mask_type V0),
(XLenVT timm:$roundmode),
GPR:$vl, sew, TAIL_AGNOSTIC)>;
}

multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
list<VTypeInfo> vtilist = AllIntegerVectors,
bit isSEWAware = 0> {
Expand Down Expand Up @@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
}
}

multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
foreach VtiToWti = AllWidenableIntVectors in {
defvar vti = VtiToWti.Vti;
defvar wti = VtiToWti.Wti;
defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
!cast<ComplexPattern>(SplatPat#_#uimm5),
uimm5>;
}
}

class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
Expand Down Expand Up @@ -2468,8 +2395,26 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;

// 12.5. Vector Narrowing Fixed-Point Clip Instructions
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
// Rounding mode here is arbitrary since we aren't shifting out any bits.
def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1),
(vti.Mask V0),
VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1),
(vti.Mask V0),
VLOpFrag)),
(!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}

// 13. Vector Floating-Point Instructions

Expand Down
Loading