Skip to content

[X86] Fold BLEND(PERMUTE(X),PERMUTE(Y)) -> PERMUTE(BLEND(X,Y)) #90219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3577,6 +3577,16 @@ static bool isUndefOrZeroOrInRange(ArrayRef<int> Mask, int Low, int Hi) {
Mask, [Low, Hi](int M) { return isUndefOrZeroOrInRange(M, Low, Hi); });
}

/// Return true if every element in Mask, is an in-place blend/select mask or is
/// undef.
static bool isBlendOrUndef(ArrayRef<int> Mask) {
unsigned NumElts = Mask.size();
for (auto [I, M] : enumerate(Mask))
if (!isUndefOrEqual(M, I) && !isUndefOrEqual(M, I + NumElts))
return false;
return true;
}

/// Return true if every element in Mask, beginning
/// from position Pos and ending in Pos + Size, falls within the specified
/// sequence (Low, Low + Step, ..., Low + (Size - 1) * Step) or is undef.
Expand Down Expand Up @@ -40019,6 +40029,93 @@ static SDValue combineCommutableSHUFP(SDValue N, MVT VT, const SDLoc &DL,
return SDValue();
}

// Attempt to fold BLEND(PERMUTE(X),PERMUTE(Y)) -> PERMUTE(BLEND(X,Y))
// iff we don't demand the same element index for both X and Y.
static SDValue combineBlendOfPermutes(MVT VT, SDValue N0, SDValue N1,
ArrayRef<int> BlendMask,
const APInt &DemandedElts,
SelectionDAG &DAG, const SDLoc &DL) {
assert(isBlendOrUndef(BlendMask) && "Blend shuffle expected");
if (!N0.hasOneUse() || !N1.hasOneUse())
return SDValue();

unsigned NumElts = VT.getVectorNumElements();
SDValue BC0 = peekThroughOneUseBitcasts(N0);
SDValue BC1 = peekThroughOneUseBitcasts(N1);

// See if both operands are shuffles, and that we can scale the shuffle masks
// to the same width as the blend mask.
// TODO: Support SM_SentinelZero?
SmallVector<SDValue, 2> Ops0, Ops1;
SmallVector<int, 32> Mask0, Mask1, ScaledMask0, ScaledMask1;
if (!getTargetShuffleMask(BC0, /*AllowSentinelZero=*/false, Ops0, Mask0) ||
!getTargetShuffleMask(BC1, /*AllowSentinelZero=*/false, Ops1, Mask1) ||
!scaleShuffleElements(Mask0, NumElts, ScaledMask0) ||
!scaleShuffleElements(Mask1, NumElts, ScaledMask1))
return SDValue();

// Determine the demanded elts from both permutes.
APInt Demanded0, DemandedLHS0, DemandedRHS0;
APInt Demanded1, DemandedLHS1, DemandedRHS1;
if (!getShuffleDemandedElts(NumElts, BlendMask, DemandedElts, Demanded0,
Demanded1,
/*AllowUndefElts=*/true) ||
!getShuffleDemandedElts(NumElts, ScaledMask0, Demanded0, DemandedLHS0,
DemandedRHS0, /*AllowUndefElts=*/true) ||
!getShuffleDemandedElts(NumElts, ScaledMask1, Demanded1, DemandedLHS1,
DemandedRHS1, /*AllowUndefElts=*/true))
return SDValue();

// Confirm that we only use a single operand from both permutes and that we
// don't demand the same index from both.
if (!DemandedRHS0.isZero() || !DemandedRHS1.isZero() ||
DemandedLHS0.intersects(DemandedLHS1))
return SDValue();

// Use the permute demanded elts masks as the new blend mask.
// Create the new permute mask as a blend of the 2 original permute masks.
SmallVector<int, 32> NewBlendMask(NumElts, SM_SentinelUndef);
SmallVector<int, 32> NewPermuteMask(NumElts, SM_SentinelUndef);
for (int I = 0; I != NumElts; ++I) {
if (Demanded0[I]) {
int M = ScaledMask0[I];
if (0 <= M) {
assert(isUndefOrEqual(NewBlendMask[M], M) &&
"BlendMask demands LHS AND RHS");
NewBlendMask[M] = M;
NewPermuteMask[I] = M;
}
} else if (Demanded1[I]) {
int M = ScaledMask1[I];
if (0 <= M) {
assert(isUndefOrEqual(NewBlendMask[M], M + NumElts) &&
"BlendMask demands LHS AND RHS");
NewBlendMask[M] = M + NumElts;
NewPermuteMask[I] = M;
}
}
}
assert(isBlendOrUndef(NewBlendMask) && "Bad blend");
assert(isUndefOrInRange(NewPermuteMask, 0, NumElts) && "Bad permute");

// v16i16 shuffles can explode in complexity very easily, only accept them if
// the blend mask is the same in the 128-bit subvectors (or can widen to
// v8i32) and the permute can be widened as well.
if (VT == MVT::v16i16) {
if (!is128BitLaneRepeatedShuffleMask(VT, NewBlendMask) &&
!canWidenShuffleElements(NewBlendMask))
return SDValue();
if (!canWidenShuffleElements(NewPermuteMask))
return SDValue();
}

SDValue NewBlend =
DAG.getVectorShuffle(VT, DL, DAG.getBitcast(VT, Ops0[0]),
DAG.getBitcast(VT, Ops1[0]), NewBlendMask);
return DAG.getVectorShuffle(VT, DL, NewBlend, DAG.getUNDEF(VT),
NewPermuteMask);
}

// TODO - move this to TLI like isBinOp?
static bool isUnaryOp(unsigned Opcode) {
switch (Opcode) {
Expand Down Expand Up @@ -41771,6 +41868,15 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
KnownUndef = SrcUndef.zextOrTrunc(NumElts);
break;
}
case X86ISD::BLENDI: {
SmallVector<int, 16> BlendMask;
DecodeBLENDMask(NumElts, Op.getConstantOperandVal(2), BlendMask);
if (SDValue R = combineBlendOfPermutes(VT.getSimpleVT(), Op.getOperand(0),
Op.getOperand(1), BlendMask,
DemandedElts, TLO.DAG, SDLoc(Op)))
return TLO.CombineTo(Op, R);
break;
}
case X86ISD::BLENDV: {
APInt SelUndef, SelZero;
if (SimplifyDemandedVectorElts(Op.getOperand(0), DemandedElts, SelUndef,
Expand Down
20 changes: 8 additions & 12 deletions llvm/test/CodeGen/X86/horizontal-sum.ll
Original file line number Diff line number Diff line change
Expand Up @@ -679,9 +679,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
; AVX1-SLOW-NEXT: vphaddd %xmm1, %xmm0, %xmm4
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
; AVX1-SLOW-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
; AVX1-SLOW-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5,6,7]
; AVX1-SLOW-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
; AVX1-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[1,1,1,1]
; AVX1-SLOW-NEXT: vpaddd %xmm2, %xmm1, %xmm1
; AVX1-SLOW-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
Expand All @@ -704,9 +703,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
; AVX1-FAST-NEXT: vphaddd %xmm1, %xmm0, %xmm4
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
; AVX1-FAST-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
; AVX1-FAST-NEXT: vpblendw {{.*#+}} xmm0 = xmm0[0,1],xmm1[2,3],xmm0[4,5,6,7]
; AVX1-FAST-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
; AVX1-FAST-NEXT: vphaddd %xmm2, %xmm2, %xmm1
; AVX1-FAST-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
; AVX1-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm2[3,3,3,3]
Expand All @@ -727,9 +725,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
; AVX2-SLOW-NEXT: vphaddd %xmm1, %xmm0, %xmm4
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
; AVX2-SLOW-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
; AVX2-SLOW-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2,3]
; AVX2-SLOW-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
; AVX2-SLOW-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[1,1,1,1]
; AVX2-SLOW-NEXT: vpaddd %xmm2, %xmm1, %xmm1
; AVX2-SLOW-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
Expand All @@ -752,9 +749,8 @@ define <4 x i32> @sequential_sum_v4i32_v4i32(<4 x i32> %0, <4 x i32> %1, <4 x i3
; AVX2-FAST-NEXT: vphaddd %xmm1, %xmm0, %xmm4
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm4[0,2,2,3]
; AVX2-FAST-NEXT: vpunpckhdq {{.*#+}} xmm5 = xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[3,3,3,3]
; AVX2-FAST-NEXT: vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2,3]
; AVX2-FAST-NEXT: vpunpckhqdq {{.*#+}} xmm0 = xmm0[1],xmm1[1]
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm0 = xmm0[1,3,2,3]
; AVX2-FAST-NEXT: vphaddd %xmm2, %xmm2, %xmm1
; AVX2-FAST-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm4[0],xmm1[0]
; AVX2-FAST-NEXT: vpshufd {{.*#+}} xmm4 = xmm2[3,3,3,3]
Expand Down
Loading