@@ -3009,9 +3009,9 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
3009
3009
return SDValue ();
3010
3010
}
3011
3011
3012
- const APInt *
3013
- SelectionDAG::getValidShiftAmountConstant (SDValue V,
3014
- const APInt &DemandedElts ) const {
3012
+ std::optional< uint64_t >
3013
+ SelectionDAG::getValidShiftAmount (SDValue V, const APInt &DemandedElts ,
3014
+ unsigned Depth ) const {
3015
3015
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3016
3016
V.getOpcode () == ISD::SRA) &&
3017
3017
" Unknown shift node" );
@@ -3020,91 +3020,113 @@ SelectionDAG::getValidShiftAmountConstant(SDValue V,
3020
3020
// Shifting more than the bitwidth is not valid.
3021
3021
const APInt &ShAmt = SA->getAPIntValue ();
3022
3022
if (ShAmt.ult (BitWidth))
3023
- return &ShAmt;
3023
+ return ShAmt.getZExtValue ();
3024
+ } else {
3025
+ // Use computeKnownBits to find a hidden constant (usually type legalized).
3026
+ // e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3027
+ KnownBits KnownAmt =
3028
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3029
+ if (KnownAmt.isConstant () && KnownAmt.getConstant ().ult (BitWidth))
3030
+ return KnownAmt.getConstant ().getZExtValue ();
3024
3031
}
3025
- return nullptr ;
3032
+ return std::nullopt ;
3026
3033
}
3027
3034
3028
- const APInt *SelectionDAG::getValidShiftAmountConstant (SDValue V) const {
3035
+ std::optional<uint64_t >
3036
+ SelectionDAG::getValidShiftAmount (SDValue V, unsigned Depth) const {
3029
3037
EVT VT = V.getValueType ();
3030
3038
APInt DemandedElts = VT.isFixedLengthVector ()
3031
3039
? APInt::getAllOnes (VT.getVectorNumElements ())
3032
3040
: APInt (1 , 1 );
3033
- return getValidShiftAmountConstant (V, DemandedElts);
3041
+ return getValidShiftAmount (V, DemandedElts, Depth );
3034
3042
}
3035
3043
3036
- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (
3037
- SDValue V, const APInt &DemandedElts) const {
3044
+ std::optional<uint64_t >
3045
+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, const APInt &DemandedElts,
3046
+ unsigned Depth) const {
3038
3047
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3039
3048
V.getOpcode () == ISD::SRA) &&
3040
3049
" Unknown shift node" );
3041
- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3042
- return ValidAmt;
3043
3050
unsigned BitWidth = V.getScalarValueSizeInBits ();
3044
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3045
- if (!BV)
3046
- return nullptr ;
3047
- const APInt *MinShAmt = nullptr ;
3048
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3049
- if (!DemandedElts[i])
3050
- continue ;
3051
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3052
- if (!SA)
3053
- return nullptr ;
3054
- // Shifting more than the bitwidth is not valid.
3055
- const APInt &ShAmt = SA->getAPIntValue ();
3056
- if (ShAmt.uge (BitWidth))
3057
- return nullptr ;
3058
- if (MinShAmt && MinShAmt->ule (ShAmt))
3059
- continue ;
3060
- MinShAmt = &ShAmt;
3051
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3052
+ const APInt *MinShAmt = nullptr ;
3053
+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3054
+ if (!DemandedElts[i])
3055
+ continue ;
3056
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3057
+ if (!SA) {
3058
+ MinShAmt = nullptr ;
3059
+ break ;
3060
+ }
3061
+ // Shifting more than the bitwidth is not valid.
3062
+ const APInt &ShAmt = SA->getAPIntValue ();
3063
+ if (ShAmt.uge (BitWidth))
3064
+ return std::nullopt;
3065
+ if (MinShAmt && MinShAmt->ule (ShAmt))
3066
+ continue ;
3067
+ MinShAmt = &ShAmt;
3068
+ }
3069
+ if (MinShAmt)
3070
+ return MinShAmt->getZExtValue ();
3061
3071
}
3062
- return MinShAmt;
3072
+ KnownBits KnownAmt =
3073
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3074
+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3075
+ return KnownAmt.getMinValue ().getZExtValue ();
3076
+ return std::nullopt;
3063
3077
}
3064
3078
3065
- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (SDValue V) const {
3079
+ std::optional<uint64_t >
3080
+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, unsigned Depth) const {
3066
3081
EVT VT = V.getValueType ();
3067
3082
APInt DemandedElts = VT.isFixedLengthVector ()
3068
3083
? APInt::getAllOnes (VT.getVectorNumElements ())
3069
3084
: APInt (1 , 1 );
3070
- return getValidMinimumShiftAmountConstant (V, DemandedElts);
3085
+ return getValidMinimumShiftAmount (V, DemandedElts, Depth );
3071
3086
}
3072
3087
3073
- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (
3074
- SDValue V, const APInt &DemandedElts) const {
3088
+ std::optional<uint64_t >
3089
+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, const APInt &DemandedElts,
3090
+ unsigned Depth) const {
3075
3091
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3076
3092
V.getOpcode () == ISD::SRA) &&
3077
3093
" Unknown shift node" );
3078
- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3079
- return ValidAmt;
3080
3094
unsigned BitWidth = V.getScalarValueSizeInBits ();
3081
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3082
- if (!BV)
3083
- return nullptr ;
3084
- const APInt *MaxShAmt = nullptr ;
3085
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3086
- if (!DemandedElts[i])
3087
- continue ;
3088
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3089
- if (!SA)
3090
- return nullptr ;
3091
- // Shifting more than the bitwidth is not valid.
3092
- const APInt &ShAmt = SA->getAPIntValue ();
3093
- if (ShAmt.uge (BitWidth))
3094
- return nullptr ;
3095
- if (MaxShAmt && MaxShAmt->uge (ShAmt))
3096
- continue ;
3097
- MaxShAmt = &ShAmt;
3095
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3096
+ const APInt *MaxShAmt = nullptr ;
3097
+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3098
+ if (!DemandedElts[i])
3099
+ continue ;
3100
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3101
+ if (!SA) {
3102
+ MaxShAmt = nullptr ;
3103
+ break ;
3104
+ }
3105
+ // Shifting more than the bitwidth is not valid.
3106
+ const APInt &ShAmt = SA->getAPIntValue ();
3107
+ if (ShAmt.uge (BitWidth))
3108
+ return std::nullopt;
3109
+ if (MaxShAmt && MaxShAmt->uge (ShAmt))
3110
+ continue ;
3111
+ MaxShAmt = &ShAmt;
3112
+ }
3113
+ if (MaxShAmt)
3114
+ return MaxShAmt->getZExtValue ();
3098
3115
}
3099
- return MaxShAmt;
3116
+ KnownBits KnownAmt =
3117
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3118
+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3119
+ return KnownAmt.getMaxValue ().getZExtValue ();
3120
+ return std::nullopt;
3100
3121
}
3101
3122
3102
- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (SDValue V) const {
3123
+ std::optional<uint64_t >
3124
+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, unsigned Depth) const {
3103
3125
EVT VT = V.getValueType ();
3104
3126
APInt DemandedElts = VT.isFixedLengthVector ()
3105
3127
? APInt::getAllOnes (VT.getVectorNumElements ())
3106
3128
: APInt (1 , 1 );
3107
- return getValidMaximumShiftAmountConstant (V, DemandedElts);
3129
+ return getValidMaximumShiftAmount (V, DemandedElts, Depth );
3108
3130
}
3109
3131
3110
3132
// / Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3591,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
3569
3591
Known = KnownBits::shl (Known, Known2, NUW, NSW, ShAmtNonZero);
3570
3592
3571
3593
// Minimum shift low bits are known zero.
3572
- if (const APInt * ShMinAmt =
3573
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3574
- Known.Zero .setLowBits (ShMinAmt-> getZExtValue () );
3594
+ if (std::optional< uint64_t > ShMinAmt =
3595
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth ))
3596
+ Known.Zero .setLowBits (* ShMinAmt);
3575
3597
break ;
3576
3598
}
3577
3599
case ISD::SRL:
@@ -3581,9 +3603,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
3581
3603
Op->getFlags ().hasExact ());
3582
3604
3583
3605
// Minimum shift high bits are known zero.
3584
- if (const APInt * ShMinAmt =
3585
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3586
- Known.Zero .setHighBits (ShMinAmt-> getZExtValue () );
3606
+ if (std::optional< uint64_t > ShMinAmt =
3607
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth ))
3608
+ Known.Zero .setHighBits (* ShMinAmt);
3587
3609
break ;
3588
3610
case ISD::SRA:
3589
3611
Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
@@ -4587,17 +4609,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
4587
4609
case ISD::SRA:
4588
4610
Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4589
4611
// SRA X, C -> adds C sign bits.
4590
- if (const APInt * ShAmt =
4591
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
4592
- Tmp = std::min<uint64_t >(Tmp + ShAmt-> getZExtValue () , VTBits);
4612
+ if (std::optional< uint64_t > ShAmt =
4613
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth ))
4614
+ Tmp = std::min<uint64_t >(Tmp + * ShAmt, VTBits);
4593
4615
return Tmp;
4594
4616
case ISD::SHL:
4595
- if (const APInt * ShAmt =
4596
- getValidMaximumShiftAmountConstant (Op, DemandedElts)) {
4617
+ if (std::optional< uint64_t > ShAmt =
4618
+ getValidMaximumShiftAmount (Op, DemandedElts, Depth )) {
4597
4619
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
4598
4620
Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4599
- if (ShAmt-> ult ( Tmp) )
4600
- return Tmp - ShAmt-> getZExtValue () ;
4621
+ if (* ShAmt < Tmp)
4622
+ return Tmp - * ShAmt;
4601
4623
}
4602
4624
break ;
4603
4625
case ISD::AND:
@@ -5270,7 +5292,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
5270
5292
case ISD::SRL:
5271
5293
case ISD::SRA:
5272
5294
// If the max shift amount isn't in range, then the shift can create poison.
5273
- return !getValidMaximumShiftAmountConstant (Op, DemandedElts);
5295
+ return !getValidMaximumShiftAmount (Op, DemandedElts, Depth );
5274
5296
5275
5297
case ISD::SCALAR_TO_VECTOR:
5276
5298
// Check if we demand any upper (undef) elements.
0 commit comments