Skip to content

Commit a79d13f

Browse files
authored
[RISCV][ISel] Use vaaddu with rounding mode rnu for ISD::AVGCEILU. (#77473)
Similar to #76550, but for `ISD::AVGCEILU`. Specifically, this patch aims to use `vaaddu` with rounding mode rnu (i.e `vxrm[1:0] = 0b00`) for `ISD::AVGCEILU`. ### Source code ``` define <vscale x 8 x i8> @vaaddu_vv_nxv8i8_ceil(<vscale x 8 x i8> %x, <vscale x 8 x i8> %y) { %xzv = zext <vscale x 8 x i8> %x to <vscale x 8 x i16> %yzv = zext <vscale x 8 x i8> %y to <vscale x 8 x i16> %add = add nuw nsw <vscale x 8 x i16> %xzv, %yzv %one = insertelement <vscale x 8 x i16> poison, i16 1, i32 0 %splat = shufflevector <vscale x 8 x i16> %one, <vscale x 8 x i16> poison, <vscale x 8 x i32> zeroinitializer %add1 = add nuw nsw <vscale x 8 x i16> %add, %splat %div = lshr <vscale x 8 x i16> %add1, %splat %ret = trunc <vscale x 8 x i16> %div to <vscale x 8 x i8> ret <vscale x 8 x i8> %ret } ``` ### Before this patch ``` vaaddu_vv_nxv8i8_ceil: vsetvli a0, zero, e8, m1, ta, ma vwaddu.vv v10, v8, v9 vsetvli zero, zero, e16, m2, ta, ma vadd.vi v10, v10, 1 vsetvli zero, zero, e8, m1, ta, ma vnsrl.wi v8, v10, 1 ret ``` ### After this patch ``` vaaddu_vv_nxv8i8_ceil: vsetvli a0, zero, e8, m1, ta, ma csrwi vxrm, 0 vaaddu.vv v8, v8, v9 ret ```
1 parent 3593ade commit a79d13f

File tree

6 files changed

+604
-84
lines changed

6 files changed

+604
-84
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -814,8 +814,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
814814
setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
815815
Custom);
816816
setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
817-
setOperationAction({ISD::AVGFLOORU, ISD::SADDSAT, ISD::UADDSAT,
818-
ISD::SSUBSAT, ISD::USUBSAT},
817+
setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
818+
ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
819819
VT, Legal);
820820

821821
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL"
@@ -1185,8 +1185,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
11851185
if (VT.getVectorElementType() != MVT::i64 || Subtarget.hasStdExtV())
11861186
setOperationAction({ISD::MULHS, ISD::MULHU}, VT, Custom);
11871187

1188-
setOperationAction({ISD::AVGFLOORU, ISD::SADDSAT, ISD::UADDSAT,
1189-
ISD::SSUBSAT, ISD::USUBSAT},
1188+
setOperationAction({ISD::AVGFLOORU, ISD::AVGCEILU, ISD::SADDSAT,
1189+
ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT},
11901190
VT, Custom);
11911191

11921192
setOperationAction(ISD::VSELECT, VT, Custom);
@@ -5466,6 +5466,7 @@ static unsigned getRISCVVLOp(SDValue Op) {
54665466
OP_CASE(SSUBSAT)
54675467
OP_CASE(USUBSAT)
54685468
OP_CASE(AVGFLOORU)
5469+
OP_CASE(AVGCEILU)
54695470
OP_CASE(FADD)
54705471
OP_CASE(FSUB)
54715472
OP_CASE(FMUL)
@@ -5570,7 +5571,7 @@ static bool hasMergeOp(unsigned Opcode) {
55705571
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
55715572
"not a RISC-V target specific op");
55725573
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5573-
125 &&
5574+
126 &&
55745575
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
55755576
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
55765577
21 &&
@@ -5596,7 +5597,7 @@ static bool hasMaskOp(unsigned Opcode) {
55965597
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
55975598
"not a RISC-V target specific op");
55985599
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
5599-
125 &&
5600+
126 &&
56005601
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
56015602
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
56025603
21 &&
@@ -6461,6 +6462,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
64616462
return SplitVectorOp(Op, DAG);
64626463
[[fallthrough]];
64636464
case ISD::AVGFLOORU:
6465+
case ISD::AVGCEILU:
64646466
case ISD::SADDSAT:
64656467
case ISD::UADDSAT:
64666468
case ISD::SSUBSAT:
@@ -18595,6 +18597,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1859518597
NODE_NAME_CASE(UREM_VL)
1859618598
NODE_NAME_CASE(XOR_VL)
1859718599
NODE_NAME_CASE(AVGFLOORU_VL)
18600+
NODE_NAME_CASE(AVGCEILU_VL)
1859818601
NODE_NAME_CASE(SADDSAT_VL)
1859918602
NODE_NAME_CASE(UADDSAT_VL)
1860018603
NODE_NAME_CASE(SSUBSAT_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ enum NodeType : unsigned {
255255

256256
// Averaging adds of unsigned integers.
257257
AVGFLOORU_VL,
258+
// Rounding averaging adds of unsigned integers.
259+
AVGCEILU_VL,
258260

259261
MULHS_VL,
260262
MULHU_VL,

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,23 @@ multiclass VPatMultiplyAddSDNode_VV_VX<SDNode op, string instruction_name> {
877877
}
878878
}
879879

880+
multiclass VPatAVGADD_VV_VX_RM<SDNode vop, int vxrm> {
881+
foreach vti = AllIntegerVectors in {
882+
let Predicates = GetVTypePredicates<vti>.Predicates in {
883+
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
884+
(vti.Vector vti.RegClass:$rs2)),
885+
(!cast<Instruction>("PseudoVAADDU_VV_"#vti.LMul.MX)
886+
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
887+
vxrm, vti.AVL, vti.Log2SEW, TA_MA)>;
888+
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
889+
(vti.Vector (SplatPat (XLenVT GPR:$rs2)))),
890+
(!cast<Instruction>("PseudoVAADDU_VX_"#vti.LMul.MX)
891+
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2,
892+
vxrm, vti.AVL, vti.Log2SEW, TA_MA)>;
893+
}
894+
}
895+
}
896+
880897
//===----------------------------------------------------------------------===//
881898
// Patterns.
882899
//===----------------------------------------------------------------------===//
@@ -1132,20 +1149,8 @@ defm : VPatBinarySDNode_VV_VX<ssubsat, "PseudoVSSUB">;
11321149
defm : VPatBinarySDNode_VV_VX<usubsat, "PseudoVSSUBU">;
11331150

11341151
// 12.2. Vector Single-Width Averaging Add and Subtract
1135-
foreach vti = AllIntegerVectors in {
1136-
let Predicates = GetVTypePredicates<vti>.Predicates in {
1137-
def : Pat<(avgflooru (vti.Vector vti.RegClass:$rs1),
1138-
(vti.Vector vti.RegClass:$rs2)),
1139-
(!cast<Instruction>("PseudoVAADDU_VV_"#vti.LMul.MX)
1140-
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs2,
1141-
0b10, vti.AVL, vti.Log2SEW, TA_MA)>;
1142-
def : Pat<(avgflooru (vti.Vector vti.RegClass:$rs1),
1143-
(vti.Vector (SplatPat (XLenVT GPR:$rs2)))),
1144-
(!cast<Instruction>("PseudoVAADDU_VX_"#vti.LMul.MX)
1145-
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2,
1146-
0b10, vti.AVL, vti.Log2SEW, TA_MA)>;
1147-
}
1148-
}
1152+
defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10>;
1153+
defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00>;
11491154

11501155
// 15. Vector Mask Instructions
11511156

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def riscv_cttz_vl : SDNode<"RISCVISD::CTTZ_VL", SDT_RISCVIntUnOp_VL>
112112
def riscv_ctpop_vl : SDNode<"RISCVISD::CTPOP_VL", SDT_RISCVIntUnOp_VL>;
113113

114114
def riscv_avgflooru_vl : SDNode<"RISCVISD::AVGFLOORU_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
115+
def riscv_avgceilu_vl : SDNode<"RISCVISD::AVGCEILU_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
115116
def riscv_saddsat_vl : SDNode<"RISCVISD::SADDSAT_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
116117
def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [SDNPCommutative]>;
117118
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
@@ -2031,6 +2032,25 @@ multiclass VPatSlide1VL_VF<SDNode vop, string instruction_name> {
20312032
}
20322033
}
20332034

2035+
multiclass VPatAVGADDVL_VV_VX_RM<SDNode vop, int vxrm> {
2036+
foreach vti = AllIntegerVectors in {
2037+
let Predicates = GetVTypePredicates<vti>.Predicates in {
2038+
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
2039+
(vti.Vector vti.RegClass:$rs2),
2040+
vti.RegClass:$merge, (vti.Mask V0), VLOpFrag),
2041+
(!cast<Instruction>("PseudoVAADDU_VV_"#vti.LMul.MX#"_MASK")
2042+
vti.RegClass:$merge, vti.RegClass:$rs1, vti.RegClass:$rs2,
2043+
(vti.Mask V0), vxrm, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
2044+
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
2045+
(vti.Vector (SplatPat (XLenVT GPR:$rs2))),
2046+
vti.RegClass:$merge, (vti.Mask V0), VLOpFrag),
2047+
(!cast<Instruction>("PseudoVAADDU_VX_"#vti.LMul.MX#"_MASK")
2048+
vti.RegClass:$merge, vti.RegClass:$rs1, GPR:$rs2,
2049+
(vti.Mask V0), vxrm, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
2050+
}
2051+
}
2052+
}
2053+
20342054
//===----------------------------------------------------------------------===//
20352055
// Patterns.
20362056
//===----------------------------------------------------------------------===//
@@ -2308,22 +2328,8 @@ defm : VPatBinaryVL_VV_VX<riscv_ssubsat_vl, "PseudoVSSUB">;
23082328
defm : VPatBinaryVL_VV_VX<riscv_usubsat_vl, "PseudoVSSUBU">;
23092329

23102330
// 12.2. Vector Single-Width Averaging Add and Subtract
2311-
foreach vti = AllIntegerVectors in {
2312-
let Predicates = GetVTypePredicates<vti>.Predicates in {
2313-
def : Pat<(riscv_avgflooru_vl (vti.Vector vti.RegClass:$rs1),
2314-
(vti.Vector vti.RegClass:$rs2),
2315-
vti.RegClass:$merge, (vti.Mask V0), VLOpFrag),
2316-
(!cast<Instruction>("PseudoVAADDU_VV_"#vti.LMul.MX#"_MASK")
2317-
vti.RegClass:$merge, vti.RegClass:$rs1, vti.RegClass:$rs2,
2318-
(vti.Mask V0), 0b10, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
2319-
def : Pat<(riscv_avgflooru_vl (vti.Vector vti.RegClass:$rs1),
2320-
(vti.Vector (SplatPat (XLenVT GPR:$rs2))),
2321-
vti.RegClass:$merge, (vti.Mask V0), VLOpFrag),
2322-
(!cast<Instruction>("PseudoVAADDU_VX_"#vti.LMul.MX#"_MASK")
2323-
vti.RegClass:$merge, vti.RegClass:$rs1, GPR:$rs2,
2324-
(vti.Mask V0), 0b10, GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
2325-
}
2326-
}
2331+
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10>;
2332+
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00>;
23272333

23282334
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
23292335
class VPatTruncSatClipMaxMinBase<string inst,

0 commit comments

Comments
 (0)