@@ -2529,6 +2529,43 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, SelectionDAG &DAG) {
2529
2529
return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
2530
2530
}
2531
2531
2532
+ // Attempt to form avgceilu(A, B) from sub(or(A, B), lshr(xor(A, B), 1))
2533
+ static SDValue combineFixedwidthToAVGCEILU(SDNode *N, SelectionDAG &DAG) {
2534
+ assert(N->getOpcode() == ISD::SUB and "SUB node is required here");
2535
+ SDValue Or = N->getOperand(0);
2536
+ SDValue Lshr = N->getOperand(1);
2537
+ if (Or.getOpcode() != ISD::OR or Lshr.getOpcode() != ISD::SRL)
2538
+ return SDValue();
2539
+ SDValue Xor = Lshr.getOperand(0);
2540
+ if (Xor.getOpcode() != ISD::XOR)
2541
+ return SDValue();
2542
+ SDValue Or1 = Or.getOperand(0);
2543
+ SDValue Or2 = Or.getOperand(1);
2544
+ SDValue Xor1 = Xor.getOperand(0);
2545
+ SDValue Xor2 = Xor.getOperand(1);
2546
+ if (Or1 == Xor2 and Or2 == Xor1) {
2547
+ SDValue temp = Or1;
2548
+ Or1 = Or2;
2549
+ Or2 = temp;
2550
+ } else if (Or1 != Xor1 or Or2 != Xor2)
2551
+ return SDValue();
2552
+ // Is the right shift using an immediate value of 1?
2553
+ ConstantSDNode *N1C = isConstOrConstSplat(Lshr.getOperand(1));
2554
+ if (!N1C or N1C->getAPIntValue() != 1)
2555
+ return SDValue();
2556
+ EVT VT = Or1.getValueType();
2557
+ EVT NVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits());
2558
+ if (VT.isVector())
2559
+ VT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
2560
+ else
2561
+ VT = NVT;
2562
+ SDLoc DL(N);
2563
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2564
+ if (!TLI.isOperationLegalOrCustom(ISD::AVGCEILU, VT))
2565
+ return SDValue();
2566
+ return DAG.getNode(ISD::AVGCEILU, DL, VT, Or1, Or2);
2567
+ }
2568
+
2532
2569
/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
2533
2570
/// a shift and add with a different constant.
2534
2571
static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
@@ -3859,6 +3896,10 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
3859
3896
if (SDValue V = foldAddSubOfSignBit(N, DAG))
3860
3897
return V;
3861
3898
3899
+ // Try to match AVGCEILU fixedwidth pattern
3900
+ if (SDValue V = combineFixedwidthToAVGCEILU(N, DAG))
3901
+ return V;
3902
+
3862
3903
if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
3863
3904
return V;
3864
3905
0 commit comments