@@ -47690,12 +47690,47 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
47690
47690
return V;
47691
47691
47692
47692
if (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV) {
47693
- SmallVector<int, 64> Mask ;
47694
- if (createShuffleMaskFromVSELECT(Mask , Cond,
47693
+ SmallVector<int, 64> CondMask ;
47694
+ if (createShuffleMaskFromVSELECT(CondMask , Cond,
47695
47695
N->getOpcode() == X86ISD::BLENDV)) {
47696
47696
// Convert vselects with constant condition into shuffles.
47697
47697
if (DCI.isBeforeLegalizeOps())
47698
- return DAG.getVectorShuffle(VT, DL, LHS, RHS, Mask);
47698
+ return DAG.getVectorShuffle(VT, DL, LHS, RHS, CondMask);
47699
+
47700
+ // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
47701
+ // by forcing the unselected elements to zero.
47702
+ // TODO: Can we handle more shuffles with this?
47703
+ if (LHS.hasOneUse() && RHS.hasOneUse()) {
47704
+ SmallVector<SDValue, 1> LHSOps, RHSOps;
47705
+ SmallVector<int, 64> LHSMask, RHSMask, ByteMask;
47706
+ SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47707
+ SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47708
+ if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47709
+ RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47710
+ scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47711
+ getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47712
+ getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47713
+ assert(ByteMask.size() == LHSMask.size() &&
47714
+ ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47715
+ for (auto [I, M] : enumerate(ByteMask)) {
47716
+ // getConstVector sets negative shuffle mask values as undef, so
47717
+ // ensure we hardcode SM_SentinelZero values to zero (0x80).
47718
+ if (M < (int)ByteMask.size()) {
47719
+ LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
47720
+ RHSMask[I] = 0x80;
47721
+ } else {
47722
+ LHSMask[I] = 0x80;
47723
+ RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
47724
+ }
47725
+ }
47726
+ MVT ByteVT = LHSShuf.getSimpleValueType();
47727
+ LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47728
+ getConstVector(LHSMask, ByteVT, DAG, DL, true));
47729
+ RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47730
+ getConstVector(RHSMask, ByteVT, DAG, DL, true));
47731
+ return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
47732
+ }
47733
+ }
47699
47734
47700
47735
// Attempt to combine as shuffle.
47701
47736
SDValue Op(N, 0);
@@ -47704,43 +47739,6 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
47704
47739
}
47705
47740
}
47706
47741
47707
- // fold vselect(cond, pshufb(x), pshufb(y)) -> or (pshufb(x), pshufb(y))
47708
- // by forcing the unselected elements to zero.
47709
- // TODO: Can we handle more shuffles with this?
47710
- if (N->getOpcode() == ISD::VSELECT && CondVT.isVector() && LHS.hasOneUse() &&
47711
- RHS.hasOneUse()) {
47712
- SmallVector<SDValue, 1> LHSOps, RHSOps;
47713
- SmallVector<int, 64> LHSMask, RHSMask, CondMask, ByteMask;
47714
- SDValue LHSShuf = peekThroughOneUseBitcasts(LHS);
47715
- SDValue RHSShuf = peekThroughOneUseBitcasts(RHS);
47716
- if (LHSShuf.getOpcode() == X86ISD::PSHUFB &&
47717
- RHSShuf.getOpcode() == X86ISD::PSHUFB &&
47718
- createShuffleMaskFromVSELECT(CondMask, Cond) &&
47719
- scaleShuffleMaskElts(VT.getSizeInBits() / 8, CondMask, ByteMask) &&
47720
- getTargetShuffleMask(LHSShuf, true, LHSOps, LHSMask) &&
47721
- getTargetShuffleMask(RHSShuf, true, RHSOps, RHSMask)) {
47722
- assert(ByteMask.size() == LHSMask.size() &&
47723
- ByteMask.size() == RHSMask.size() && "Shuffle mask mismatch");
47724
- for (auto [I, M] : enumerate(ByteMask)) {
47725
- // getConstVector sets negative shuffle mask values as undef, so ensure
47726
- // we hardcode SM_SentinelZero values to zero (0x80).
47727
- if (M < (int)ByteMask.size()) {
47728
- LHSMask[I] = isUndefOrZero(LHSMask[I]) ? 0x80 : LHSMask[I];
47729
- RHSMask[I] = 0x80;
47730
- } else {
47731
- LHSMask[I] = 0x80;
47732
- RHSMask[I] = isUndefOrZero(RHSMask[I]) ? 0x80 : RHSMask[I];
47733
- }
47734
- }
47735
- MVT ByteVT = LHSShuf.getSimpleValueType();
47736
- LHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, LHSOps[0],
47737
- getConstVector(LHSMask, ByteVT, DAG, DL, true));
47738
- RHS = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, RHSOps[0],
47739
- getConstVector(RHSMask, ByteVT, DAG, DL, true));
47740
- return DAG.getBitcast(VT, DAG.getNode(ISD::OR, DL, ByteVT, LHS, RHS));
47741
- }
47742
- }
47743
-
47744
47742
// If we have SSE[12] support, try to form min/max nodes. SSE min/max
47745
47743
// instructions match the semantics of the common C idiom x<y?x:y but not
47746
47744
// x<=y?x:y, because of how they handle negative zero (which can be
0 commit comments