Skip to content

[X86] matchBinaryPermuteShuffle - match AVX512 "cross lane" SHLDQ/SRLDQ style patterns using VALIGN #140538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10096,7 +10096,10 @@ static bool isTargetShuffleEquivalent(MVT VT, ArrayRef<int> Mask,
if (Size != (int)ExpectedMask.size())
return false;
assert(llvm::all_of(ExpectedMask,
[Size](int M) { return isInRange(M, 0, 2 * Size); }) &&
[Size](int M) {
return M == SM_SentinelZero ||
isInRange(M, 0, 2 * Size);
}) &&
"Illegal target shuffle mask");

// Check for out-of-range target shuffle mask indices.
Expand All @@ -10119,6 +10122,9 @@ static bool isTargetShuffleEquivalent(MVT VT, ArrayRef<int> Mask,
int ExpectedIdx = ExpectedMask[i];
if (MaskIdx == SM_SentinelUndef || MaskIdx == ExpectedIdx)
continue;
// If we failed to match an expected SM_SentinelZero then early out.
if (ExpectedIdx < 0)
return false;
if (MaskIdx == SM_SentinelZero) {
// If we need this expected index to be a zero element, then update the
// relevant zero mask and perform the known bits at the end to minimize
Expand Down Expand Up @@ -39594,18 +39600,46 @@ static bool matchBinaryPermuteShuffle(
((MaskVT.is128BitVector() && Subtarget.hasVLX()) ||
(MaskVT.is256BitVector() && Subtarget.hasVLX()) ||
(MaskVT.is512BitVector() && Subtarget.hasAVX512()))) {
MVT AlignVT = MVT::getVectorVT(MVT::getIntegerVT(EltSizeInBits),
MaskVT.getSizeInBits() / EltSizeInBits);
if (!isAnyZero(Mask)) {
int Rotation = matchShuffleAsElementRotate(V1, V2, Mask);
if (0 < Rotation) {
Shuffle = X86ISD::VALIGN;
if (EltSizeInBits == 64)
ShuffleVT = MVT::getVectorVT(MVT::i64, MaskVT.getSizeInBits() / 64);
else
ShuffleVT = MVT::getVectorVT(MVT::i32, MaskVT.getSizeInBits() / 32);
ShuffleVT = AlignVT;
PermuteImm = Rotation;
return true;
}
}
// See if we can use VALIGN as a cross-lane version of VSHLDQ/VSRLDQ.
unsigned ZeroLo = Zeroable.countr_one();
unsigned ZeroHi = Zeroable.countl_one();
assert((ZeroLo + ZeroHi) < NumMaskElts && "Zeroable shuffle detected");
if (ZeroLo) {
SmallVector<int, 16> ShiftMask(NumMaskElts, SM_SentinelZero);
std::iota(ShiftMask.begin() + ZeroLo, ShiftMask.end(), 0);
if (isTargetShuffleEquivalent(MaskVT, Mask, ShiftMask, DAG, V1)) {
V1 = V1;
V2 = getZeroVector(AlignVT, Subtarget, DAG, DL);
Shuffle = X86ISD::VALIGN;
ShuffleVT = AlignVT;
PermuteImm = NumMaskElts - ZeroLo;
return true;
}
}
if (ZeroHi) {
SmallVector<int, 16> ShiftMask(NumMaskElts, SM_SentinelZero);
std::iota(ShiftMask.begin(), ShiftMask.begin() + NumMaskElts - ZeroHi,
ZeroHi);
if (isTargetShuffleEquivalent(MaskVT, Mask, ShiftMask, DAG, V1)) {
V2 = V1;
V1 = getZeroVector(AlignVT, Subtarget, DAG, DL);
Shuffle = X86ISD::VALIGN;
ShuffleVT = AlignVT;
PermuteImm = ZeroHi;
return true;
}
}
}

// Attempt to match against PALIGNR byte rotate.
Expand Down
7 changes: 2 additions & 5 deletions llvm/test/CodeGen/X86/vector-shuffle-combining-avx512f.ll
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,8 @@ define <8 x i64> @combine_vpermt2var_8i64_as_valignq(<8 x i64> %x0, <8 x i64> %x
define <8 x i64> @combine_vpermt2var_8i64_as_valignq_zero(<8 x i64> %x0) {
; CHECK-LABEL: combine_vpermt2var_8i64_as_valignq_zero:
; CHECK: # %bb.0:
; CHECK-NEXT: vpmovsxbq {{.*#+}} zmm2 = [15,0,1,2,3,4,5,6]
; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1
; CHECK-NEXT: vpermt2q %zmm0, %zmm2, %zmm1
; CHECK-NEXT: vmovdqa64 %zmm1, %zmm0
; CHECK-NEXT: valignq {{.*#+}} zmm0 = zmm0[7],zmm1[0,1,2,3,4,5,6]
; CHECK-NEXT: ret{{[l|q]}}
%res0 = call <8 x i64> @llvm.x86.avx512.maskz.vpermt2var.q.512(<8 x i64> <i64 15, i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6>, <8 x i64> zeroinitializer, <8 x i64> %x0, i8 -1)
ret <8 x i64> %res0
Expand All @@ -825,8 +823,7 @@ define <8 x i64> @combine_vpermt2var_8i64_as_zero_valignq(<8 x i64> %x0) {
; CHECK-LABEL: combine_vpermt2var_8i64_as_zero_valignq:
; CHECK: # %bb.0:
; CHECK-NEXT: vpxor %xmm1, %xmm1, %xmm1
; CHECK-NEXT: vpmovsxbq {{.*#+}} zmm2 = [15,0,1,2,3,4,5,6]
; CHECK-NEXT: vpermt2q %zmm1, %zmm2, %zmm0
; CHECK-NEXT: valignq {{.*#+}} zmm0 = zmm1[7],zmm0[0,1,2,3,4,5,6]
; CHECK-NEXT: ret{{[l|q]}}
%res0 = call <8 x i64> @llvm.x86.avx512.maskz.vpermt2var.q.512(<8 x i64> <i64 15, i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6>, <8 x i64> %x0, <8 x i64> zeroinitializer, i8 -1)
ret <8 x i64> %res0
Expand Down
Loading