@@ -47237,32 +47237,37 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
47237
47237
// fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
47238
47238
// by forcing the unselected elements to zero.
47239
47239
// TODO: Can we handle more shuffles with this?
47240
- if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() &&
47241
- LHS.getOpcode() == X86ISD::PSHUFB && RHS.getOpcode() == X86ISD::PSHUFB &&
47242
- LHS.hasOneUse() && RHS.hasOneUse()) {
47243
- MVT SimpleVT = VT.getSimpleVT();
47240
+ if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() &&
47241
+ RHS.hasOneUse()) {
47244
47242
SmallVector<SDValue, 1> LHSOps, RHSOps;
47245
- SmallVector<int, 64> LHSMask, RHSMask, CondMask;
47246
- if (createShuffleMaskFromVSELECT(CondMask, Cond) &&
47247
- getTargetShuffleMask(LHS, true, LHSOps, LHSMask) &&
47248
- getTargetShuffleMask(RHS, true, RHSOps, RHSMask)) {
47249
- int NumElts = VT.getVectorNumElements();
47250
- for (int i = 0; i != NumElts; ++i) {
47243
+ SmallVector<int, 64> LHSMask, RHSMask, CondMask, ByteMask;
47244
+ SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47245
+ SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47246
+ if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47247
+ RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47248
+ createShuffleMaskFromVSELECT(CondMask, Cond) &&
47249
+ scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47250
+ getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47251
+ getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47252
+ assert(ByteMask.size() == LHSMask.size() &&
47253
+ ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47254
+ for (auto [I, M] : enumerate(ByteMask)) {
47251
47255
// getConstVector sets negative shuffle mask values as undef, so ensure
47252
47256
// we hardcode SM_SentinelZero values to zero (0x80).
47253
- if (CondMask[i] < NumElts ) {
47254
- LHSMask[i ] = isUndefOrZero(LHSMask[i ]) ? 0x80 : LHSMask[i ];
47255
- RHSMask[i ] = 0x80;
47257
+ if (M < ByteMask.size() ) {
47258
+ LHSMask[I ] = isUndefOrZero(LHSMask[I ]) ? 0x80 : LHSMask[I ];
47259
+ RHSMask[I ] = 0x80;
47256
47260
} else {
47257
- LHSMask[i ] = 0x80;
47258
- RHSMask[i ] = isUndefOrZero(RHSMask[i ]) ? 0x80 : RHSMask[i ];
47261
+ LHSMask[I ] = 0x80;
47262
+ RHSMask[I ] = isUndefOrZero(RHSMask[I ]) ? 0x80 : RHSMask[I ];
47259
47263
}
47260
47264
}
47261
- LHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, LHS.getOperand(0),
47262
- getConstVector(LHSMask, SimpleVT, DAG, DL, true));
47263
- RHS = DAG.getNode(X86ISD::PSHUFB, DL, VT, RHS.getOperand(0),
47264
- getConstVector(RHSMask, SimpleVT, DAG, DL, true));
47265
- return DAG.getNode(ISD::OR, DL, VT, LHS, RHS);
47265
+ MVT ByteVT = LHSShuf.getSimpleValueType();
47266
+ LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47267
+ getConstVector(LHSMask, ByteVT, DAG, DL, true));
47268
+ RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47269
+ getConstVector(RHSMask, ByteVT, DAG, DL, true));
47270
+ return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
47266
47271
}
47267
47272
}
47268
47273
0 commit comments