48
48
#include " llvm/CodeGen/ValueTypes.h"
49
49
#include " llvm/CodeGenTypes/MachineValueType.h"
50
50
#include " llvm/IR/Constant.h"
51
- #include " llvm/IR/ConstantRange.h"
52
51
#include " llvm/IR/Constants.h"
53
52
#include " llvm/IR/DataLayout.h"
54
53
#include " llvm/IR/DebugInfoMetadata.h"
@@ -3009,26 +3008,66 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
3009
3008
return SDValue ();
3010
3009
}
3011
3010
3011
+ std::optional<ConstantRange>
3012
+ SelectionDAG::getValidShiftAmountRange (SDValue V, const APInt &DemandedElts,
3013
+ unsigned Depth) const {
3014
+ assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3015
+ V.getOpcode () == ISD::SRA) &&
3016
+ " Unknown shift node" );
3017
+ // Shifting more than the bitwidth is not valid.
3018
+ unsigned BitWidth = V.getScalarValueSizeInBits ();
3019
+
3020
+ if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand (1 ))) {
3021
+ const APInt &ShAmt = Cst->getAPIntValue ();
3022
+ if (ShAmt.uge (BitWidth))
3023
+ return std::nullopt;
3024
+ return ConstantRange (ShAmt);
3025
+ }
3026
+
3027
+ if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3028
+ const APInt *MinAmt = nullptr , *MaxAmt = nullptr ;
3029
+ for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3030
+ if (!DemandedElts[i])
3031
+ continue ;
3032
+ auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3033
+ if (!SA) {
3034
+ MinAmt = MaxAmt = nullptr ;
3035
+ break ;
3036
+ }
3037
+ const APInt &ShAmt = SA->getAPIntValue ();
3038
+ if (ShAmt.uge (BitWidth))
3039
+ return std::nullopt;
3040
+ if (!MinAmt || MinAmt->ugt (ShAmt))
3041
+ MinAmt = &ShAmt;
3042
+ if (!MaxAmt || MaxAmt->ult (ShAmt))
3043
+ MaxAmt = &ShAmt;
3044
+ }
3045
+ assert (((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
3046
+ " Failed to find matching min/max shift amounts" );
3047
+ if (MinAmt && MaxAmt)
3048
+ return ConstantRange (*MinAmt, *MaxAmt);
3049
+ }
3050
+
3051
+ // Use computeKnownBits to find a hidden constant/knownbits (usually type
3052
+ // legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3053
+ KnownBits KnownAmt =
3054
+ computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3055
+ if (KnownAmt.getMaxValue ().ult (BitWidth))
3056
+ return ConstantRange::fromKnownBits (KnownAmt, /* IsSigned=*/ false );
3057
+
3058
+ return std::nullopt;
3059
+ }
3060
+
3012
3061
std::optional<uint64_t >
3013
3062
SelectionDAG::getValidShiftAmount (SDValue V, const APInt &DemandedElts,
3014
3063
unsigned Depth) const {
3015
3064
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3016
3065
V.getOpcode () == ISD::SRA) &&
3017
3066
" Unknown shift node" );
3018
- unsigned BitWidth = V.getScalarValueSizeInBits ();
3019
- if (ConstantSDNode *SA = isConstOrConstSplat (V.getOperand (1 ), DemandedElts)) {
3020
- // Shifting more than the bitwidth is not valid.
3021
- const APInt &ShAmt = SA->getAPIntValue ();
3022
- if (ShAmt.ult (BitWidth))
3023
- return ShAmt.getZExtValue ();
3024
- } else {
3025
- // Use computeKnownBits to find a hidden constant (usually type legalized).
3026
- // e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3027
- KnownBits KnownAmt =
3028
- computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3029
- if (KnownAmt.isConstant () && KnownAmt.getConstant ().ult (BitWidth))
3030
- return KnownAmt.getConstant ().getZExtValue ();
3031
- }
3067
+ if (std::optional<ConstantRange> AmtRange =
3068
+ getValidShiftAmountRange (V, DemandedElts, Depth))
3069
+ if (const APInt *ShAmt = AmtRange->getSingleElement ())
3070
+ return ShAmt->getZExtValue ();
3032
3071
return std::nullopt;
3033
3072
}
3034
3073
@@ -3047,32 +3086,9 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
3047
3086
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3048
3087
V.getOpcode () == ISD::SRA) &&
3049
3088
" Unknown shift node" );
3050
- unsigned BitWidth = V.getScalarValueSizeInBits ();
3051
- if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3052
- const APInt *MinShAmt = nullptr ;
3053
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3054
- if (!DemandedElts[i])
3055
- continue ;
3056
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3057
- if (!SA) {
3058
- MinShAmt = nullptr ;
3059
- break ;
3060
- }
3061
- // Shifting more than the bitwidth is not valid.
3062
- const APInt &ShAmt = SA->getAPIntValue ();
3063
- if (ShAmt.uge (BitWidth))
3064
- return std::nullopt;
3065
- if (MinShAmt && MinShAmt->ule (ShAmt))
3066
- continue ;
3067
- MinShAmt = &ShAmt;
3068
- }
3069
- if (MinShAmt)
3070
- return MinShAmt->getZExtValue ();
3071
- }
3072
- KnownBits KnownAmt =
3073
- computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3074
- if (KnownAmt.getMaxValue ().ult (BitWidth))
3075
- return KnownAmt.getMinValue ().getZExtValue ();
3089
+ if (std::optional<ConstantRange> AmtRange =
3090
+ getValidShiftAmountRange (V, DemandedElts, Depth))
3091
+ return AmtRange->getUnsignedMin ().getZExtValue ();
3076
3092
return std::nullopt;
3077
3093
}
3078
3094
@@ -3091,32 +3107,9 @@ SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
3091
3107
assert ((V.getOpcode () == ISD::SHL || V.getOpcode () == ISD::SRL ||
3092
3108
V.getOpcode () == ISD::SRA) &&
3093
3109
" Unknown shift node" );
3094
- unsigned BitWidth = V.getScalarValueSizeInBits ();
3095
- if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand (1 ))) {
3096
- const APInt *MaxShAmt = nullptr ;
3097
- for (unsigned i = 0 , e = BV->getNumOperands (); i != e; ++i) {
3098
- if (!DemandedElts[i])
3099
- continue ;
3100
- auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand (i));
3101
- if (!SA) {
3102
- MaxShAmt = nullptr ;
3103
- break ;
3104
- }
3105
- // Shifting more than the bitwidth is not valid.
3106
- const APInt &ShAmt = SA->getAPIntValue ();
3107
- if (ShAmt.uge (BitWidth))
3108
- return std::nullopt;
3109
- if (MaxShAmt && MaxShAmt->uge (ShAmt))
3110
- continue ;
3111
- MaxShAmt = &ShAmt;
3112
- }
3113
- if (MaxShAmt)
3114
- return MaxShAmt->getZExtValue ();
3115
- }
3116
- KnownBits KnownAmt =
3117
- computeKnownBits (V.getOperand (1 ), DemandedElts, Depth + 1 );
3118
- if (KnownAmt.getMaxValue ().ult (BitWidth))
3119
- return KnownAmt.getMaxValue ().getZExtValue ();
3110
+ if (std::optional<ConstantRange> AmtRange =
3111
+ getValidShiftAmountRange (V, DemandedElts, Depth))
3112
+ return AmtRange->getUnsignedMax ().getZExtValue ();
3120
3113
return std::nullopt;
3121
3114
}
3122
3115
0 commit comments