@@ -11612,6 +11612,29 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
11612
11612
return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
11613
11613
}
11614
11614
11615
+ case ISD::EXTRACT_VECTOR_ELT: {
11616
+ auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
11617
+ if (!IdxOp)
11618
+ return std::nullopt;
11619
+ auto VecIdx = IdxOp->getZExtValue();
11620
+ auto ScalarSize = Op.getScalarValueSizeInBits();
11621
+
11622
+ assert((ScalarSize >= 8) && !(ScalarSize % 8));
11623
+
11624
+ if (ScalarSize < 32) {
11625
+ // TODO: support greater than 32 bit sources
11626
+ if ((VecIdx + 1) * ScalarSize > 32)
11627
+ return std::nullopt;
11628
+
11629
+ SrcIndex = VecIdx * ScalarSize / 8 + SrcIndex;
11630
+ return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, Depth + 1);
11631
+ }
11632
+
11633
+ // The scalar is 32 bits, so just use the scalar
11634
+ // TODO: support greater than 32 bit sources
11635
+ return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11636
+ }
11637
+
11615
11638
default: {
11616
11639
return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex);
11617
11640
}
@@ -11922,6 +11945,9 @@ static bool addresses16Bits(int Mask) {
11922
11945
int Low8 = Mask & 0xff;
11923
11946
int Hi8 = (Mask & 0xff00) >> 8;
11924
11947
11948
+ if (Low8 == 0x0c || Hi8 == 0x0c)
11949
+ return false;
11950
+
11925
11951
assert(Low8 < 8 && Hi8 < 8);
11926
11952
// Are the bytes contiguous in the order of increasing addresses.
11927
11953
bool IsConsecutive = (Hi8 - Low8 == 1);
@@ -12016,17 +12042,16 @@ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
12016
12042
12017
12043
static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12018
12044
SelectionDAG &DAG = DCI.DAG;
12019
- [[maybe_unused]] EVT VT = N->getValueType(0);
12020
12045
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
12021
12046
12022
12047
// VT is known to be MVT::i32, so we need to provide 4 bytes.
12023
- assert(VT == MVT::i32);
12048
+ assert(N->getValueType(0) == MVT::i32);
12049
+
12024
12050
for (int i = 0; i < 4; i++) {
12025
12051
// Find the ByteProvider that provides the ith byte of the result of OR
12026
12052
std::optional<ByteProvider<SDValue>> P =
12027
12053
calculateByteProvider(SDValue(N, 0), i, 0, /*StartingIndex = */ i);
12028
- // TODO support constantZero
12029
- if (!P || P->isConstantZero())
12054
+ if (!P)
12030
12055
return SDValue();
12031
12056
12032
12057
PermNodes.push_back(*P);
@@ -12039,6 +12064,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12039
12064
uint64_t PermMask = 0x00000000;
12040
12065
for (size_t i = 0; i < PermNodes.size(); i++) {
12041
12066
auto PermOp = PermNodes[i];
12067
+ if (PermOp.isConstantZero()) {
12068
+ if (FirstSrc.first == i)
12069
+ ++FirstSrc.first;
12070
+ PermMask |= 0x0c << (i * 8);
12071
+ continue;
12072
+ }
12042
12073
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
12043
12074
// by sizeof(Src2) = 4
12044
12075
int SrcByteAdjust = 4;
@@ -12062,10 +12093,14 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12062
12093
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
12063
12094
}
12064
12095
SDLoc DL(N);
12096
+ if (PermMask == 0x0c0c0c0c)
12097
+ return DAG.getConstant(0, DL, MVT::i32);
12098
+
12065
12099
SDValue Op = *PermNodes[FirstSrc.first].Src;
12066
12100
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
12067
12101
assert(Op.getValueSizeInBits() == 32);
12068
12102
12103
+ SDValue OtherOp;
12069
12104
// Check that we are not just extracting the bytes in order from an op
12070
12105
if (!SecondSrc) {
12071
12106
int Low16 = PermMask & 0xffff;
@@ -12077,17 +12112,17 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
12077
12112
// The perm op would really just produce Op. So combine into Op
12078
12113
if (WellFormedLow && WellFormedHi)
12079
12114
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
12080
- }
12081
12115
12082
- SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;
12116
+ OtherOp = Op;
12117
+ }
12083
12118
12084
12119
if (SecondSrc) {
12085
- OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
12120
+ OtherOp = getDWordFromOffset(DAG, DL, *PermNodes[SecondSrc->first].Src,
12121
+ SecondSrc->second);
12086
12122
assert(OtherOp.getValueSizeInBits() == 32);
12087
12123
}
12088
12124
12089
12125
if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
12090
-
12091
12126
assert(Op.getValueType().isByteSized() &&
12092
12127
OtherOp.getValueType().isByteSized());
12093
12128
@@ -12159,12 +12194,33 @@ SDValue SITargetLowering::performOrCombine(SDNode *N,
12159
12194
// If all the uses of an or need to extract the individual elements, do not
12160
12195
// attempt to lower into v_perm
12161
12196
auto usesCombinedOperand = [](SDNode *OrUse) {
12197
+ // The combined bytes seem to be getting extracted
12198
+ if (OrUse->getOpcode() == ISD::SRL || OrUse->getOpcode() == ISD::TRUNCATE)
12199
+ return false;
12200
+
12201
+ if (OrUse->getOpcode() == ISD::AND) {
12202
+ auto SelectMask = dyn_cast<ConstantSDNode>(OrUse->getOperand(1));
12203
+ if (SelectMask && (SelectMask->getZExtValue() == 0xFF))
12204
+ return false;
12205
+ }
12206
+
12207
+ if (OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE0 ||
12208
+ OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE1 ||
12209
+ OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE2 ||
12210
+ OrUse->getOpcode() == AMDGPUISD::CVT_F32_UBYTE3) {
12211
+ return false;
12212
+ }
12213
+
12214
+ if (auto StoreUse = dyn_cast<StoreSDNode>(OrUse))
12215
+ if (StoreUse->isTruncatingStore() &&
12216
+ StoreUse->getMemoryVT().getSizeInBits() == 8)
12217
+ return false;
12218
+
12162
12219
// If we have any non-vectorized use, then it is a candidate for v_perm
12163
- if (OrUse->getOpcode() != ISD::BITCAST ||
12164
- ! OrUse->getValueType(0).isVector( ))
12220
+ if (!( OrUse->getValueType(0).isVector() &&
12221
+ OrUse->getOpcode() != ISD::BUILD_VECTOR ))
12165
12222
return true;
12166
12223
12167
- // If we have any non-vectorized use, then it is a candidate for v_perm
12168
12224
for (auto VUse : OrUse->uses()) {
12169
12225
if (!VUse->getValueType(0).isVector())
12170
12226
return true;
0 commit comments