Skip to content

[AMDGPU] Rework dot4 signedness checks #68757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions llvm/include/llvm/CodeGen/ByteProvider.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ template <typename ISelOp> class ByteProvider {
ByteProvider(std::optional<ISelOp> Src, int64_t DestOffset, int64_t SrcOffset)
: Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {}

ByteProvider(std::optional<ISelOp> Src, int64_t DestOffset, int64_t SrcOffset,
std::optional<bool> IsSigned)
: Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset),
IsSigned(IsSigned) {}

// TODO -- use constraint in c++20
// Does this type correspond with an operation in selection DAG
template <typename T> class is_op {
Expand Down Expand Up @@ -66,9 +61,6 @@ template <typename ISelOp> class ByteProvider {
// DestOffset
int64_t SrcOffset = 0;

// Whether or not the path to this Src involved signed extensions
std::optional<bool> IsSigned;

ByteProvider() = default;

static ByteProvider getSrc(std::optional<ISelOp> Val, int64_t ByteOffset,
Expand All @@ -78,14 +70,6 @@ template <typename ISelOp> class ByteProvider {
return ByteProvider(Val, ByteOffset, VectorOffset);
}

static ByteProvider getSrc(std::optional<ISelOp> Val, int64_t ByteOffset,
int64_t VectorOffset,
std::optional<bool> IsSigned) {
static_assert(is_op<ISelOp>().value,
"ByteProviders must contain an operation in selection DAG.");
return ByteProvider(Val, ByteOffset, VectorOffset, IsSigned);
}

static ByteProvider getConstantZero() {
return ByteProvider<ISelOp>(std::nullopt, 0, 0);
}
Expand Down
171 changes: 88 additions & 83 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10940,7 +10940,6 @@ SDValue SITargetLowering::performAndCombine(SDNode *N,
// performed.
static const std::optional<ByteProvider<SDValue>>
calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
std::optional<bool> IsSigned = std::nullopt,
unsigned Depth = 0) {
// We may need to recursively traverse a series of SRLs
if (Depth >= 6)
Expand All @@ -10952,16 +10951,12 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,

switch (Op->getOpcode()) {
case ISD::TRUNCATE: {
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
Depth + 1);
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
}

case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND_INREG: {
IsSigned = IsSigned.value_or(false) ||
Op->getOpcode() == ISD::SIGN_EXTEND ||
Op->getOpcode() == ISD::SIGN_EXTEND_INREG;
SDValue NarrowOp = Op->getOperand(0);
auto NarrowVT = NarrowOp.getValueType();
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG) {
Expand All @@ -10974,8 +10969,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,

if (SrcIndex >= NarrowByteWidth)
return std::nullopt;
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
Depth + 1);
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
}

case ISD::SRA:
Expand All @@ -10991,24 +10985,11 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,

SrcIndex += BitShift / 8;

return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
Depth + 1);
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
}

default: {
if (isa<AtomicSDNode>(Op) || Op->isMemIntrinsic()) {
// If this causes us to throw away signedness info, then fail.
if (IsSigned)
return std::nullopt;
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
}

if (auto L = dyn_cast<LoadSDNode>(Op))
if (L->getExtensionType() != ISD::NON_EXTLOAD)
IsSigned =
IsSigned.value_or(false) || L->getExtensionType() == ISD::SEXTLOAD;

return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex, IsSigned);
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
}
}
llvm_unreachable("fully handled switch");
Expand All @@ -11022,8 +11003,7 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
// performed. \p StartingIndex is the originally requested byte of the Or
static const std::optional<ByteProvider<SDValue>>
calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
unsigned StartingIndex = 0,
std::optional<bool> IsSigned = std::nullopt) {
unsigned StartingIndex = 0) {
// Finding Src tree of RHS of or typically requires at least 1 additional
// depth
if (Depth > 6)
Expand All @@ -11038,11 +11018,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
switch (Op.getOpcode()) {
case ISD::OR: {
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
StartingIndex, IsSigned);
StartingIndex);
if (!RHS)
return std::nullopt;
auto LHS = calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
StartingIndex, IsSigned);
StartingIndex);
if (!LHS)
return std::nullopt;
// A well formed Or will have two ByteProviders for each byte, one of which
Expand Down Expand Up @@ -11073,7 +11053,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
return ByteProvider<SDValue>::getConstantZero();
}

return calculateSrcByte(Op->getOperand(0), StartingIndex, Index, IsSigned);
return calculateSrcByte(Op->getOperand(0), StartingIndex, Index);
}

case ISD::FSHR: {
Expand Down Expand Up @@ -11122,7 +11102,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
// the SRL is Index + ByteShift
return BytesProvided - ByteShift > Index
? calculateSrcByte(Op->getOperand(0), StartingIndex,
Index + ByteShift, IsSigned)
Index + ByteShift)
: ByteProvider<SDValue>::getConstantZero();
}

Expand All @@ -11143,7 +11123,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
return Index < ByteShift
? ByteProvider<SDValue>::getConstantZero()
: calculateByteProvider(Op.getOperand(0), Index - ByteShift,
Depth + 1, StartingIndex, IsSigned);
Depth + 1, StartingIndex);
}
case ISD::ANY_EXTEND:
case ISD::SIGN_EXTEND:
Expand All @@ -11163,48 +11143,35 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
return std::nullopt;
uint64_t NarrowByteWidth = NarrowBitWidth / 8;

IsSigned =
Op->getOpcode() != ISD::ANY_EXTEND
? std::optional<bool>(IsSigned.value_or(false) ||
Op->getOpcode() == ISD::SIGN_EXTEND ||
Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
Op->getOpcode() == ISD::AssertSext)
: IsSigned;

if (Index >= NarrowByteWidth)
return Op.getOpcode() == ISD::ZERO_EXTEND
? std::optional<ByteProvider<SDValue>>(
ByteProvider<SDValue>::getConstantZero())
: std::nullopt;
return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex,
IsSigned);
return calculateByteProvider(NarrowOp, Index, Depth + 1, StartingIndex);
}

case ISD::TRUNCATE: {
uint64_t NarrowByteWidth = BitWidth / 8;

if (NarrowByteWidth >= Index) {
return calculateByteProvider(Op.getOperand(0), Index, Depth + 1,
StartingIndex, IsSigned);
StartingIndex);
}

return std::nullopt;
}

case ISD::CopyFromReg: {
if (BitWidth / 8 > Index)
return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
return calculateSrcByte(Op, StartingIndex, Index);

return std::nullopt;
}

case ISD::LOAD: {
auto L = cast<LoadSDNode>(Op.getNode());

// Only set IsSigned if the load is extended.
if (L->getExtensionType() != ISD::NON_EXTLOAD)
IsSigned =
IsSigned.value_or(false) || L->getExtensionType() == ISD::SEXTLOAD;
unsigned NarrowBitWidth = L->getMemoryVT().getSizeInBits();
if (NarrowBitWidth % 8 != 0)
return std::nullopt;
Expand All @@ -11221,15 +11188,15 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
}

if (NarrowByteWidth > Index) {
return calculateSrcByte(Op, StartingIndex, Index, IsSigned);
return calculateSrcByte(Op, StartingIndex, Index);
}

return std::nullopt;
}

case ISD::BSWAP:
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
Depth + 1, StartingIndex, IsSigned);
Depth + 1, StartingIndex);

case ISD::EXTRACT_VECTOR_ELT: {
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
Expand All @@ -11244,7 +11211,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
}

return calculateSrcByte(ScalarSize == 32 ? Op : Op.getOperand(0),
StartingIndex, Index, IsSigned);
StartingIndex, Index);
}

case AMDGPUISD::PERM: {
Expand All @@ -11260,10 +11227,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
auto NextOp = Op.getOperand(IdxMask > 0x03 ? 0 : 1);
auto NextIndex = IdxMask > 0x03 ? IdxMask % 4 : IdxMask;

return IdxMask != 0x0c
? calculateSrcByte(NextOp, StartingIndex, NextIndex, IsSigned)
: ByteProvider<SDValue>(
ByteProvider<SDValue>::getConstantZero());
return IdxMask != 0x0c ? calculateSrcByte(NextOp, StartingIndex, NextIndex)
: ByteProvider<SDValue>(
ByteProvider<SDValue>::getConstantZero());
}

default: {
Expand Down Expand Up @@ -13064,32 +13030,67 @@ static bool isMul(const SDValue Op) {
Opcode == AMDGPUISD::MUL_I24);
}

static std::optional<bool> checkSignedness(const SDValue &N,
ByteProvider<SDValue> &Src0,
ByteProvider<SDValue> &Src1) {
auto MulOpcode = N.getOpcode();
std::optional<bool> IterIsSigned;
// Both sides of the tree must have the same signedness semantics.
if ((Src0.IsSigned != Src1.IsSigned) ||
(Src0.IsSigned.value_or(false) != Src1.IsSigned.value_or(false)))
return IterIsSigned;
// If we have a MUL_U24 op with signed semantics, then fail.
if (Src0.IsSigned.value_or(false) && MulOpcode == AMDGPUISD::MUL_U24)
return IterIsSigned;
// If we have a MUL_I24 op with unsigned semantics, then fail.
if (!Src0.IsSigned.value_or(true) && MulOpcode == AMDGPUISD::MUL_I24)
return IterIsSigned;

bool TopLevelSignedness =
MulOpcode == AMDGPUISD::MUL_I24 ||
(MulOpcode == ISD::MUL && N.getNode()->getFlags().hasNoSignedWrap() &&
!N.getNode()->getFlags().hasNoUnsignedWrap());

// In cases where we are accumulating into an i8 (for v_dot4), the
// ByteProvider will not have signedness info since the MSBs are dont-cares.
// In this case, we simply use the TopLevelSignedness of the instruction.
IterIsSigned = Src0.IsSigned.value_or(TopLevelSignedness);
return IterIsSigned;
static std::optional<bool>
checkDot4MulSignedness(const SDValue &N, ByteProvider<SDValue> &Src0,
ByteProvider<SDValue> &Src1, const SDValue &S0Op,
const SDValue &S1Op, const SelectionDAG &DAG) {
// If we both ops are i8s (pre legalize-dag), then the signedness semantics
// of the dot4 is irrelevant.
if (S0Op.getValueSizeInBits() == 8 && S1Op.getValueSizeInBits() == 8)
return false;

auto Known0 = DAG.computeKnownBits(S0Op, 0);
bool S0IsUnsigned = Known0.countMinLeadingZeros() > 0;
bool S0IsSigned = Known0.countMinLeadingOnes() > 0;
auto Known1 = DAG.computeKnownBits(S1Op, 0);
bool S1IsUnsigned = Known1.countMinLeadingZeros() > 0;
bool S1IsSigned = Known1.countMinLeadingOnes() > 0;

assert(!(S0IsUnsigned && S0IsSigned));
assert(!(S1IsUnsigned && S1IsSigned));

// There are 9 possible permutations of
// {S0IsUnsigned, S0IsSigned, S1IsUnsigned, S1IsSigned}

// In two permutations, the sign bits are known to be the same for both Ops,
// so simply return Signed / Unsigned corresponding to the MSB

if ((S0IsUnsigned && S1IsUnsigned) || (S0IsSigned && S1IsSigned))
return S0IsSigned;

// In another two permutations, the sign bits are known to be opposite. In
// this case return std::nullopt to indicate a bad match.

if ((S0IsUnsigned && S1IsSigned) || (S0IsSigned && S1IsUnsigned))
return std::nullopt;

// In the remaining five permutations, we don't know the value of the sign
// bit for at least one Op. Since we have a valid ByteProvider, we know that
// the upper bits must be extension bits. Thus, the only ways for the sign
// bit to be unknown is if it was sign extended from unknown value, or if it
// was any extended. In either case, it is correct to use the signed
// version of the signedness semantics of dot4

// In two of such permutations, we known the sign bit is set for
// one op, and the other is unknown. It is okay to used signed version of
// dot4.
if ((S0IsSigned && !(S1IsSigned || S1IsUnsigned)) ||
((S1IsSigned && !(S0IsSigned || S0IsUnsigned))))
return true;

// In one such permutation, we don't know either of the sign bits. It is okay
// to used the signed version of dot4.
if ((!(S1IsSigned || S1IsUnsigned) && !(S0IsSigned || S0IsUnsigned)))
return true;

// In two of such permutations, we known the sign bit is unset for
// one op, and the other is unknown. Return std::nullopt to indicate a
// bad match.
if ((S0IsUnsigned && !(S1IsSigned || S1IsUnsigned)) ||
((S1IsUnsigned && !(S0IsSigned || S0IsUnsigned))))
return std::nullopt;

llvm_unreachable("Fully covered condition");
}

SDValue SITargetLowering::performAddCombine(SDNode *N,
Expand Down Expand Up @@ -13132,8 +13133,10 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
if (!Src1)
break;

auto IterIsSigned =
checkSignedness(TempNode->getOperand(MulIdx), *Src0, *Src1);
auto IterIsSigned = checkDot4MulSignedness(
TempNode->getOperand(MulIdx), *Src0, *Src1,
TempNode->getOperand(MulIdx)->getOperand(0),
TempNode->getOperand(MulIdx)->getOperand(1), DAG);
if (!IterIsSigned)
break;
if (!IsSigned)
Expand All @@ -13154,8 +13157,10 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
if (!Src1)
break;
auto IterIsSigned =
checkSignedness(TempNode->getOperand(AddIdx), *Src0, *Src1);
auto IterIsSigned = checkDot4MulSignedness(
TempNode->getOperand(AddIdx), *Src0, *Src1,
TempNode->getOperand(AddIdx)->getOperand(0),
TempNode->getOperand(AddIdx)->getOperand(1), DAG);
if (!IterIsSigned)
break;
assert(IsSigned);
Expand Down
Loading