Skip to content

[DAG] Replace getValid*ShiftAmountConstant helpers with getValid*ShiftAmount helpers to support KnownBits analysis #93182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/Metadata.h"
#include "llvm/Support/Allocator.h"
Expand Down Expand Up @@ -2159,36 +2160,44 @@ class SelectionDAG {
/// splatted value it will return SDValue().
SDValue getSplatValue(SDValue V, bool LegalTypes = false);

/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the valid constant range.
std::optional<ConstantRange>
getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
unsigned Depth) const;

/// If a SHL/SRA/SRL node \p V has a uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
const APInt *getValidShiftAmountConstant(SDValue V,
const APInt &DemandedElts) const;
std::optional<uint64_t> getValidShiftAmount(SDValue V,
const APInt &DemandedElts,
unsigned Depth = 0) const;

/// If a SHL/SRA/SRL node \p V has a constant or splat constant shift amount
/// If a SHL/SRA/SRL node \p V has a uniform shift amount
/// that is less than the element bit-width of the shift node, return it.
const APInt *getValidShiftAmountConstant(SDValue V) const;

/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
/// than the element bit-width of the shift node, return the minimum value.
const APInt *
getValidMinimumShiftAmountConstant(SDValue V,
const APInt &DemandedElts) const;

/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
/// than the element bit-width of the shift node, return the minimum value.
const APInt *
getValidMinimumShiftAmountConstant(SDValue V) const;

/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
/// than the element bit-width of the shift node, return the maximum value.
const APInt *
getValidMaximumShiftAmountConstant(SDValue V,
const APInt &DemandedElts) const;

/// If a SHL/SRA/SRL node \p V has constant shift amounts that are all less
/// than the element bit-width of the shift node, return the maximum value.
const APInt *
getValidMaximumShiftAmountConstant(SDValue V) const;
std::optional<uint64_t> getValidShiftAmount(SDValue V,
unsigned Depth = 0) const;

/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the minimum possible value.
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
const APInt &DemandedElts,
unsigned Depth = 0) const;

/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the minimum possible value.
std::optional<uint64_t> getValidMinimumShiftAmount(SDValue V,
unsigned Depth = 0) const;

/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the maximum possible value.
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
const APInt &DemandedElts,
unsigned Depth = 0) const;

/// If a SHL/SRA/SRL node \p V has shift amounts that are all less than the
/// element bit-width of the shift node, return the maximum possible value.
std::optional<uint64_t> getValidMaximumShiftAmount(SDValue V,
unsigned Depth = 0) const;

/// Match a binop + shuffle pyramid that represents a horizontal reduction
/// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p
Expand Down
170 changes: 92 additions & 78 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
Expand Down Expand Up @@ -3009,102 +3008,117 @@ SDValue SelectionDAG::getSplatValue(SDValue V, bool LegalTypes) {
return SDValue();
}

const APInt *
SelectionDAG::getValidShiftAmountConstant(SDValue V,
const APInt &DemandedElts) const {
std::optional<ConstantRange>
SelectionDAG::getValidShiftAmountRange(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
// Shifting more than the bitwidth is not valid.
unsigned BitWidth = V.getScalarValueSizeInBits();
if (ConstantSDNode *SA = isConstOrConstSplat(V.getOperand(1), DemandedElts)) {
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.ult(BitWidth))
return &ShAmt;

if (auto *Cst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
const APInt &ShAmt = Cst->getAPIntValue();
if (ShAmt.uge(BitWidth))
return std::nullopt;
return ConstantRange(ShAmt);
}
return nullptr;

if (auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1))) {
const APInt *MinAmt = nullptr, *MaxAmt = nullptr;
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
if (!DemandedElts[i])
continue;
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
if (!SA) {
MinAmt = MaxAmt = nullptr;
break;
}
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.uge(BitWidth))
return std::nullopt;
if (!MinAmt || MinAmt->ugt(ShAmt))
MinAmt = &ShAmt;
if (!MaxAmt || MaxAmt->ult(ShAmt))
MaxAmt = &ShAmt;
}
assert(((!MinAmt && !MaxAmt) || (MinAmt && MaxAmt)) &&
"Failed to find matching min/max shift amounts");
if (MinAmt && MaxAmt)
return ConstantRange(*MinAmt, *MaxAmt + 1);
}

// Use computeKnownBits to find a hidden constant/knownbits (usually type
// legalized). e.g. Hidden behind multiple bitcasts/build_vector/casts etc.
KnownBits KnownAmt = computeKnownBits(V.getOperand(1), DemandedElts, Depth);
if (KnownAmt.getMaxValue().ult(BitWidth))
return ConstantRange::fromKnownBits(KnownAmt, /*IsSigned=*/false);

return std::nullopt;
}

const APInt *SelectionDAG::getValidShiftAmountConstant(SDValue V) const {
std::optional<uint64_t>
SelectionDAG::getValidShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
if (std::optional<ConstantRange> AmtRange =
getValidShiftAmountRange(V, DemandedElts, Depth))
if (const APInt *ShAmt = AmtRange->getSingleElement())
return ShAmt->getZExtValue();
return std::nullopt;
}

std::optional<uint64_t>
SelectionDAG::getValidShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
return getValidShiftAmountConstant(V, DemandedElts);
return getValidShiftAmount(V, DemandedElts, Depth);
}

const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(
SDValue V, const APInt &DemandedElts) const {
std::optional<uint64_t>
SelectionDAG::getValidMinimumShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
if (!BV)
return nullptr;
const APInt *MinShAmt = nullptr;
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
if (!DemandedElts[i])
continue;
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
if (!SA)
return nullptr;
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.uge(BitWidth))
return nullptr;
if (MinShAmt && MinShAmt->ule(ShAmt))
continue;
MinShAmt = &ShAmt;
}
return MinShAmt;
if (std::optional<ConstantRange> AmtRange =
getValidShiftAmountRange(V, DemandedElts, Depth))
return AmtRange->getUnsignedMin().getZExtValue();
return std::nullopt;
}

const APInt *SelectionDAG::getValidMinimumShiftAmountConstant(SDValue V) const {
std::optional<uint64_t>
SelectionDAG::getValidMinimumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
return getValidMinimumShiftAmountConstant(V, DemandedElts);
return getValidMinimumShiftAmount(V, DemandedElts, Depth);
}

const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(
SDValue V, const APInt &DemandedElts) const {
std::optional<uint64_t>
SelectionDAG::getValidMaximumShiftAmount(SDValue V, const APInt &DemandedElts,
unsigned Depth) const {
assert((V.getOpcode() == ISD::SHL || V.getOpcode() == ISD::SRL ||
V.getOpcode() == ISD::SRA) &&
"Unknown shift node");
if (const APInt *ValidAmt = getValidShiftAmountConstant(V, DemandedElts))
return ValidAmt;
unsigned BitWidth = V.getScalarValueSizeInBits();
auto *BV = dyn_cast<BuildVectorSDNode>(V.getOperand(1));
if (!BV)
return nullptr;
const APInt *MaxShAmt = nullptr;
for (unsigned i = 0, e = BV->getNumOperands(); i != e; ++i) {
if (!DemandedElts[i])
continue;
auto *SA = dyn_cast<ConstantSDNode>(BV->getOperand(i));
if (!SA)
return nullptr;
// Shifting more than the bitwidth is not valid.
const APInt &ShAmt = SA->getAPIntValue();
if (ShAmt.uge(BitWidth))
return nullptr;
if (MaxShAmt && MaxShAmt->uge(ShAmt))
continue;
MaxShAmt = &ShAmt;
}
return MaxShAmt;
if (std::optional<ConstantRange> AmtRange =
getValidShiftAmountRange(V, DemandedElts, Depth))
return AmtRange->getUnsignedMax().getZExtValue();
return std::nullopt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should a single helper for all of the getValid*ShiftAmount functions.
Think it could just be getValidShiftAmount impl but return a KnownBits (the constant case we can just use KnownBits::makeConstant()).

Then getValidShiftAmount returns if the result is constant + less than bitwidth, getValidMinimumShiftAmount returns the getMinValue(), and vice versa for maximum. Think that will save a lot of dup code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me take a look, I'm wondering if ConstantRange will help us here. KnownBits alone tends to make the min/max values 'fuzzy'.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a bit of a mixed bag. KnownBits impl is more complete but has to throw away range information that isn't tied to particular bits. ConstantRange has a less complete impl, but obv represents what we want here better.

}

const APInt *SelectionDAG::getValidMaximumShiftAmountConstant(SDValue V) const {
std::optional<uint64_t>
SelectionDAG::getValidMaximumShiftAmount(SDValue V, unsigned Depth) const {
EVT VT = V.getValueType();
APInt DemandedElts = VT.isFixedLengthVector()
? APInt::getAllOnes(VT.getVectorNumElements())
: APInt(1, 1);
return getValidMaximumShiftAmountConstant(V, DemandedElts);
return getValidMaximumShiftAmount(V, DemandedElts, Depth);
}

/// Determine which bits of Op are known to be either zero or one and return
Expand Down Expand Up @@ -3569,9 +3583,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Known = KnownBits::shl(Known, Known2, NUW, NSW, ShAmtNonZero);

// Minimum shift low bits are known zero.
if (const APInt *ShMinAmt =
getValidMinimumShiftAmountConstant(Op, DemandedElts))
Known.Zero.setLowBits(ShMinAmt->getZExtValue());
if (std::optional<uint64_t> ShMinAmt =
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
Known.Zero.setLowBits(*ShMinAmt);
break;
}
case ISD::SRL:
Expand All @@ -3581,9 +3595,9 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
Op->getFlags().hasExact());

// Minimum shift high bits are known zero.
if (const APInt *ShMinAmt =
getValidMinimumShiftAmountConstant(Op, DemandedElts))
Known.Zero.setHighBits(ShMinAmt->getZExtValue());
if (std::optional<uint64_t> ShMinAmt =
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
Known.Zero.setHighBits(*ShMinAmt);
break;
case ISD::SRA:
Known = computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Expand Down Expand Up @@ -4587,17 +4601,17 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::SRA:
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
// SRA X, C -> adds C sign bits.
if (const APInt *ShAmt =
getValidMinimumShiftAmountConstant(Op, DemandedElts))
Tmp = std::min<uint64_t>(Tmp + ShAmt->getZExtValue(), VTBits);
if (std::optional<uint64_t> ShAmt =
getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1))
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
if (const APInt *ShAmt =
getValidMaximumShiftAmountConstant(Op, DemandedElts)) {
if (std::optional<uint64_t> ShAmt =
getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
if (ShAmt->ult(Tmp))
return Tmp - ShAmt->getZExtValue();
if (*ShAmt < Tmp)
return Tmp - *ShAmt;
}
break;
case ISD::AND:
Expand Down Expand Up @@ -5270,7 +5284,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts,
case ISD::SRL:
case ISD::SRA:
// If the max shift amount isn't in range, then the shift can create poison.
return !getValidMaximumShiftAmountConstant(Op, DemandedElts);
return !getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1);

case ISD::SCALAR_TO_VECTOR:
// Check if we demand any upper (undef) elements.
Expand Down
Loading
Loading