Skip to content

Commit 3890616

Browse files
dnsampaiosmallp-o-p
authored andcommitted
[DAGCombiner] Add combine avg from shifts (llvm#113909)
This teaches dagcombiner to fold: `(asr (add nsw x, y), 1) -> (avgfloors x, y)` `(lsr (add nuw x, y), 1) -> (avgflooru x, y)` as well the combine them to a ceil variant: `(avgfloors (add nsw x, y), 1) -> (avgceils x, y)` `(avgflooru (add nuw x, y), 1) -> (avgceilu x, y)` iff valid for the target. Removes some of the ARM MVE patterns that are now dead code. It adds the avg opcodes to `IsQRMVEInstruction` as to preserve the immediate splatting as before.
1 parent f564be7 commit 3890616

File tree

5 files changed

+329
-76
lines changed

5 files changed

+329
-76
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ namespace {
401401
SDValue PromoteExtend(SDValue Op);
402402
bool PromoteLoad(SDValue Op);
403403

404+
SDValue foldShiftToAvg(SDNode *N);
405+
404406
SDValue combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
405407
SDValue RHS, SDValue True, SDValue False,
406408
ISD::CondCode CC);
@@ -5351,6 +5353,27 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
53515353
DAG.getNode(ISD::ADD, DL, VT, N0, DAG.getAllOnesConstant(DL, VT)));
53525354
}
53535355

5356+
// Fold avgfloor((add nw x,y), 1) -> avgceil(x,y)
5357+
// Fold avgfloor((add nw x,1), y) -> avgceil(x,y)
5358+
if ((Opcode == ISD::AVGFLOORU && hasOperation(ISD::AVGCEILU, VT)) ||
5359+
(Opcode == ISD::AVGFLOORS && hasOperation(ISD::AVGCEILS, VT))) {
5360+
SDValue Add;
5361+
if (sd_match(N,
5362+
m_c_BinOp(Opcode,
5363+
m_AllOf(m_Value(Add), m_Add(m_Value(X), m_Value(Y))),
5364+
m_One())) ||
5365+
sd_match(N, m_c_BinOp(Opcode,
5366+
m_AllOf(m_Value(Add), m_Add(m_Value(X), m_One())),
5367+
m_Value(Y)))) {
5368+
5369+
if (IsSigned && Add->getFlags().hasNoSignedWrap())
5370+
return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
5371+
5372+
if (!IsSigned && Add->getFlags().hasNoUnsignedWrap())
5373+
return DAG.getNode(ISD::AVGCEILU, DL, VT, X, Y);
5374+
}
5375+
}
5376+
53545377
return SDValue();
53555378
}
53565379

@@ -10629,6 +10652,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
1062910652
if (SDValue NarrowLoad = reduceLoadWidth(N))
1063010653
return NarrowLoad;
1063110654

10655+
if (SDValue AVG = foldShiftToAvg(N))
10656+
return AVG;
10657+
1063210658
return SDValue();
1063310659
}
1063410660

@@ -10883,6 +10909,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
1088310909
if (SDValue MULH = combineShiftToMULH(N, DL, DAG, TLI))
1088410910
return MULH;
1088510911

10912+
if (SDValue AVG = foldShiftToAvg(N))
10913+
return AVG;
10914+
1088610915
return SDValue();
1088710916
}
1088810917

@@ -11396,6 +11425,53 @@ static SDValue combineMinNumMaxNumImpl(const SDLoc &DL, EVT VT, SDValue LHS,
1139611425
}
1139711426
}
1139811427

11428+
SDValue DAGCombiner::foldShiftToAvg(SDNode *N) {
11429+
const unsigned Opcode = N->getOpcode();
11430+
11431+
// Convert (sr[al] (add n[su]w x, y)) -> (avgfloor[su] x, y)
11432+
if (Opcode != ISD::SRA && Opcode != ISD::SRL)
11433+
return SDValue();
11434+
11435+
unsigned FloorISD = 0;
11436+
auto VT = N->getValueType(0);
11437+
bool IsUnsigned = false;
11438+
11439+
// Decide wether signed or unsigned.
11440+
switch (Opcode) {
11441+
case ISD::SRA:
11442+
if (!hasOperation(ISD::AVGFLOORS, VT))
11443+
return SDValue();
11444+
FloorISD = ISD::AVGFLOORS;
11445+
break;
11446+
case ISD::SRL:
11447+
IsUnsigned = true;
11448+
if (!hasOperation(ISD::AVGFLOORU, VT))
11449+
return SDValue();
11450+
FloorISD = ISD::AVGFLOORU;
11451+
break;
11452+
default:
11453+
return SDValue();
11454+
}
11455+
11456+
// Captured values.
11457+
SDValue A, B, Add;
11458+
11459+
// Match floor average as it is common to both floor/ceil avgs.
11460+
if (!sd_match(N, m_BinOp(Opcode,
11461+
m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
11462+
m_One())))
11463+
return SDValue();
11464+
11465+
// Can't optimize adds that may wrap.
11466+
if (IsUnsigned && !Add->getFlags().hasNoUnsignedWrap())
11467+
return SDValue();
11468+
11469+
if (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())
11470+
return SDValue();
11471+
11472+
return DAG.getNode(FloorISD, SDLoc(N), N->getValueType(0), {A, B});
11473+
}
11474+
1139911475
/// Generate Min/Max node
1140011476
SDValue DAGCombiner::combineMinNumMaxNum(const SDLoc &DL, EVT VT, SDValue LHS,
1140111477
SDValue RHS, SDValue True,

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7951,6 +7951,8 @@ static bool IsQRMVEInstruction(const SDNode *N, const SDNode *Op) {
79517951
case ISD::MUL:
79527952
case ISD::SADDSAT:
79537953
case ISD::UADDSAT:
7954+
case ISD::AVGFLOORS:
7955+
case ISD::AVGFLOORU:
79547956
return true;
79557957
case ISD::SUB:
79567958
case ISD::SSUBSAT:

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,64 +2222,6 @@ defm MVE_VRHADDu8 : MVE_VRHADD<MVE_v16u8, avgceilu>;
22222222
defm MVE_VRHADDu16 : MVE_VRHADD<MVE_v8u16, avgceilu>;
22232223
defm MVE_VRHADDu32 : MVE_VRHADD<MVE_v4u32, avgceilu>;
22242224

2225-
// Rounding Halving Add perform the arithemtic operation with an extra bit of
2226-
// precision, before performing the shift, to void clipping errors. We're not
2227-
// modelling that here with these patterns, but we're using no wrap forms of
2228-
// add to ensure that the extra bit of information is not needed for the
2229-
// arithmetic or the rounding.
2230-
let Predicates = [HasMVEInt] in {
2231-
def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2232-
(v16i8 (ARMvmovImm (i32 3585)))),
2233-
(i32 1))),
2234-
(MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
2235-
def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2236-
(v8i16 (ARMvmovImm (i32 2049)))),
2237-
(i32 1))),
2238-
(MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
2239-
def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2240-
(v4i32 (ARMvmovImm (i32 1)))),
2241-
(i32 1))),
2242-
(MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
2243-
def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2244-
(v16i8 (ARMvmovImm (i32 3585)))),
2245-
(i32 1))),
2246-
(MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
2247-
def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2248-
(v8i16 (ARMvmovImm (i32 2049)))),
2249-
(i32 1))),
2250-
(MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
2251-
def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2252-
(v4i32 (ARMvmovImm (i32 1)))),
2253-
(i32 1))),
2254-
(MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
2255-
2256-
def : Pat<(v16i8 (ARMvshrsImm (addnsw (addnsw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2257-
(v16i8 (ARMvdup (i32 1)))),
2258-
(i32 1))),
2259-
(MVE_VRHADDs8 MQPR:$Qm, MQPR:$Qn)>;
2260-
def : Pat<(v8i16 (ARMvshrsImm (addnsw (addnsw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2261-
(v8i16 (ARMvdup (i32 1)))),
2262-
(i32 1))),
2263-
(MVE_VRHADDs16 MQPR:$Qm, MQPR:$Qn)>;
2264-
def : Pat<(v4i32 (ARMvshrsImm (addnsw (addnsw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2265-
(v4i32 (ARMvdup (i32 1)))),
2266-
(i32 1))),
2267-
(MVE_VRHADDs32 MQPR:$Qm, MQPR:$Qn)>;
2268-
def : Pat<(v16i8 (ARMvshruImm (addnuw (addnuw (v16i8 MQPR:$Qm), (v16i8 MQPR:$Qn)),
2269-
(v16i8 (ARMvdup (i32 1)))),
2270-
(i32 1))),
2271-
(MVE_VRHADDu8 MQPR:$Qm, MQPR:$Qn)>;
2272-
def : Pat<(v8i16 (ARMvshruImm (addnuw (addnuw (v8i16 MQPR:$Qm), (v8i16 MQPR:$Qn)),
2273-
(v8i16 (ARMvdup (i32 1)))),
2274-
(i32 1))),
2275-
(MVE_VRHADDu16 MQPR:$Qm, MQPR:$Qn)>;
2276-
def : Pat<(v4i32 (ARMvshruImm (addnuw (addnuw (v4i32 MQPR:$Qm), (v4i32 MQPR:$Qn)),
2277-
(v4i32 (ARMvdup (i32 1)))),
2278-
(i32 1))),
2279-
(MVE_VRHADDu32 MQPR:$Qm, MQPR:$Qn)>;
2280-
}
2281-
2282-
22832225
class MVE_VHADDSUB<string iname, string suffix, bit U, bit subtract,
22842226
bits<2> size, list<dag> pattern=[]>
22852227
: MVE_int<iname, suffix, size, pattern> {
@@ -2303,8 +2245,7 @@ class MVE_VHSUB_<string suffix, bit U, bits<2> size,
23032245
: MVE_VHADDSUB<"vhsub", suffix, U, 0b1, size, pattern>;
23042246

23052247
multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
2306-
SDPatternOperator unpred_op, Intrinsic PredInt, PatFrag add_op,
2307-
SDNode shift_op> {
2248+
SDPatternOperator unpred_op, Intrinsic PredInt> {
23082249
def "" : MVE_VHADD_<VTI.Suffix, VTI.Unsigned, VTI.Size>;
23092250
defvar Inst = !cast<Instruction>(NAME);
23102251
defm : MVE_TwoOpPattern<VTI, Op, PredInt, (? (i32 VTI.Unsigned)), !cast<Instruction>(NAME)>;
@@ -2313,26 +2254,18 @@ multiclass MVE_VHADD_m<MVEVectorVTInfo VTI, SDNode Op,
23132254
// Unpredicated add-and-divide-by-two
23142255
def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn), (i32 VTI.Unsigned))),
23152256
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;
2316-
2317-
def : Pat<(VTI.Vec (shift_op (add_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)), (i32 1))),
2318-
(Inst MQPR:$Qm, MQPR:$Qn)>;
23192257
}
23202258
}
23212259

2322-
multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op, PatFrag add_op, SDNode shift_op>
2323-
: MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated, add_op,
2324-
shift_op>;
2260+
multiclass MVE_VHADD<MVEVectorVTInfo VTI, SDNode Op>
2261+
: MVE_VHADD_m<VTI, Op, int_arm_mve_vhadd, int_arm_mve_hadd_predicated>;
23252262

2326-
// Halving add/sub perform the arithemtic operation with an extra bit of
2327-
// precision, before performing the shift, to void clipping errors. We're not
2328-
// modelling that here with these patterns, but we're using no wrap forms of
2329-
// add/sub to ensure that the extra bit of information is not needed.
2330-
defm MVE_VHADDs8 : MVE_VHADD<MVE_v16s8, avgfloors, addnsw, ARMvshrsImm>;
2331-
defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors, addnsw, ARMvshrsImm>;
2332-
defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors, addnsw, ARMvshrsImm>;
2333-
defm MVE_VHADDu8 : MVE_VHADD<MVE_v16u8, avgflooru, addnuw, ARMvshruImm>;
2334-
defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru, addnuw, ARMvshruImm>;
2335-
defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru, addnuw, ARMvshruImm>;
2263+
defm MVE_VHADDs8 : MVE_VHADD<MVE_v16s8, avgfloors>;
2264+
defm MVE_VHADDs16 : MVE_VHADD<MVE_v8s16, avgfloors>;
2265+
defm MVE_VHADDs32 : MVE_VHADD<MVE_v4s32, avgfloors>;
2266+
defm MVE_VHADDu8 : MVE_VHADD<MVE_v16u8, avgflooru>;
2267+
defm MVE_VHADDu16 : MVE_VHADD<MVE_v8u16, avgflooru>;
2268+
defm MVE_VHADDu32 : MVE_VHADD<MVE_v4u32, avgflooru>;
23362269

23372270
multiclass MVE_VHSUB_m<MVEVectorVTInfo VTI,
23382271
SDPatternOperator unpred_op, Intrinsic pred_int, PatFrag sub_op,

0 commit comments

Comments
 (0)