Skip to content

Commit d65febe

Browse files
committed
[AMDGPU] Make getDWordFromOffset robust against exotic types + handle vectors in CalcByteProvider
Change-Id: I7598f17ebd0bffb247c40944dea1f845bc16238b
1 parent cdd6cc0 commit d65febe

File tree

2 files changed

+338
-40
lines changed

2 files changed

+338
-40
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10946,6 +10946,9 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1094610946
if (Op.getValueSizeInBits() < 8)
1094710947
return std::nullopt;
1094810948

10949+
if (Op.getValueType().isVector())
10950+
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex, IsSigned);
10951+
1094910952
switch (Op->getOpcode()) {
1095010953
case ISD::TRUNCATE: {
1095110954
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
@@ -11031,8 +11034,12 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1103111034
if (Index > BitWidth / 8 - 1)
1103211035
return std::nullopt;
1103311036

11037+
bool IsVec = Op.getValueType().isVector();
1103411038
switch (Op.getOpcode()) {
1103511039
case ISD::OR: {
11040+
if (IsVec)
11041+
return std::nullopt;
11042+
1103611043
auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
1103711044
StartingIndex, IsSigned);
1103811045
if (!RHS)
@@ -11053,6 +11060,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1105311060
}
1105411061

1105511062
case ISD::AND: {
11063+
if (IsVec)
11064+
return std::nullopt;
11065+
1105611066
auto BitMaskOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
1105711067
if (!BitMaskOp)
1105811068
return std::nullopt;
@@ -11073,6 +11083,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1107311083
}
1107411084

1107511085
case ISD::FSHR: {
11086+
if (IsVec)
11087+
return std::nullopt;
11088+
1107611089
// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
1107711090
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(2));
1107811091
if (!ShiftOp || Op.getValueType().isVector())
@@ -11098,6 +11111,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1109811111

1109911112
case ISD::SRA:
1110011113
case ISD::SRL: {
11114+
if (IsVec)
11115+
return std::nullopt;
11116+
1110111117
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
1110211118
if (!ShiftOp)
1110311119
return std::nullopt;
@@ -11123,6 +11139,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1112311139
}
1112411140

1112511141
case ISD::SHL: {
11142+
if (IsVec)
11143+
return std::nullopt;
11144+
1112611145
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
1112711146
if (!ShiftOp)
1112811147
return std::nullopt;
@@ -11147,6 +11166,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1114711166
case ISD::SIGN_EXTEND_INREG:
1114811167
case ISD::AssertZext:
1114911168
case ISD::AssertSext: {
11169+
if (IsVec)
11170+
return std::nullopt;
11171+
1115011172
SDValue NarrowOp = Op->getOperand(0);
1115111173
unsigned NarrowBitWidth = NarrowOp.getValueSizeInBits();
1115211174
if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
@@ -11177,6 +11199,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1117711199
}
1117811200

1117911201
case ISD::TRUNCATE: {
11202+
if (IsVec)
11203+
return std::nullopt;
11204+
1118011205
uint64_t NarrowByteWidth = BitWidth / 8;
1118111206

1118211207
if (NarrowByteWidth >= Index) {
@@ -11223,9 +11248,13 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1122311248
return std::nullopt;
1122411249
}
1122511250

11226-
case ISD::BSWAP:
11251+
case ISD::BSWAP: {
11252+
if (IsVec)
11253+
return std::nullopt;
11254+
1122711255
return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
1122811256
Depth + 1, StartingIndex, IsSigned);
11257+
}
1122911258

1123011259
case ISD::EXTRACT_VECTOR_ELT: {
1123111260
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11242,6 +11271,9 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
1124211271
}
1124311272

1124411273
case AMDGPUISD::PERM: {
11274+
if (IsVec)
11275+
return std::nullopt;
11276+
1124511277
auto PermMask = dyn_cast<ConstantSDNode>(Op->getOperand(2));
1124611278
if (!PermMask)
1124711279
return std::nullopt;
@@ -11339,25 +11371,55 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
1133911371
static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
1134011372
unsigned DWordOffset) {
1134111373
SDValue Ret;
11342-
if (Src.getValueSizeInBits() <= 32)
11343-
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
1134411374

11345-
if (Src.getValueSizeInBits() >= 256) {
11346-
assert(!(Src.getValueSizeInBits() % 32));
11347-
Ret = DAG.getBitcast(
11348-
MVT::getVectorVT(MVT::i32, Src.getValueSizeInBits() / 32), Src);
11349-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Ret,
11350-
DAG.getConstant(DWordOffset, SL, MVT::i32));
11351-
}
11375+
auto TypeSize = Src.getValueSizeInBits().getFixedValue();
11376+
// ByteProvider must be at least 8 bits
11377+
assert(Src.getValueSizeInBits().isKnownMultipleOf(8));
1135211378

11353-
Ret = DAG.getBitcastedAnyExtOrTrunc(
11354-
Src, SL, MVT::getIntegerVT(Src.getValueSizeInBits()));
11355-
if (DWordOffset) {
11356-
auto Shifted = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11357-
DAG.getConstant(DWordOffset * 32, SL, MVT::i32));
11358-
return DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, Shifted);
11359-
}
11379+
if (TypeSize <= 32)
11380+
return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
1136011381

11382+
if (Src.getValueType().isVector()) {
11383+
auto ScalarTySize = Src.getScalarValueSizeInBits();
11384+
auto ScalarTy = Src.getValueType().getScalarType();
11385+
if (ScalarTySize == 32) {
11386+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Src,
11387+
DAG.getConstant(DWordOffset, SL, MVT::i32));
11388+
}
11389+
if (ScalarTySize > 32) {
11390+
Ret = DAG.getNode(
11391+
ISD::EXTRACT_VECTOR_ELT, SL, ScalarTy, Src,
11392+
DAG.getConstant(DWordOffset / (ScalarTySize / 32), SL, MVT::i32));
11393+
auto ShiftVal = 32 * (DWordOffset % (ScalarTySize / 32));
11394+
if (ShiftVal)
11395+
Ret = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
11396+
DAG.getConstant(ShiftVal, SL, MVT::i32));
11397+
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11398+
}
11399+
11400+
assert(ScalarTySize < 32);
11401+
auto NumElements = TypeSize / ScalarTySize;
11402+
auto Trunc32Elements = (ScalarTySize * NumElements) / 32;
11403+
auto NormalizedTrunc = Trunc32Elements * 32 / ScalarTySize;
11404+
auto NumElementsIn32 = 32 / ScalarTySize;
11405+
auto NumAvailElements = DWordOffset < Trunc32Elements
11406+
? NumElementsIn32
11407+
: NumElements - NormalizedTrunc;
11408+
11409+
SmallVector<SDValue, 4> VecSrcs;
11410+
DAG.ExtractVectorElements(Src, VecSrcs, DWordOffset * NumElementsIn32,
11411+
NumAvailElements);
11412+
11413+
Ret = DAG.getBuildVector(
11414+
MVT::getVectorVT(MVT::getIntegerVT(ScalarTySize), NumAvailElements), SL,
11415+
VecSrcs);
11416+
return Ret = DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
11417+
}
11418+
11419+
/// Scalar Type
11420+
auto ShiftVal = 32 * DWordOffset;
11421+
Ret = DAG.getNode(ISD::SRL, SL, Src.getValueType(), Src,
11422+
DAG.getConstant(ShiftVal, SL, MVT::i32));
1136111423
return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
1136211424
}
1136311425

@@ -11426,13 +11488,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1142611488
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
1142711489
}
1142811490

11429-
SDValue OtherOp =
11430-
SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
11491+
SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;
1143111492

11432-
if (SecondSrc)
11493+
if (SecondSrc) {
1143311494
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
11434-
11435-
assert(Op.getValueSizeInBits() == 32);
11495+
assert(OtherOp.getValueSizeInBits() == 32);
11496+
}
1143611497

1143711498
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
1143811499

0 commit comments

Comments
 (0)