Skip to content

Commit a7b5122

Browse files
committed
[AMDGPU]: Accept constant zero bytes in v_perm OrCombine
Change-Id: I5925a3ab10031bf6adcb08ad97d99193b21b11ec
1 parent c95693c commit a7b5122

File tree

10 files changed

+319
-273
lines changed

10 files changed

+319
-273
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11612,6 +11612,29 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
1161211612
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
1161311613
}
1161411614

11615+
case ISD::EXTRACT_VECTOR_ELT: {
11616+
auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11617+
if (!IdxOp)
11618+
return std::nullopt;
11619+
auto VecIdx = IdxOp->getZExtValue();
11620+
auto ScalarSize = Op.getScalarValueSizeInBits();
11621+
11622+
assert((ScalarSize >= 8) && !(ScalarSize % 8));
11623+
11624+
if (ScalarSize < 32) {
11625+
// TODO: support greater than 32 bit sources
11626+
if ((VecIdx + 1) * ScalarSize > 32)
11627+
return std::nullopt;
11628+
11629+
SrcIndex = VecIdx * ScalarSize / 8 + SrcIndex;
11630+
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
11631+
}
11632+
11633+
// The scalar is 32 bits, so just use the scalar
11634+
// TODO: support greater than 32 bit sources
11635+
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11636+
}
11637+
1161511638
default: {
1161611639
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
1161711640
}
@@ -11922,6 +11945,9 @@ static bool addresses16Bits(int Mask) {
1192211945
int Low8 = Mask & 0xff;
1192311946
int Hi8 = (Mask & 0xff00) >> 8;
1192411947

11948+
if (Low8 == 0x0c || Hi8 == 0x0c)
11949+
return false;
11950+
1192511951
assert(Low8 < 8 && Hi8 < 8);
1192611952
// Are the bytes contiguous in the order of increasing addresses.
1192711953
bool IsConsecutive = (Hi8 - Low8 == 1);
@@ -12016,17 +12042,16 @@ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
1201612042

1201712043
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1201812044
SelectionDAG &DAG = DCI.DAG;
12019-
[[maybe_unused]] EVT VT = N->getValueType(0);
1202012045
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
1202112046

1202212047
// VT is known to be MVT::i32, so we need to provide 4 bytes.
12023-
assert(VT == MVT::i32);
12048+
assert(N->getValueType(0) == MVT::i32);
12049+
1202412050
for (int i = 0; i < 4; i++) {
1202512051
// Find the ByteProvider that provides the ith byte of the result of OR
1202612052
std::optional<ByteProvider<SDValue>> P =
1202712053
calculateByteProvider(SDValue(N, 0), i, 0, /*StartingIndex = */ i);
12028-
// TODO support constantZero
12029-
if (!P || P->isConstantZero())
12054+
if (!P)
1203012055
return SDValue();
1203112056

1203212057
PermNodes.push_back(*P);
@@ -12039,6 +12064,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1203912064
uint64_t PermMask = 0x00000000;
1204012065
for (size_t i = 0; i < PermNodes.size(); i++) {
1204112066
auto PermOp = PermNodes[i];
12067+
if (PermOp.isConstantZero()) {
12068+
if (FirstSrc.first == i)
12069+
++FirstSrc.first;
12070+
PermMask |= 0x0c << (i * 8);
12071+
continue;
12072+
}
1204212073
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
1204312074
// by sizeof(Src2) = 4
1204412075
int SrcByteAdjust = 4;
@@ -12062,10 +12093,14 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1206212093
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
1206312094
}
1206412095
SDLoc DL(N);
12096+
if (PermMask == 0x0c0c0c0c)
12097+
return DAG.getConstant(0, DL, MVT::i32);
12098+
1206512099
SDValue Op = *PermNodes[FirstSrc.first].Src;
1206612100
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
1206712101
assert(Op.getValueSizeInBits() == 32);
1206812102

12103+
SDValue OtherOp;
1206912104
// Check that we are not just extracting the bytes in order from an op
1207012105
if (!SecondSrc) {
1207112106
int Low16 = PermMask & 0xffff;
@@ -12077,17 +12112,17 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
1207712112
// The perm op would really just produce Op. So combine into Op
1207812113
if (WellFormedLow && WellFormedHi)
1207912114
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
12080-
}
1208112115

12082-
SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;
12116+
OtherOp = Op;
12117+
}
1208312118

1208412119
if (SecondSrc) {
12085-
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
12120+
OtherOp = getDWordFromOffset(DAG, DL, *PermNodes[SecondSrc->first].Src,
12121+
SecondSrc->second);
1208612122
assert(OtherOp.getValueSizeInBits() == 32);
1208712123
}
1208812124

1208912125
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
12090-
1209112126
assert(Op.getValueType().isByteSized() &&
1209212127
OtherOp.getValueType().isByteSized());
1209312128

@@ -12159,12 +12194,33 @@ SDValue SITargetLowering::performOrCombine(SDNode *N,
1215912194
// If all the uses of an or need to extract the individual elements, do not
1216012195
// attempt to lower into v_perm
1216112196
auto usesCombinedOperand = [](SDNode *OrUse) {
12197+
// The combined bytes seem to be getting extracted
12198+
if (OrUse->getOpcode() == ISD::SRL || OrUse->getOpcode() == ISD::TRUNCATE)
12199+
return false;
12200+
12201+
if (OrUse->getOpcode() == ISD::AND) {
12202+
auto SelectMask = dyn_cast<ConstantSDNode>(OrUse->getOperand(1));
12203+
if (SelectMask && (SelectMask->getZExtValue() == 0xFF))
12204+
return false;
12205+
}
12206+
12207+
if (OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE0 ||
12208+
OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE1 ||
12209+
OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE2 ||
12210+
OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE3) {
12211+
return false;
12212+
}
12213+
12214+
if (auto StoreUse = dyn_cast<StoreSDNode>(OrUse))
12215+
if (StoreUse->isTruncatingStore() &&
12216+
StoreUse->getMemoryVT().getSizeInBits() == 8)
12217+
return false;
12218+
1216212219
// If we have any non-vectorized use, then it is a candidate for v_perm
12163-
if (OrUse->getOpcode() != ISD::BITCAST ||
12164-
!OrUse->getValueType(0).isVector())
12220+
if (!(OrUse->getValueType(0).isVector() &&
12221+
OrUse->getOpcode() != ISD::BUILD_VECTOR))
1216512222
return true;
1216612223

12167-
// If we have any non-vectorized use, then it is a candidate for v_perm
1216812224
for (auto VUse : OrUse->uses()) {
1216912225
if (!VUse->getValueType(0).isVector())
1217012226
return true;

0 commit comments

Comments
 (0)