@@ -47406,10 +47406,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
47406
47406
return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
47407
47407
}
47408
47408
47409
- // fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
47410
- // into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
47411
- // into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
47412
- // depending on sign of (SarConst - [56,48,32,24,16])
47409
+ // fold (SRA (SHL X, ShlConst), SraConst)
47410
+ // into (SHL (sext_in_reg X), ShlConst - SraConst)
47411
+ // or (sext_in_reg X)
47412
+ // or (SRA (sext_in_reg X), SraConst - ShlConst)
47413
+ // depending on relation between SraConst and ShlConst.
47414
+ // We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
47415
+ // us to do the sext_in_reg from corresponding bit.
47413
47416
47414
47417
// sexts in X86 are MOVs. The MOVs have the same code size
47415
47418
// as above SHIFTs (only SHIFT on 1 has lower code size).
@@ -47425,29 +47428,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
47425
47428
SDValue N00 = N0.getOperand(0);
47426
47429
SDValue N01 = N0.getOperand(1);
47427
47430
APInt ShlConst = N01->getAsAPIntVal();
47428
- APInt SarConst = N1->getAsAPIntVal();
47431
+ APInt SraConst = N1->getAsAPIntVal();
47429
47432
EVT CVT = N1.getValueType();
47430
47433
47431
- if (SarConst.isNegative())
47434
+ if (CVT != N01.getValueType())
47435
+ return SDValue();
47436
+ if (SraConst.isNegative())
47432
47437
return SDValue();
47433
47438
47434
47439
for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
47435
47440
unsigned ShiftSize = SVT.getSizeInBits();
47436
- // skipping types without corresponding sext/zext and
47437
- // ShlConst that is not one of [56,48,32,24,16]
47441
+ // Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
47438
47442
if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
47439
47443
continue;
47440
47444
SDLoc DL(N);
47441
47445
SDValue NN =
47442
47446
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
47443
- SarConst = SarConst - (Size - ShiftSize);
47444
- if (SarConst == 0)
47447
+ if (SraConst.eq(ShlConst))
47445
47448
return NN;
47446
- if (SarConst.isNegative( ))
47449
+ if (SraConst.ult(ShlConst ))
47447
47450
return DAG.getNode(ISD::SHL, DL, VT, NN,
47448
- DAG.getConstant(-SarConst , DL, CVT));
47451
+ DAG.getConstant(ShlConst - SraConst , DL, CVT));
47449
47452
return DAG.getNode(ISD::SRA, DL, VT, NN,
47450
- DAG.getConstant(SarConst , DL, CVT));
47453
+ DAG.getConstant(SraConst - ShlConst , DL, CVT));
47451
47454
}
47452
47455
return SDValue();
47453
47456
}
0 commit comments