Skip to content

Commit 8e7618a

Browse files
authored
[X86] Fold BLEND(PERMUTE(X),PERMUTE(Y)) -> PERMUTE(BLEND(X,Y)) (#90219)
If we don't demand the same element from both single source shuffles (permutes), then attempt to blend the sources together first and then perform a merged permute. For vXi16 blends we have to be careful as these are much more likely to involve byte/word vector shuffles that will result in the creation of additional shuffle instructions. This fold might be worth it for VSELECT with constant masks on AVX512 targets, but I haven't investigated this yet, but I've tried to write combineBlendOfPermutes so to be prepared for this. The PR34592 -O0 regression is an unfortunate failure to cleanup with a later pass that calls SimplifyDemandedElts like the -O3 does - I'm not sure how worried we should be tbh.
1 parent 0a0cac6 commit 8e7618a

18 files changed

+12710
-14157
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,6 +3577,16 @@ static bool isUndefOrZeroOrInRange(ArrayRef<int> Mask, int Low, int Hi) {
35773577
Mask, [Low, Hi](int M) { return isUndefOrZeroOrInRange(M, Low, Hi); });
35783578
}
35793579

3580+
/// Return true if every element in Mask, is an in-place blend/select mask or is
3581+
/// undef.
3582+
static bool isBlendOrUndef(ArrayRef<int> Mask) {
3583+
unsigned NumElts = Mask.size();
3584+
for (auto [I, M] : enumerate(Mask))
3585+
if (!isUndefOrEqual(M, I) && !isUndefOrEqual(M, I + NumElts))
3586+
return false;
3587+
return true;
3588+
}
3589+
35803590
/// Return true if every element in Mask, beginning
35813591
/// from position Pos and ending in Pos + Size, falls within the specified
35823592
/// sequence (Low, Low + Step, ..., Low + (Size - 1) * Step) or is undef.
@@ -40021,6 +40031,93 @@ static SDValue combineCommutableSHUFP(SDValue N, MVT VT, const SDLoc &DL,
4002140031
return SDValue();
4002240032
}
4002340033

40034+
// Attempt to fold BLEND(PERMUTE(X),PERMUTE(Y)) -> PERMUTE(BLEND(X,Y))
40035+
// iff we don't demand the same element index for both X and Y.
40036+
static SDValue combineBlendOfPermutes(MVT VT, SDValue N0, SDValue N1,
40037+
ArrayRef<int> BlendMask,
40038+
const APInt &DemandedElts,
40039+
SelectionDAG &DAG, const SDLoc &DL) {
40040+
assert(isBlendOrUndef(BlendMask) && "Blend shuffle expected");
40041+
if (!N0.hasOneUse() || !N1.hasOneUse())
40042+
return SDValue();
40043+
40044+
unsigned NumElts = VT.getVectorNumElements();
40045+
SDValue BC0 = peekThroughOneUseBitcasts(N0);
40046+
SDValue BC1 = peekThroughOneUseBitcasts(N1);
40047+
40048+
// See if both operands are shuffles, and that we can scale the shuffle masks
40049+
// to the same width as the blend mask.
40050+
// TODO: Support SM_SentinelZero?
40051+
SmallVector<SDValue, 2> Ops0, Ops1;
40052+
SmallVector<int, 32> Mask0, Mask1, ScaledMask0, ScaledMask1;
40053+
if (!getTargetShuffleMask(BC0, /*AllowSentinelZero=*/false, Ops0, Mask0) ||
40054+
!getTargetShuffleMask(BC1, /*AllowSentinelZero=*/false, Ops1, Mask1) ||
40055+
!scaleShuffleElements(Mask0, NumElts, ScaledMask0) ||
40056+
!scaleShuffleElements(Mask1, NumElts, ScaledMask1))
40057+
return SDValue();
40058+
40059+
// Determine the demanded elts from both permutes.
40060+
APInt Demanded0, DemandedLHS0, DemandedRHS0;
40061+
APInt Demanded1, DemandedLHS1, DemandedRHS1;
40062+
if (!getShuffleDemandedElts(NumElts, BlendMask, DemandedElts, Demanded0,
40063+
Demanded1,
40064+
/*AllowUndefElts=*/true) ||
40065+
!getShuffleDemandedElts(NumElts, ScaledMask0, Demanded0, DemandedLHS0,
40066+
DemandedRHS0, /*AllowUndefElts=*/true) ||
40067+
!getShuffleDemandedElts(NumElts, ScaledMask1, Demanded1, DemandedLHS1,
40068+
DemandedRHS1, /*AllowUndefElts=*/true))
40069+
return SDValue();
40070+
40071+
// Confirm that we only use a single operand from both permutes and that we
40072+
// don't demand the same index from both.
40073+
if (!DemandedRHS0.isZero() || !DemandedRHS1.isZero() ||
40074+
DemandedLHS0.intersects(DemandedLHS1))
40075+
return SDValue();
40076+
40077+
// Use the permute demanded elts masks as the new blend mask.
40078+
// Create the new permute mask as a blend of the 2 original permute masks.
40079+
SmallVector<int, 32> NewBlendMask(NumElts, SM_SentinelUndef);
40080+
SmallVector<int, 32> NewPermuteMask(NumElts, SM_SentinelUndef);
40081+
for (int I = 0; I != NumElts; ++I) {
40082+
if (Demanded0[I]) {
40083+
int M = ScaledMask0[I];
40084+
if (0 <= M) {
40085+
assert(isUndefOrEqual(NewBlendMask[M], M) &&
40086+
"BlendMask demands LHS AND RHS");
40087+
NewBlendMask[M] = M;
40088+
NewPermuteMask[I] = M;
40089+
}
40090+
} else if (Demanded1[I]) {
40091+
int M = ScaledMask1[I];
40092+
if (0 <= M) {
40093+
assert(isUndefOrEqual(NewBlendMask[M], M + NumElts) &&
40094+
"BlendMask demands LHS AND RHS");
40095+
NewBlendMask[M] = M + NumElts;
40096+
NewPermuteMask[I] = M;
40097+
}
40098+
}
40099+
}
40100+
assert(isBlendOrUndef(NewBlendMask) && "Bad blend");
40101+
assert(isUndefOrInRange(NewPermuteMask, 0, NumElts) && "Bad permute");
40102+
40103+
// v16i16 shuffles can explode in complexity very easily, only accept them if
40104+
// the blend mask is the same in the 128-bit subvectors (or can widen to
40105+
// v8i32) and the permute can be widened as well.
40106+
if (VT == MVT::v16i16) {
40107+
if (!is128BitLaneRepeatedShuffleMask(VT, NewBlendMask) &&
40108+
!canWidenShuffleElements(NewBlendMask))
40109+
return SDValue();
40110+
if (!canWidenShuffleElements(NewPermuteMask))
40111+
return SDValue();
40112+
}
40113+
40114+
SDValue NewBlend =
40115+
DAG.getVectorShuffle(VT, DL, DAG.getBitcast(VT, Ops0[0]),
40116+
DAG.getBitcast(VT, Ops1[0]), NewBlendMask);
40117+
return DAG.getVectorShuffle(VT, DL, NewBlend, DAG.getUNDEF(VT),
40118+
NewPermuteMask);
40119+
}
40120+
4002440121
// TODO - move this to TLI like isBinOp?
4002540122
static bool isUnaryOp(unsigned Opcode) {
4002640123
switch (Opcode) {
@@ -41773,6 +41870,15 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
4177341870
KnownUndef = SrcUndef.zextOrTrunc(NumElts);
4177441871
break;
4177541872
}
41873+
case X86ISD::BLENDI: {
41874+
SmallVector<int, 16> BlendMask;
41875+
DecodeBLENDMask(NumElts, Op.getConstantOperandVal(2), BlendMask);
41876+
if (SDValue R = combineBlendOfPermutes(VT.getSimpleVT(), Op.getOperand(0),
41877+
Op.getOperand(1), BlendMask,
41878+
DemandedElts, TLO.DAG, SDLoc(Op)))
41879+
return TLO.CombineTo(Op, R);
41880+
break;
41881+
}
4177641882
case X86ISD::BLENDV: {
4177741883
APInt SelUndef, SelZero;
4177841884
if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, SelUndef,

llvm/test/CodeGen/X86/horizontal-sum.ll

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -679,9 +679,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
679679
; AVX1-SLOW-NEXT: vphaddd %xmm1, %xmm0, %xmm4
680680
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
681681
; AVX1-SLOW-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
682-
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
683-
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
684-
; AVX1-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5,6,7]
682+
; AVX1-SLOW-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
683+
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
685684
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[1,1,1,1]
686685
; AVX1-SLOW-NEXT: vpaddd %xmm2, %xmm1, %xmm1
687686
; AVX1-SLOW-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
@@ -704,9 +703,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
704703
; AVX1-FAST-NEXT: vphaddd %xmm1, %xmm0, %xmm4
705704
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
706705
; AVX1-FAST-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
707-
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
708-
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
709-
; AVX1-FAST-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5,6,7]
706+
; AVX1-FAST-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
707+
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
710708
; AVX1-FAST-NEXT: vphaddd %xmm2, %xmm2, %xmm1
711709
; AVX1-FAST-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
712710
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm2[3,3,3,3]
@@ -727,9 +725,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
727725
; AVX2-SLOW-NEXT: vphaddd %xmm1, %xmm0, %xmm4
728726
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
729727
; AVX2-SLOW-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
730-
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
731-
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
732-
; AVX2-SLOW-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2,3]
728+
; AVX2-SLOW-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
729+
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
733730
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[1,1,1,1]
734731
; AVX2-SLOW-NEXT: vpaddd %xmm2, %xmm1, %xmm1
735732
; AVX2-SLOW-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
@@ -752,9 +749,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
752749
; AVX2-FAST-NEXT: vphaddd %xmm1, %xmm0, %xmm4
753750
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
754751
; AVX2-FAST-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
755-
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
756-
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
757-
; AVX2-FAST-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2,3]
752+
; AVX2-FAST-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
753+
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
758754
; AVX2-FAST-NEXT: vphaddd %xmm2, %xmm2, %xmm1
759755
; AVX2-FAST-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
760756
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm2[3,3,3,3]

0 commit comments

Comments
 (0)