@@ -2411,9 +2411,10 @@ SDValue SelectionDAG::getSplatValue(SDValue V) {
2411
2411
2412
2412
// / If a SHL/SRA/SRL node has a constant or splat constant shift amount that
2413
2413
// / is less than the element bit-width of the shift node, return it.
2414
- static const APInt *getValidShiftAmountConstant (SDValue V) {
2414
+ static const APInt *getValidShiftAmountConstant (SDValue V,
2415
+ const APInt &DemandedElts) {
2415
2416
unsigned BitWidth = V.getScalarValueSizeInBits ();
2416
- if (ConstantSDNode *SA = isConstOrConstSplat (V.getOperand (1 ))) {
2417
+ if (ConstantSDNode *SA = isConstOrConstSplat (V.getOperand (1 ), DemandedElts )) {
2417
2418
// Shifting more than the bitwidth is not valid.
2418
2419
const APInt &ShAmt = SA->getAPIntValue ();
2419
2420
if (ShAmt.ult (BitWidth))
@@ -2424,13 +2425,16 @@ static const APInt *getValidShiftAmountConstant(SDValue V) {
2424
2425
2425
2426
// / If a SHL/SRA/SRL node has constant vector shift amounts that are all less
2426
2427
// / than the element bit-width of the shift node, return the minimum value.
2427
- static const APInt *getValidMinimumShiftAmountConstant (SDValue V) {
2428
+ static const APInt *
2429
+ getValidMinimumShiftAmountConstant (SDValue V, const APInt &DemandedElts) {
2428
2430
unsigned BitWidth = V.getScalarValueSizeInBits ();
2429
2431
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
2430
2432
if (!BV)
2431
2433
return nullptr ;
2432
2434
const APInt *MinShAmt = nullptr ;
2433
2435
for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
2436
+ if (!DemandedElts[i])
2437
+ continue ;
2434
2438
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
2435
2439
if (!SA)
2436
2440
return nullptr ;
@@ -2827,14 +2831,15 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
2827
2831
break ;
2828
2832
}
2829
2833
case ISD::SHL:
2830
- if (const APInt *ShAmt = getValidShiftAmountConstant (Op)) {
2834
+ if (const APInt *ShAmt = getValidShiftAmountConstant (Op, DemandedElts )) {
2831
2835
Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
2832
2836
unsigned Shift = ShAmt->getZExtValue ();
2833
2837
Known.Zero <<= Shift;
2834
2838
Known.One <<= Shift;
2835
2839
// Low bits are known zero.
2836
2840
Known.Zero .setLowBits (Shift);
2837
- } else if (const APInt *ShMinAmt = getValidMinimumShiftAmountConstant (Op)) {
2841
+ } else if (const APInt *ShMinAmt =
2842
+ getValidMinimumShiftAmountConstant (Op, DemandedElts)) {
2838
2843
// Minimum shift low bits are known zero.
2839
2844
Known.Zero .setLowBits (ShMinAmt->getZExtValue ());
2840
2845
} else {
@@ -2846,14 +2851,15 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
2846
2851
}
2847
2852
break ;
2848
2853
case ISD::SRL:
2849
- if (const APInt *ShAmt = getValidShiftAmountConstant (Op)) {
2854
+ if (const APInt *ShAmt = getValidShiftAmountConstant (Op, DemandedElts )) {
2850
2855
Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
2851
2856
unsigned Shift = ShAmt->getZExtValue ();
2852
2857
Known.Zero .lshrInPlace (Shift);
2853
2858
Known.One .lshrInPlace (Shift);
2854
2859
// High bits are known zero.
2855
2860
Known.Zero .setHighBits (Shift);
2856
- } else if (const APInt *ShMinAmt = getValidMinimumShiftAmountConstant (Op)) {
2861
+ } else if (const APInt *ShMinAmt =
2862
+ getValidMinimumShiftAmountConstant (Op, DemandedElts)) {
2857
2863
// Minimum shift high bits are known zero.
2858
2864
Known.Zero .setHighBits (ShMinAmt->getZExtValue ());
2859
2865
} else {
@@ -2864,7 +2870,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
2864
2870
}
2865
2871
break ;
2866
2872
case ISD::SRA:
2867
- if (const APInt *ShAmt = getValidShiftAmountConstant (Op)) {
2873
+ if (const APInt *ShAmt = getValidShiftAmountConstant (Op, DemandedElts )) {
2868
2874
Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
2869
2875
unsigned Shift = ShAmt->getZExtValue ();
2870
2876
// Sign extend known zero/one bit (else is unknown).
0 commit comments