Skip to content

Commit 3879271

Browse files
committed
[DAG] Add getValidShiftAmountRange to determine the range of valid shift amount values.
1 parent dc6b14f commit 3879271

File tree

2 files changed

+67
-67
lines changed

2 files changed

+67
-67
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/CodeGen/SelectionDAGNodes.h"
3333
#include "llvm/CodeGen/ValueTypes.h"
3434
#include "llvm/CodeGenTypes/MachineValueType.h"
35+
#include "llvm/IR/ConstantRange.h"
3536
#include "llvm/IR/DebugLoc.h"
3637
#include "llvm/IR/Metadata.h"
3738
#include "llvm/Support/Allocator.h"
@@ -2159,6 +2160,12 @@ class SelectionDAG {
21592160
/// splatted value it will return SDValue().
21602161
SDValue getSplatValue(SDValue V, bool LegalTypes = false);
21612162

2163+
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
2164+
/// element bit-width of the shift node, return the valid constant range.
2165+
std::optional<ConstantRange>
2166+
getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
2167+
unsigned Depth) const;
2168+
21622169
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
21632170
/// that is less than the element bit-width of the shift node, return it.
21642171
std::optional<uint64_t> getValidShiftAmount(SDValue V,

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 60 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
#include "llvm/CodeGen/ValueTypes.h"
4949
#include "llvm/CodeGenTypes/MachineValueType.h"
5050
#include "llvm/IR/Constant.h"
51-
#include "llvm/IR/ConstantRange.h"
5251
#include "llvm/IR/Constants.h"
5352
#include "llvm/IR/DataLayout.h"
5453
#include "llvm/IR/DebugInfoMetadata.h"
@@ -3009,26 +3008,66 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
30093008
return SDValue();
30103009
}
30113010

3011+
std::optional<ConstantRange>
3012+
SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
3013+
unsigned Depth) const {
3014+
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
3015+
V.getOpcode() == ISD::SRA) &&
3016+
"Unknown shift node");
3017+
// Shifting more than the bitwidth is not valid.
3018+
unsigned BitWidth = V.getScalarValueSizeInBits();
3019+
3020+
if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
3021+
const APInt &ShAmt = Cst->getAPIntValue();
3022+
if (ShAmt.uge(BitWidth))
3023+
return std::nullopt;
3024+
return ConstantRange(ShAmt);
3025+
}
3026+
3027+
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3028+
const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
3029+
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3030+
if (!DemandedElts[i])
3031+
continue;
3032+
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3033+
if (!SA) {
3034+
MinAmt = MaxAmt = nullptr;
3035+
break;
3036+
}
3037+
const APInt &ShAmt = SA->getAPIntValue();
3038+
if (ShAmt.uge(BitWidth))
3039+
return std::nullopt;
3040+
if (!MinAmt || MinAmt->ugt(ShAmt))
3041+
MinAmt = &ShAmt;
3042+
if (!MaxAmt || MaxAmt->ult(ShAmt))
3043+
MaxAmt = &ShAmt;
3044+
}
3045+
assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
3046+
"Failed to find matching min/max shift amounts");
3047+
if (MinAmt && MaxAmt)
3048+
return ConstantRange(*MinAmt, *MaxAmt);
3049+
}
3050+
3051+
// Use computeKnownBits to find a hidden constant/knownbits (usually type
3052+
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3053+
KnownBits KnownAmt =
3054+
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3055+
if (KnownAmt.getMaxValue().ult(BitWidth))
3056+
return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);
3057+
3058+
return std::nullopt;
3059+
}
3060+
30123061
std::optional<uint64_t>
30133062
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
30143063
unsigned Depth) const {
30153064
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30163065
V.getOpcode() == ISD::SRA) &&
30173066
"Unknown shift node");
3018-
unsigned BitWidth = V.getScalarValueSizeInBits();
3019-
if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
3020-
// Shifting more than the bitwidth is not valid.
3021-
const APInt &ShAmt = SA->getAPIntValue();
3022-
if (ShAmt.ult(BitWidth))
3023-
return ShAmt.getZExtValue();
3024-
} else {
3025-
// Use computeKnownBits to find a hidden constant (usually type legalized).
3026-
// e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
3027-
KnownBits KnownAmt =
3028-
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3029-
if (KnownAmt.isConstant() && KnownAmt.getConstant().ult(BitWidth))
3030-
return KnownAmt.getConstant().getZExtValue();
3031-
}
3067+
if (std::optional<ConstantRange> AmtRange =
3068+
getValidShiftAmountRange(V, DemandedElts, Depth))
3069+
if (const APInt *ShAmt = AmtRange->getSingleElement())
3070+
return ShAmt->getZExtValue();
30323071
return std::nullopt;
30333072
}
30343073

@@ -3047,32 +3086,9 @@ SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
30473086
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30483087
V.getOpcode() == ISD::SRA) &&
30493088
"Unknown shift node");
3050-
unsigned BitWidth = V.getScalarValueSizeInBits();
3051-
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3052-
const APInt *MinShAmt = nullptr;
3053-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3054-
if (!DemandedElts[i])
3055-
continue;
3056-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3057-
if (!SA) {
3058-
MinShAmt = nullptr;
3059-
break;
3060-
}
3061-
// Shifting more than the bitwidth is not valid.
3062-
const APInt &ShAmt = SA->getAPIntValue();
3063-
if (ShAmt.uge(BitWidth))
3064-
return std::nullopt;
3065-
if (MinShAmt && MinShAmt->ule(ShAmt))
3066-
continue;
3067-
MinShAmt = &ShAmt;
3068-
}
3069-
if (MinShAmt)
3070-
return MinShAmt->getZExtValue();
3071-
}
3072-
KnownBits KnownAmt =
3073-
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3074-
if (KnownAmt.getMaxValue().ult(BitWidth))
3075-
return KnownAmt.getMinValue().getZExtValue();
3089+
if (std::optional<ConstantRange> AmtRange =
3090+
getValidShiftAmountRange(V, DemandedElts, Depth))
3091+
return AmtRange->getUnsignedMin().getZExtValue();
30763092
return std::nullopt;
30773093
}
30783094

@@ -3091,32 +3107,9 @@ SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
30913107
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
30923108
V.getOpcode() == ISD::SRA) &&
30933109
"Unknown shift node");
3094-
unsigned BitWidth = V.getScalarValueSizeInBits();
3095-
if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
3096-
const APInt *MaxShAmt = nullptr;
3097-
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
3098-
if (!DemandedElts[i])
3099-
continue;
3100-
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
3101-
if (!SA) {
3102-
MaxShAmt = nullptr;
3103-
break;
3104-
}
3105-
// Shifting more than the bitwidth is not valid.
3106-
const APInt &ShAmt = SA->getAPIntValue();
3107-
if (ShAmt.uge(BitWidth))
3108-
return std::nullopt;
3109-
if (MaxShAmt && MaxShAmt->uge(ShAmt))
3110-
continue;
3111-
MaxShAmt = &ShAmt;
3112-
}
3113-
if (MaxShAmt)
3114-
return MaxShAmt->getZExtValue();
3115-
}
3116-
KnownBits KnownAmt =
3117-
computeKnownBits(V.getOperand(1), DemandedElts, Depth + 1);
3118-
if (KnownAmt.getMaxValue().ult(BitWidth))
3119-
return KnownAmt.getMaxValue().getZExtValue();
3110+
if (std::optional<ConstantRange> AmtRange =
3111+
getValidShiftAmountRange(V, DemandedElts, Depth))
3112+
return AmtRange->getUnsignedMax().getZExtValue();
31203113
return std::nullopt;
31213114
}
31223115

0 commit comments

Comments
 (0)