Skip to content

Commit 7468bfd

Browse files
committed
[AMDGPU] Make getDWordFromOffset robust against exotic types + handle vectors in CalcByteProvider
Change-Id: I88775857394ac698e25ca1b89d7092d1dee50c33
1 parent 85a17a4 commit 7468bfd

File tree

2 files changed

+337
-40
lines changed

2 files changed

+337
-40
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11570,6 +11570,9 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1157011570
if (Op.getValueSizeInBits() < 8)
1157111571
return std::nullopt;
1157211572

11573+
if (Op.getValueType().isVector())
11574+
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11575+
1157311576
switch (Op->getOpcode()) {
1157411577
case ISD::TRUNCATE: {
1157511578
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
@@ -11636,8 +11639,12 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1163611639
if (Index > BitWidth / 8 - 1)
1163711640
return std::nullopt;
1163811641

11642+
bool IsVec = Op.getValueType().isVector();
1163911643
switch (Op.getOpcode()) {
1164011644
case ISD::OR: {
11645+
if (IsVec)
11646+
return std::nullopt;
11647+
1164111648
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
1164211649
StartingIndex);
1164311650
if (!RHS)
@@ -11658,6 +11665,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1165811665
}
1165911666

1166011667
case ISD::AND: {
11668+
if (IsVec)
11669+
return std::nullopt;
11670+
1166111671
auto BitMaskOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
1166211672
if (!BitMaskOp)
1166311673
return std::nullopt;
@@ -11678,6 +11688,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1167811688
}
1167911689

1168011690
case ISD::FSHR: {
11691+
if (IsVec)
11692+
return std::nullopt;
11693+
1168111694
// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
1168211695
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(2));
1168311696
if (!ShiftOp || Op.getValueType().isVector())
@@ -11703,6 +11716,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1170311716

1170411717
case ISD::SRA:
1170511718
case ISD::SRL: {
11719+
if (IsVec)
11720+
return std::nullopt;
11721+
1170611722
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
1170711723
if (!ShiftOp)
1170811724
return std::nullopt;
@@ -11728,6 +11744,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1172811744
}
1172911745

1173011746
case ISD::SHL: {
11747+
if (IsVec)
11748+
return std::nullopt;
11749+
1173111750
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
1173211751
if (!ShiftOp)
1173311752
return std::nullopt;
@@ -11752,6 +11771,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1175211771
case ISD::SIGN_EXTEND_INREG:
1175311772
case ISD::AssertZext:
1175411773
case ISD::AssertSext: {
11774+
if (IsVec)
11775+
return std::nullopt;
11776+
1175511777
SDValue NarrowOp = Op->getOperand(0);
1175611778
unsigned NarrowBitWidth = NarrowOp.getValueSizeInBits();
1175711779
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
@@ -11773,6 +11795,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1177311795
}
1177411796

1177511797
case ISD::TRUNCATE: {
11798+
if (IsVec)
11799+
return std::nullopt;
11800+
1177611801
uint64_t NarrowByteWidth = BitWidth / 8;
1177711802

1177811803
if (NarrowByteWidth >= Index) {
@@ -11815,9 +11840,13 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1181511840
return std::nullopt;
1181611841
}
1181711842

11818-
case ISD::BSWAP:
11843+
case ISD::BSWAP: {
11844+
if (IsVec)
11845+
return std::nullopt;
11846+
1181911847
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
1182011848
Depth + 1, StartingIndex);
11849+
}
1182111850

1182211851
case ISD::EXTRACT_VECTOR_ELT: {
1182311852
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11834,6 +11863,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1183411863
}
1183511864

1183611865
case AMDGPUISD::PERM: {
11866+
if (IsVec)
11867+
return std::nullopt;
11868+
1183711869
auto PermMask = dyn_cast<ConstantSDNode>(Op->getOperand(2));
1183811870
if (!PermMask)
1183911871
return std::nullopt;
@@ -11930,25 +11962,55 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1193011962
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
1193111963
unsigned DWordOffset) {
1193211964
SDValue Ret;
11933-
if (Src.getValueSizeInBits() <= 32)
11934-
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
1193511965

11936-
if (Src.getValueSizeInBits() >= 256) {
11937-
assert(!(Src.getValueSizeInBits() % 32));
11938-
Ret = DAG.getBitcast(
11939-
MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11940-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11941-
DAG.getConstant(DWordOffset, SL, MVT::i32));
11942-
}
11966+
auto TypeSize = Src.getValueSizeInBits().getFixedValue();
11967+
// ByteProvider must be at least 8 bits
11968+
assert(Src.getValueSizeInBits().isKnownMultipleOf(8));
1194311969

11944-
Ret = DAG.getBitcastedAnyExtOrTrunc(
11945-
Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11946-
if (DWordOffset) {
11947-
auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11948-
DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11949-
return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11950-
}
11970+
if (TypeSize <= 32)
11971+
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
1195111972

11973+
if (Src.getValueType().isVector()) {
11974+
auto ScalarTySize = Src.getScalarValueSizeInBits();
11975+
auto ScalarTy = Src.getValueType().getScalarType();
11976+
if (ScalarTySize == 32) {
11977+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Src,
11978+
DAG.getConstant(DWordOffset, SL, MVT::i32));
11979+
}
11980+
if (ScalarTySize > 32) {
11981+
Ret = DAG.getNode(
11982+
ISD::EXTRACT_VECTOR_ELT, SL, ScalarTy, Src,
11983+
DAG.getConstant(DWordOffset / (ScalarTySize / 32), SL, MVT::i32));
11984+
auto ShiftVal = 32 * (DWordOffset % (ScalarTySize / 32));
11985+
if (ShiftVal)
11986+
Ret = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11987+
DAG.getConstant(ShiftVal, SL, MVT::i32));
11988+
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11989+
}
11990+
11991+
assert(ScalarTySize < 32);
11992+
auto NumElements = TypeSize / ScalarTySize;
11993+
auto Trunc32Elements = (ScalarTySize * NumElements) / 32;
11994+
auto NormalizedTrunc = Trunc32Elements * 32 / ScalarTySize;
11995+
auto NumElementsIn32 = 32 / ScalarTySize;
11996+
auto NumAvailElements = DWordOffset < Trunc32Elements
11997+
? NumElementsIn32
11998+
: NumElements - NormalizedTrunc;
11999+
12000+
SmallVector<SDValue, 4> VecSrcs;
12001+
DAG.ExtractVectorElements(Src, VecSrcs, DWordOffset * NumElementsIn32,
12002+
NumAvailElements);
12003+
12004+
Ret = DAG.getBuildVector(
12005+
MVT::getVectorVT(MVT::getIntegerVT(ScalarTySize), NumAvailElements), SL,
12006+
VecSrcs);
12007+
return Ret = DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
12008+
}
12009+
12010+
/// Scalar Type
12011+
auto ShiftVal = 32 * DWordOffset;
12012+
Ret = DAG.getNode(ISD::SRL, SL, Src.getValueType(), Src,
12013+
DAG.getConstant(ShiftVal, SL, MVT::i32));
1195212014
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
1195312015
}
1195412016

@@ -12017,13 +12079,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1201712079
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
1201812080
}
1201912081

12020-
SDValue OtherOp =
12021-
SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
12082+
SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;
1202212083

12023-
if (SecondSrc)
12084+
if (SecondSrc) {
1202412085
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
12025-
12026-
assert(Op.getValueSizeInBits() == 32);
12086+
assert(OtherOp.getValueSizeInBits() == 32);
12087+
}
1202712088

1202812089
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
1202912090

0 commit comments

Comments
 (0)