Skip to content

Commit 1b17d1e

Browse files
authored
[X86] Allow select(cond,pshufb,pshufb) -> or(pshufb,pshufb) fold to peek through bitcasts (#128876)
Peek through one use bitcasts and rescale the condition mask to a vXi8 type to allow more aggressive use of pshufb zeroing.
1 parent 7f33242 commit 1b17d1e

File tree

3 files changed

+127
-164
lines changed

3 files changed

+127
-164
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47237,32 +47237,37 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
4723747237
// fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
4723847238
// by forcing the unselected elements to zero.
4723947239
// TODO: Can we handle more shuffles with this?
47240-
if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() &&
47241-
LHS.getOpcode() == X86ISD::PSHUFB && RHS.getOpcode() == X86ISD::PSHUFB &&
47242-
LHS.hasOneUse() && RHS.hasOneUse()) {
47243-
MVT SimpleVT = VT.getSimpleVT();
47240+
if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() &&
47241+
RHS.hasOneUse()) {
4724447242
SmallVector<SDValue, 1> LHSOps, RHSOps;
47245-
SmallVector<int, 64> LHSMask, RHSMask, CondMask;
47246-
if (createShuffleMaskFromVSELECT(CondMask, Cond) &&
47247-
getTargetShuffleMask(LHS, true, LHSOps, LHSMask) &&
47248-
getTargetShuffleMask(RHS, true, RHSOps, RHSMask)) {
47249-
int NumElts = VT.getVectorNumElements();
47250-
for (int i = 0; i != NumElts; ++i) {
47243+
SmallVector<int, 64> LHSMask, RHSMask, CondMask, ByteMask;
47244+
SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47245+
SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47246+
if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47247+
RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47248+
createShuffleMaskFromVSELECT(CondMask, Cond) &&
47249+
scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47250+
getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47251+
getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47252+
assert(ByteMask.size() == LHSMask.size() &&
47253+
ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47254+
for (auto [I, M] : enumerate(ByteMask)) {
4725147255
// getConstVector sets negative shuffle mask values as undef, so ensure
4725247256
// we hardcode SM_SentinelZero values to zero (0x80).
47253-
if (CondMask[i] < NumElts) {
47254-
LHSMask[i] = isUndefOrZero(LHSMask[i]) ? 0x80 : LHSMask[i];
47255-
RHSMask[i] = 0x80;
47257+
if (M < ByteMask.size()) {
47258+
LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
47259+
RHSMask[I] = 0x80;
4725647260
} else {
47257-
LHSMask[i] = 0x80;
47258-
RHSMask[i] = isUndefOrZero(RHSMask[i]) ? 0x80 : RHSMask[i];
47261+
LHSMask[I] = 0x80;
47262+
RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
4725947263
}
4726047264
}
47261-
LHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, LHS.getOperand(0),
47262-
getConstVector(LHSMask, SimpleVT, DAG, DL, true));
47263-
RHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, RHS.getOperand(0),
47264-
getConstVector(RHSMask, SimpleVT, DAG, DL, true));
47265-
return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
47265+
MVT ByteVT = LHSShuf.getSimpleValueType();
47266+
LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47267+
getConstVector(LHSMask, ByteVT, DAG, DL, true));
47268+
RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47269+
getConstVector(RHSMask, ByteVT, DAG, DL, true));
47270+
return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
4726647271
}
4726747272
}
4726847273

0 commit comments

Comments
 (0)