@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
3009
3009
return SDValue ();
3010
3010
}
3011
3011
3012
- const APInt *
3013
- SelectionDAG::getValidShiftAmountConstant (SDValue V,
3014
- const APInt &DemandedElts ) const {
3012
+ std::optional< uint64_t >
3013
+ SelectionDAG::getValidShiftAmount (SDValue V, const APInt &DemandedElts ,
3014
+ unsigned Depth ) const {
3015
3015
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3016
3016
V.getOpcode () == ISD::SRA) &&
3017
3017
" Unknown shift node" );
@@ -3020,91 +3020,111 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
3020
3020
// Shifting more than the bitwidth is not valid.
3021
3021
const APInt &ShAmt = SA->getAPIntValue ();
3022
3022
if (ShAmt.ult (BitWidth))
3023
- return &ShAmt;
3023
+ return ShAmt.getZExtValue ();
3024
+ } else {
3025
+ KnownBits KnownAmt =
3026
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3027
+ if (KnownAmt.isConstant () && KnownAmt.getConstant ().ult (BitWidth))
3028
+ return KnownAmt.getConstant ().getZExtValue ();
3024
3029
}
3025
- return nullptr ;
3030
+ return std::nullopt ;
3026
3031
}
3027
3032
3028
- const APInt *SelectionDAG::getValidShiftAmountConstant (SDValue V) const {
3033
+ std::optional<uint64_t >
3034
+ SelectionDAG::getValidShiftAmount (SDValue V, unsigned Depth) const {
3029
3035
EVT VT = V.getValueType ();
3030
3036
APInt DemandedElts = VT.isFixedLengthVector ()
3031
3037
? APInt::getAllOnes (VT.getVectorNumElements ())
3032
3038
: APInt (1 , 1 );
3033
- return getValidShiftAmountConstant (V, DemandedElts);
3039
+ return getValidShiftAmount (V, DemandedElts, Depth );
3034
3040
}
3035
3041
3036
- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (
3037
- SDValue V, const APInt &DemandedElts) const {
3042
+ std::optional<uint64_t >
3043
+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, const APInt &DemandedElts,
3044
+ unsigned Depth) const {
3038
3045
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3039
3046
V.getOpcode () == ISD::SRA) &&
3040
3047
" Unknown shift node" );
3041
- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3042
- return ValidAmt;
3043
3048
unsigned BitWidth = V.getScalarValueSizeInBits ();
3044
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3045
- if (!BV)
3046
- return nullptr ;
3047
- const APInt *MinShAmt = nullptr ;
3048
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3049
- if (!DemandedElts[i])
3050
- continue ;
3051
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3052
- if (!SA)
3053
- return nullptr ;
3054
- // Shifting more than the bitwidth is not valid.
3055
- const APInt &ShAmt = SA->getAPIntValue ();
3056
- if (ShAmt.uge (BitWidth))
3057
- return nullptr ;
3058
- if (MinShAmt && MinShAmt->ule (ShAmt))
3059
- continue ;
3060
- MinShAmt = &ShAmt;
3049
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3050
+ const APInt *MinShAmt = nullptr ;
3051
+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3052
+ if (!DemandedElts[i])
3053
+ continue ;
3054
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3055
+ if (!SA) {
3056
+ MinShAmt = nullptr ;
3057
+ break ;
3058
+ }
3059
+ // Shifting more than the bitwidth is not valid.
3060
+ const APInt &ShAmt = SA->getAPIntValue ();
3061
+ if (ShAmt.uge (BitWidth))
3062
+ return std::nullopt;
3063
+ if (MinShAmt && MinShAmt->ule (ShAmt))
3064
+ continue ;
3065
+ MinShAmt = &ShAmt;
3066
+ }
3067
+ if (MinShAmt)
3068
+ return MinShAmt->getZExtValue ();
3061
3069
}
3062
- return MinShAmt;
3070
+ KnownBits KnownAmt =
3071
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3072
+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3073
+ return KnownAmt.getMinValue ().getZExtValue ();
3074
+ return std::nullopt;
3063
3075
}
3064
3076
3065
- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (SDValue V) const {
3077
+ std::optional<uint64_t >
3078
+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, unsigned Depth) const {
3066
3079
EVT VT = V.getValueType ();
3067
3080
APInt DemandedElts = VT.isFixedLengthVector ()
3068
3081
? APInt::getAllOnes (VT.getVectorNumElements ())
3069
3082
: APInt (1 , 1 );
3070
- return getValidMinimumShiftAmountConstant (V, DemandedElts);
3083
+ return getValidMinimumShiftAmount (V, DemandedElts, Depth );
3071
3084
}
3072
3085
3073
- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (
3074
- SDValue V, const APInt &DemandedElts) const {
3086
+ std::optional<uint64_t >
3087
+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, const APInt &DemandedElts,
3088
+ unsigned Depth) const {
3075
3089
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3076
3090
V.getOpcode () == ISD::SRA) &&
3077
3091
" Unknown shift node" );
3078
- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3079
- return ValidAmt;
3080
3092
unsigned BitWidth = V.getScalarValueSizeInBits ();
3081
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3082
- if (!BV)
3083
- return nullptr ;
3084
- const APInt *MaxShAmt = nullptr ;
3085
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3086
- if (!DemandedElts[i])
3087
- continue ;
3088
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3089
- if (!SA)
3090
- return nullptr ;
3091
- // Shifting more than the bitwidth is not valid.
3092
- const APInt &ShAmt = SA->getAPIntValue ();
3093
- if (ShAmt.uge (BitWidth))
3094
- return nullptr ;
3095
- if (MaxShAmt && MaxShAmt->uge (ShAmt))
3096
- continue ;
3097
- MaxShAmt = &ShAmt;
3093
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3094
+ const APInt *MaxShAmt = nullptr ;
3095
+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3096
+ if (!DemandedElts[i])
3097
+ continue ;
3098
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3099
+ if (!SA) {
3100
+ MaxShAmt = nullptr ;
3101
+ break ;
3102
+ }
3103
+ // Shifting more than the bitwidth is not valid.
3104
+ const APInt &ShAmt = SA->getAPIntValue ();
3105
+ if (ShAmt.uge (BitWidth))
3106
+ return std::nullopt;
3107
+ if (MaxShAmt && MaxShAmt->uge (ShAmt))
3108
+ continue ;
3109
+ MaxShAmt = &ShAmt;
3110
+ }
3111
+ if (MaxShAmt)
3112
+ return MaxShAmt->getZExtValue ();
3098
3113
}
3099
- return MaxShAmt;
3114
+ KnownBits KnownAmt =
3115
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3116
+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3117
+ return KnownAmt.getMaxValue ().getZExtValue ();
3118
+ return std::nullopt;
3100
3119
}
3101
3120
3102
- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (SDValue V) const {
3121
+ std::optional<uint64_t >
3122
+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, unsigned Depth) const {
3103
3123
EVT VT = V.getValueType ();
3104
3124
APInt DemandedElts = VT.isFixedLengthVector ()
3105
3125
? APInt::getAllOnes (VT.getVectorNumElements ())
3106
3126
: APInt (1 , 1 );
3107
- return getValidMaximumShiftAmountConstant (V, DemandedElts);
3127
+ return getValidMaximumShiftAmount (V, DemandedElts, Depth );
3108
3128
}
3109
3129
3110
3130
// / Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3589,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
3569
3589
Known = KnownBits::shl (Known, Known2, NUW, NSW, ShAmtNonZero);
3570
3590
3571
3591
// Minimum shift low bits are known zero.
3572
- if (const APInt * ShMinAmt =
3573
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3574
- Known.Zero .setLowBits (ShMinAmt-> getZExtValue () );
3592
+ if (std::optional< uint64_t > ShMinAmt =
3593
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth ))
3594
+ Known.Zero .setLowBits (* ShMinAmt);
3575
3595
break ;
3576
3596
}
3577
3597
case ISD::SRL:
@@ -3581,9 +3601,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
3581
3601
Op->getFlags ().hasExact ());
3582
3602
3583
3603
// Minimum shift high bits are known zero.
3584
- if (const APInt * ShMinAmt =
3585
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3586
- Known.Zero .setHighBits (ShMinAmt-> getZExtValue () );
3604
+ if (std::optional< uint64_t > ShMinAmt =
3605
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth ))
3606
+ Known.Zero .setHighBits (* ShMinAmt);
3587
3607
break ;
3588
3608
case ISD::SRA:
3589
3609
Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
@@ -4587,17 +4607,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
4587
4607
case ISD::SRA:
4588
4608
Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4589
4609
// SRA X, C -> adds C sign bits.
4590
- if (const APInt * ShAmt =
4591
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
4592
- Tmp = std::min<uint64_t >(Tmp + ShAmt-> getZExtValue () , VTBits);
4610
+ if (std::optional< uint64_t > ShAmt =
4611
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth ))
4612
+ Tmp = std::min<uint64_t >(Tmp + * ShAmt, VTBits);
4593
4613
return Tmp;
4594
4614
case ISD::SHL:
4595
- if (const APInt * ShAmt =
4596
- getValidMaximumShiftAmountConstant (Op, DemandedElts)) {
4615
+ if (std::optional< uint64_t > ShAmt =
4616
+ getValidMaximumShiftAmount (Op, DemandedElts, Depth )) {
4597
4617
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
4598
4618
Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4599
- if (ShAmt-> ult ( Tmp) )
4600
- return Tmp - ShAmt-> getZExtValue () ;
4619
+ if (* ShAmt < Tmp)
4620
+ return Tmp - * ShAmt;
4601
4621
}
4602
4622
break ;
4603
4623
case ISD::AND:
@@ -5270,7 +5290,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
5270
5290
case ISD::SRL:
5271
5291
case ISD::SRA:
5272
5292
// If the max shift amount isn't in range, then the shift can create poison.
5273
- return !getValidMaximumShiftAmountConstant (Op, DemandedElts);
5293
+ return !getValidMaximumShiftAmount (Op, DemandedElts, Depth );
5274
5294
5275
5295
case ISD::SCALAR_TO_VECTOR:
5276
5296
// Check if we demand any upper (undef) elements.
0 commit comments