-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Move vnclipu patterns into DAGCombiner. #93596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesI plan to add support for multiple layers of vnclipu. For example, This scales poorly if we need to add patterns with 2 or 3 truncates. This patch just moves the existing patterns into DAG combine. Support Full diff: https://github.com/llvm/llvm-project/pull/93596.diff 5 Files Affected:
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 08f056f78979a..550904516ac8e 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -373,6 +373,15 @@ inline static bool isValidRoundingMode(unsigned Mode) {
}
} // namespace RISCVFPRndMode
+namespace RISCVVXRndMode {
+enum RoundingMode {
+ RNU = 0,
+ RNE = 1,
+ RDN = 2,
+ ROD = 3,
+};
+} // namespace RISCVVXRndMode
+
//===----------------------------------------------------------------------===//
// Floating-point Immediates
//
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..1862dec7b1afd 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5960,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 128 &&
+ 130 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -5986,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 128 &&
+ 130 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -16087,6 +16087,117 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+ // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+ // This would be benefit for the cases where X and Y are both the same value
+ // type of low precision vectors. Since the truncate would be lowered into
+ // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+ // restriction, such pattern would be expanded into a series of "vsetvli"
+ // and "vnsrl" instructions later to reach this point.
+ auto IsTruncNode = [](SDValue V) {
+ if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+ return false;
+ SDValue VL = V.getOperand(2);
+ auto *C = dyn_cast<ConstantSDNode>(VL);
+ // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+ bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+ (isa<RegisterSDNode>(VL) &&
+ cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+ return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+ };
+
+ SDValue Op = N->getOperand(0);
+
+ // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+ // to distinguish such pattern.
+ while (IsTruncNode(Op)) {
+ if (!Op.hasOneUse())
+ return SDValue();
+ Op = Op.getOperand(0);
+ }
+
+ if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+ return SDValue();
+
+ SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
+ if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+ N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N10 = N1.getOperand(0);
+ if (!N00.getValueType().isVector() ||
+ N00.getValueType() != N10.getValueType() ||
+ N->getValueType(0) != N10.getValueType())
+ return SDValue();
+
+ unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+ SDValue SMin =
+ DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+ DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+ return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
+
+// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
+// value for the truncated type.
+static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+
+ MVT VT = N->getSimpleValueType(0);
+
+ SDValue Mask = N->getOperand(1);
+ SDValue VL = N->getOperand(2);
+
+ SDValue Src = N->getOperand(0);
+
+ // Src must be a UMIN or UMIN_VL.
+ if (Src.getOpcode() != ISD::UMIN &&
+ !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
+ Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
+ return SDValue();
+
+ auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
+ // Peek through conversion between fixed and scalable vectors.
+ if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
+ isNullConstant(Op.getOperand(2)) &&
+ Op.getOperand(1).getValueType().isFixedLengthVector() &&
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ Op.getOperand(1).getOperand(0).getValueType() == Op.getValueType() &&
+ isNullConstant(Op.getOperand(1).getOperand(1)))
+ Op = Op.getOperand(1).getOperand(0);
+
+ if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
+ return true;
+
+ if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
+ Op.getOperand(2) == VL) {
+ if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+ SplatVal =
+ Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
+ return true;
+ }
+ }
+
+ return false;
+ };
+
+ APInt C;
+ if (!IsSplat(Src.getOperand(1), C))
+ return SDValue();
+
+ if (!C.isMask(VT.getScalarSizeInBits()))
+ return SDValue();
+
+ SDLoc DL(N);
+ // Rounding mode here is arbitrary since we aren't shifting out any bits.
+ return DAG.getNode(
+ RISCVISD::VNCLIPU_VL, DL, VT,
+ {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+ DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
+ VL});
+}
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -16304,56 +16415,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
return SDValue();
- case RISCVISD::TRUNCATE_VECTOR_VL: {
- // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
- // This would be benefit for the cases where X and Y are both the same value
- // type of low precision vectors. Since the truncate would be lowered into
- // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
- // restriction, such pattern would be expanded into a series of "vsetvli"
- // and "vnsrl" instructions later to reach this point.
- auto IsTruncNode = [](SDValue V) {
- if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
- return false;
- SDValue VL = V.getOperand(2);
- auto *C = dyn_cast<ConstantSDNode>(VL);
- // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
- bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
- (isa<RegisterSDNode>(VL) &&
- cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
- return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
- IsVLMAXForVMSET;
- };
-
- SDValue Op = N->getOperand(0);
-
- // We need to first find the inner level of TRUNCATE_VECTOR_VL node
- // to distinguish such pattern.
- while (IsTruncNode(Op)) {
- if (!Op.hasOneUse())
- return SDValue();
- Op = Op.getOperand(0);
- }
-
- if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
- SDValue N0 = Op.getOperand(0);
- SDValue N1 = Op.getOperand(1);
- if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
- N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
- SDValue N00 = N0.getOperand(0);
- SDValue N10 = N1.getOperand(0);
- if (N00.getValueType().isVector() &&
- N00.getValueType() == N10.getValueType() &&
- N->getValueType(0) == N10.getValueType()) {
- unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
- SDValue SMin = DAG.getNode(
- ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
- DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
- return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
- }
- }
- }
- break;
- }
+ case RISCVISD::TRUNCATE_VECTOR_VL:
+ if (SDValue V = combineTruncOfSraSext(N, DAG))
+ return V;
+ return combineTruncToVnclipu(N, DAG, Subtarget);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
@@ -19972,6 +20037,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UADDSAT_VL)
NODE_NAME_CASE(SSUBSAT_VL)
NODE_NAME_CASE(USUBSAT_VL)
+ NODE_NAME_CASE(VNCLIP_VL)
+ NODE_NAME_CASE(VNCLIPU_VL)
NODE_NAME_CASE(FADD_VL)
NODE_NAME_CASE(FSUB_VL)
NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 856ce06ba1c4f..3b8eb3c88901a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -273,6 +273,10 @@ enum NodeType : unsigned {
// Rounding averaging adds of unsigned integers.
AVGCEILU_VL,
+ // Operands are (source, shift, merge, mask, roundmode, vl)
+ VNCLIPU_VL,
+ VNCLIP_VL,
+
MULHS_VL,
MULHU_VL,
FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 66df24f2a458d..691f2052ab29d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1196,13 +1196,6 @@ multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 91f3abe22331e..610a72dd02b38 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -36,6 +36,18 @@ def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
SDTCisSameNumEltsAs<0, 4>,
SDTCisVT<5, XLenVT>]>;
+// Input: (vector, vector/scalar, merge, mask, roundmode, vl)
+def SDT_RISCVVNBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisInt<0>,
+ SDTCisSameAs<0, 3>,
+ SDTCisSameNumEltsAs<0, 1>,
+ SDTCisVec<1>,
+ SDTCisOpSmallerThanOp<2, 1>,
+ SDTCisSameAs<0, 2>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisVT<5, XLenVT>,
+ SDTCisVT<6, XLenVT>]>;
+
def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisVec<0>, SDTCisFP<0>,
SDTCVecEltisVT<2, i1>,
@@ -120,6 +132,9 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
+def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
+
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -635,6 +650,34 @@ class VPatBinaryVL_V<SDPatternOperator vop,
op2_reg_class:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
+multiclass VPatBinaryRM_VL_V<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType op1_type,
+ ValueType op2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg op1_reg_class,
+ VReg op2_reg_class> {
+ def : Pat<(result_type (vop
+ (op1_type op1_reg_class:$rs1),
+ (op2_type op2_reg_class:$rs2),
+ (result_type result_reg_class:$merge),
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ VLOpFrag)),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ op1_reg_class:$rs1,
+ op2_reg_class:$rs2,
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
class VPatBinaryVL_V_RM<SDPatternOperator vop,
string instruction_name,
string suffix,
@@ -795,6 +838,35 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
xop_kind:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
+multiclass VPatBinaryRM_VL_XI<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType vop1_type,
+ ValueType vop2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg vop_reg_class,
+ ComplexPattern SplatPatKind,
+ DAGOperand xop_kind> {
+ def : Pat<(result_type (vop
+ (vop1_type vop_reg_class:$rs1),
+ (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
+ (result_type result_reg_class:$merge),
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ VLOpFrag)),
+ (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ vop_reg_class:$rs1,
+ xop_kind:$rs2,
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
list<VTypeInfo> vtilist = AllIntegerVectors,
bit isSEWAware = 0> {
@@ -893,6 +965,24 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
}
}
+multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
+ foreach VtiToWti = AllWidenableIntVectors in {
+ defvar vti = VtiToWti.Vti;
+ defvar wti = VtiToWti.Wti;
+ defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
+ vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
+ defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
+ vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
+ defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
+ vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
+ !cast<ComplexPattern>(SplatPat#_#uimm5),
+ uimm5>;
+ }
+}
+
class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
@@ -2376,6 +2466,10 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
+// 12.5. Vector Narrowing Fixed-Point Clip Instructions
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
defvar sew = vti.SEW;
@@ -2410,16 +2504,6 @@ multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (riscv_umin_vl
- (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
- (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/minor comments.
|
||
auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) { | ||
// Peek through conversion between fixed and scalable vectors. | ||
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this just a nop? Why aren't we combining that out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do eventually, but we only revisit the UMIN after that since its the immediate user. We don't revisit the truncate after that.
I think this is partially caused by the truncate and umin being legalized by LegalizeVectorOps and build_vector being legalized by LegalizeDAG.
Src.getOperand(3) == Mask && Src.getOperand(4) == VL)) | ||
return SDValue(); | ||
|
||
auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a general useful function, surely we have parts of this repeated elsewhere already?
if (ISD::isConstantSplatVector(Op.getNode(), SplatVal)) | ||
return true; | ||
|
||
if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add vmv_s_x when VL=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
I plan to add support for multiple layers of vnclipu. For example, i32->i8 using 2 vnclipu instructions. First clipping to 65535, then clipping to 255. Similar for signed vnclip. This scales poorly if we need to add patterns with 2 or 3 truncates. Instead, move the code to DAGCombiner with new ISD opcodes to represent VCLIP(U). This patch just moves the existing patterns into DAG combine. Support for multiple truncates will as a follow up. A similar patch will be made for the signed vnclip patterns.
9d7f119
to
20f81eb
Compare
Similar to llvm#93596, this moves the signed vnclip patterns into DAG combine. This will allows us to support more than 1 level of truncate in a future patch.
Similar to llvm#93596, this moves the signed vnclip patterns into DAG combine. This will allows us to support more than 1 level of truncate in a future patch.
Similar to #93596, this moves the signed vnclip patterns into DAG combine. This will allows us to support more than 1 level of truncate in a future patch.
I plan to add support for multiple layers of vnclipu. For example,
i32->i8 using 2 vnclipu instructions. First clipping to 65535, then
clipping to 255. Similar for signed vnclip.
This scales poorly if we need to add patterns with 2 or 3 truncates.
Instead, move the code to DAGCombiner with new ISD opcodes to represent
VCLIP(U).
This patch just moves the existing patterns into DAG combine. Support
for multiple truncates will as a follow up. A similar patch series will be
made for the signed vnclip.