48
48
#include " llvm/CodeGen/ValueTypes.h"
49
49
#include " llvm/CodeGenTypes/MachineValueType.h"
50
50
#include " llvm/IR/Constant.h"
51
- #include " llvm/IR/ConstantRange.h"
52
51
#include " llvm/IR/Constants.h"
53
52
#include " llvm/IR/DataLayout.h"
54
53
#include " llvm/IR/DebugInfoMetadata.h"
@@ -3009,102 +3008,117 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
3009
3008
return SDValue ();
3010
3009
}
3011
3010
3012
- const APInt *
3013
- SelectionDAG::getValidShiftAmountConstant (SDValue V,
3014
- const APInt &DemandedElts ) const {
3011
+ std::optional<ConstantRange>
3012
+ SelectionDAG::getValidShiftAmountRange (SDValue V, const APInt &DemandedElts ,
3013
+ unsigned Depth ) const {
3015
3014
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3016
3015
V.getOpcode () == ISD::SRA) &&
3017
3016
" Unknown shift node" );
3017
+ // Shifting more than the bitwidth is not valid.
3018
3018
unsigned BitWidth = V.getScalarValueSizeInBits ();
3019
- if (ConstantSDNode *SA = isConstOrConstSplat (V.getOperand (1 ), DemandedElts)) {
3020
- // Shifting more than the bitwidth is not valid.
3021
- const APInt &ShAmt = SA->getAPIntValue ();
3022
- if (ShAmt.ult (BitWidth))
3023
- return &ShAmt;
3019
+
3020
+ if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand (1 ))) {
3021
+ const APInt &ShAmt = Cst->getAPIntValue ();
3022
+ if (ShAmt.uge (BitWidth))
3023
+ return std::nullopt;
3024
+ return ConstantRange (ShAmt);
3024
3025
}
3025
- return nullptr ;
3026
+
3027
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3028
+ const APInt *MinAmt = nullptr , *MaxAmt = nullptr ;
3029
+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3030
+ if (!DemandedElts[i])
3031
+ continue ;
3032
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3033
+ if (!SA) {
3034
+ MinAmt = MaxAmt = nullptr ;
3035
+ break ;
3036
+ }
3037
+ const APInt &ShAmt = SA->getAPIntValue ();
3038
+ if (ShAmt.uge (BitWidth))
3039
+ return std::nullopt;
3040
+ if (!MinAmt || MinAmt->ugt (ShAmt))
3041
+ MinAmt = &ShAmt;
3042
+ if (!MaxAmt || MaxAmt->ult (ShAmt))
3043
+ MaxAmt = &ShAmt;
3044
+ }
3045
+ assert (((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
3046
+ " Failed to find matching min/max shift amounts" );
3047
+ if (MinAmt && MaxAmt)
3048
+ return ConstantRange (*MinAmt, *MaxAmt + 1 );
3049
+ }
3050
+
3051
+ // Use computeKnownBits to find a hidden constant/knownbits (usually type
3052
+ // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3053
+ KnownBits KnownAmt = computeKnownBits (V.getOperand (1 ), DemandedElts, Depth);
3054
+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3055
+ return ConstantRange::fromKnownBits (KnownAmt, /* IsSigned=*/ false );
3056
+
3057
+ return std::nullopt;
3026
3058
}
3027
3059
3028
- const APInt *SelectionDAG::getValidShiftAmountConstant (SDValue V) const {
3060
+ std::optional<uint64_t >
3061
+ SelectionDAG::getValidShiftAmount (SDValue V, const APInt &DemandedElts,
3062
+ unsigned Depth) const {
3063
+ assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3064
+ V.getOpcode () == ISD::SRA) &&
3065
+ " Unknown shift node" );
3066
+ if (std::optional<ConstantRange> AmtRange =
3067
+ getValidShiftAmountRange (V, DemandedElts, Depth))
3068
+ if (const APInt *ShAmt = AmtRange->getSingleElement ())
3069
+ return ShAmt->getZExtValue ();
3070
+ return std::nullopt;
3071
+ }
3072
+
3073
+ std::optional<uint64_t >
3074
+ SelectionDAG::getValidShiftAmount (SDValue V, unsigned Depth) const {
3029
3075
EVT VT = V.getValueType ();
3030
3076
APInt DemandedElts = VT.isFixedLengthVector ()
3031
3077
? APInt::getAllOnes (VT.getVectorNumElements ())
3032
3078
: APInt (1 , 1 );
3033
- return getValidShiftAmountConstant (V, DemandedElts);
3079
+ return getValidShiftAmount (V, DemandedElts, Depth );
3034
3080
}
3035
3081
3036
- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (
3037
- SDValue V, const APInt &DemandedElts) const {
3082
+ std::optional<uint64_t >
3083
+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, const APInt &DemandedElts,
3084
+ unsigned Depth) const {
3038
3085
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3039
3086
V.getOpcode () == ISD::SRA) &&
3040
3087
" Unknown shift node" );
3041
- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3042
- return ValidAmt;
3043
- unsigned BitWidth = V.getScalarValueSizeInBits ();
3044
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3045
- if (!BV)
3046
- return nullptr ;
3047
- const APInt *MinShAmt = nullptr ;
3048
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3049
- if (!DemandedElts[i])
3050
- continue ;
3051
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3052
- if (!SA)
3053
- return nullptr ;
3054
- // Shifting more than the bitwidth is not valid.
3055
- const APInt &ShAmt = SA->getAPIntValue ();
3056
- if (ShAmt.uge (BitWidth))
3057
- return nullptr ;
3058
- if (MinShAmt && MinShAmt->ule (ShAmt))
3059
- continue ;
3060
- MinShAmt = &ShAmt;
3061
- }
3062
- return MinShAmt;
3088
+ if (std::optional<ConstantRange> AmtRange =
3089
+ getValidShiftAmountRange (V, DemandedElts, Depth))
3090
+ return AmtRange->getUnsignedMin ().getZExtValue ();
3091
+ return std::nullopt;
3063
3092
}
3064
3093
3065
- const APInt *SelectionDAG::getValidMinimumShiftAmountConstant (SDValue V) const {
3094
+ std::optional<uint64_t >
3095
+ SelectionDAG::getValidMinimumShiftAmount (SDValue V, unsigned Depth) const {
3066
3096
EVT VT = V.getValueType ();
3067
3097
APInt DemandedElts = VT.isFixedLengthVector ()
3068
3098
? APInt::getAllOnes (VT.getVectorNumElements ())
3069
3099
: APInt (1 , 1 );
3070
- return getValidMinimumShiftAmountConstant (V, DemandedElts);
3100
+ return getValidMinimumShiftAmount (V, DemandedElts, Depth );
3071
3101
}
3072
3102
3073
- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (
3074
- SDValue V, const APInt &DemandedElts) const {
3103
+ std::optional<uint64_t >
3104
+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, const APInt &DemandedElts,
3105
+ unsigned Depth) const {
3075
3106
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3076
3107
V.getOpcode () == ISD::SRA) &&
3077
3108
" Unknown shift node" );
3078
- if (const APInt *ValidAmt = getValidShiftAmountConstant (V, DemandedElts))
3079
- return ValidAmt;
3080
- unsigned BitWidth = V.getScalarValueSizeInBits ();
3081
- auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ));
3082
- if (!BV)
3083
- return nullptr ;
3084
- const APInt *MaxShAmt = nullptr ;
3085
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3086
- if (!DemandedElts[i])
3087
- continue ;
3088
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3089
- if (!SA)
3090
- return nullptr ;
3091
- // Shifting more than the bitwidth is not valid.
3092
- const APInt &ShAmt = SA->getAPIntValue ();
3093
- if (ShAmt.uge (BitWidth))
3094
- return nullptr ;
3095
- if (MaxShAmt && MaxShAmt->uge (ShAmt))
3096
- continue ;
3097
- MaxShAmt = &ShAmt;
3098
- }
3099
- return MaxShAmt;
3109
+ if (std::optional<ConstantRange> AmtRange =
3110
+ getValidShiftAmountRange (V, DemandedElts, Depth))
3111
+ return AmtRange->getUnsignedMax ().getZExtValue ();
3112
+ return std::nullopt;
3100
3113
}
3101
3114
3102
- const APInt *SelectionDAG::getValidMaximumShiftAmountConstant (SDValue V) const {
3115
+ std::optional<uint64_t >
3116
+ SelectionDAG::getValidMaximumShiftAmount (SDValue V, unsigned Depth) const {
3103
3117
EVT VT = V.getValueType ();
3104
3118
APInt DemandedElts = VT.isFixedLengthVector ()
3105
3119
? APInt::getAllOnes (VT.getVectorNumElements ())
3106
3120
: APInt (1 , 1 );
3107
- return getValidMaximumShiftAmountConstant (V, DemandedElts);
3121
+ return getValidMaximumShiftAmount (V, DemandedElts, Depth );
3108
3122
}
3109
3123
3110
3124
// / Determine which bits of Op are known to be either zero or one and return
@@ -3569,9 +3583,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
3569
3583
Known = KnownBits::shl (Known, Known2, NUW, NSW, ShAmtNonZero);
3570
3584
3571
3585
// Minimum shift low bits are known zero.
3572
- if (const APInt * ShMinAmt =
3573
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3574
- Known.Zero .setLowBits (ShMinAmt-> getZExtValue () );
3586
+ if (std::optional< uint64_t > ShMinAmt =
3587
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth + 1 ))
3588
+ Known.Zero .setLowBits (* ShMinAmt);
3575
3589
break ;
3576
3590
}
3577
3591
case ISD::SRL:
@@ -3581,9 +3595,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
3581
3595
Op->getFlags ().hasExact ());
3582
3596
3583
3597
// Minimum shift high bits are known zero.
3584
- if (const APInt * ShMinAmt =
3585
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
3586
- Known.Zero .setHighBits (ShMinAmt-> getZExtValue () );
3598
+ if (std::optional< uint64_t > ShMinAmt =
3599
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth + 1 ))
3600
+ Known.Zero .setHighBits (* ShMinAmt);
3587
3601
break ;
3588
3602
case ISD::SRA:
3589
3603
Known = computeKnownBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
@@ -4587,17 +4601,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
4587
4601
case ISD::SRA:
4588
4602
Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4589
4603
// SRA X, C -> adds C sign bits.
4590
- if (const APInt * ShAmt =
4591
- getValidMinimumShiftAmountConstant (Op, DemandedElts))
4592
- Tmp = std::min<uint64_t >(Tmp + ShAmt-> getZExtValue () , VTBits);
4604
+ if (std::optional< uint64_t > ShAmt =
4605
+ getValidMinimumShiftAmount (Op, DemandedElts, Depth + 1 ))
4606
+ Tmp = std::min<uint64_t >(Tmp + * ShAmt, VTBits);
4593
4607
return Tmp;
4594
4608
case ISD::SHL:
4595
- if (const APInt * ShAmt =
4596
- getValidMaximumShiftAmountConstant (Op, DemandedElts)) {
4609
+ if (std::optional< uint64_t > ShAmt =
4610
+ getValidMaximumShiftAmount (Op, DemandedElts, Depth + 1 )) {
4597
4611
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
4598
4612
Tmp = ComputeNumSignBits (Op.getOperand (0 ), DemandedElts, Depth + 1 );
4599
- if (ShAmt-> ult ( Tmp) )
4600
- return Tmp - ShAmt-> getZExtValue () ;
4613
+ if (* ShAmt < Tmp)
4614
+ return Tmp - * ShAmt;
4601
4615
}
4602
4616
break ;
4603
4617
case ISD::AND:
@@ -5270,7 +5284,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
5270
5284
case ISD::SRL:
5271
5285
case ISD::SRA:
5272
5286
// If the max shift amount isn't in range, then the shift can create poison.
5273
- return !getValidMaximumShiftAmountConstant (Op, DemandedElts);
5287
+ return !getValidMaximumShiftAmount (Op, DemandedElts, Depth + 1 );
5274
5288
5275
5289
case ISD::SCALAR_TO_VECTOR:
5276
5290
// Check if we demand any upper (undef) elements.
0 commit comments