Skip to content

[X86] Fold VPERMV3(X,M,Y) -> VPERMV(CONCAT(X,Y),WIDEN(M)) iff the CONCAT is free #122485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41701,6 +41701,11 @@ static SDValue canonicalizeLaneShuffleWithRepeatedOps(SDValue V,
return SDValue();
}

static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
ArrayRef<SDValue> Ops, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget);

/// Try to combine x86 target specific shuffles.
static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
SelectionDAG &DAG,
Expand Down Expand Up @@ -42401,32 +42406,27 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
return SDValue();
}
case X86ISD::VPERMV3: {
SDValue V1 = peekThroughBitcasts(N.getOperand(0));
SDValue V2 = peekThroughBitcasts(N.getOperand(2));
MVT SVT = V1.getSimpleValueType();
// Combine VPERMV3 to widened VPERMV if the two source operands are split
// from the same vector.
if (V1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
V1.getConstantOperandVal(1) == 0 &&
V2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
V2.getConstantOperandVal(1) == SVT.getVectorNumElements() &&
V1.getOperand(0) == V2.getOperand(0)) {
EVT NVT = V1.getOperand(0).getValueType();
if (NVT.is256BitVector() ||
(NVT.is512BitVector() && Subtarget.hasEVEX512())) {
MVT WideVT = MVT::getVectorVT(
VT.getScalarType(), NVT.getSizeInBits() / VT.getScalarSizeInBits());
// Combine VPERMV3 to widened VPERMV if the two source operands can be
// freely concatenated.
if (VT.is128BitVector() ||
(VT.is256BitVector() && Subtarget.useAVX512Regs())) {
SDValue Ops[] = {N.getOperand(0), N.getOperand(2)};
MVT WideVT = VT.getDoubleNumVectorElementsVT();
if (SDValue ConcatSrc =
combineConcatVectorOps(DL, WideVT, Ops, DAG, DCI, Subtarget)) {
SDValue Mask = widenSubVector(N.getOperand(1), false, Subtarget, DAG,
DL, WideVT.getSizeInBits());
SDValue Perm = DAG.getNode(X86ISD::VPERMV, DL, WideVT, Mask,
DAG.getBitcast(WideVT, V1.getOperand(0)));
SDValue Perm = DAG.getNode(X86ISD::VPERMV, DL, WideVT, Mask, ConcatSrc);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Perm,
DAG.getIntPtrConstant(0, DL));
}
}
SmallVector<SDValue, 2> Ops;
SmallVector<int, 32> Mask;
if (getTargetShuffleMask(N, /*AllowSentinelZero=*/false, Ops, Mask)) {
assert(Mask.size() == NumElts && "Unexpected shuffle mask size");
SDValue V1 = peekThroughBitcasts(N.getOperand(0));
SDValue V2 = peekThroughBitcasts(N.getOperand(2));
MVT MaskVT = N.getOperand(1).getSimpleValueType();
// Canonicalize to VPERMV if both sources are the same.
if (V1 == V2) {
Expand Down Expand Up @@ -57369,10 +57369,8 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
Op0.getOperand(1));
}

// concat(extract_subvector(v0,c0), extract_subvector(v1,c1)) -> vperm2x128.
// Only concat of subvector high halves which vperm2x128 is best at.
// TODO: This should go in combineX86ShufflesRecursively eventually.
if (VT.is256BitVector() && NumOps == 2) {
if (NumOps == 2) {
SDValue Src0 = peekThroughBitcasts(Ops[0]);
SDValue Src1 = peekThroughBitcasts(Ops[1]);
if (Src0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
Expand All @@ -57381,14 +57379,25 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT,
EVT SrcVT1 = Src1.getOperand(0).getValueType();
unsigned NumSrcElts0 = SrcVT0.getVectorNumElements();
unsigned NumSrcElts1 = SrcVT1.getVectorNumElements();
if (SrcVT0.is256BitVector() && SrcVT1.is256BitVector() &&
// concat(extract_subvector(v0), extract_subvector(v1)) -> vperm2x128.
// Only concat of subvector high halves which vperm2x128 is best at.
if (VT.is256BitVector() && SrcVT0.is256BitVector() &&
SrcVT1.is256BitVector() &&
Src0.getConstantOperandAPInt(1) == (NumSrcElts0 / 2) &&
Src1.getConstantOperandAPInt(1) == (NumSrcElts1 / 2)) {
return DAG.getNode(X86ISD::VPERM2X128, DL, VT,
DAG.getBitcast(VT, Src0.getOperand(0)),
DAG.getBitcast(VT, Src1.getOperand(0)),
DAG.getTargetConstant(0x31, DL, MVT::i8));
}
// concat(extract_subvector(x,lo), extract_subvector(x,hi)) -> x.
if (Src0.getOperand(0) == Src1.getOperand(0) &&
Src0.getConstantOperandAPInt(1) == 0 &&
Src1.getConstantOperandAPInt(1) ==
Src0.getValueType().getVectorNumElements()) {
return DAG.getBitcast(VT, extractSubVector(Src0.getOperand(0), 0, DAG,
DL, VT.getSizeInBits()));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1337,10 +1337,9 @@ define void @vec256_i16_widen_to_i32_factor2_broadcast_to_v8i32_factor8(ptr %in.
;
; AVX512BW-LABEL: vec256_i16_widen_to_i32_factor2_broadcast_to_v8i32_factor8:
; AVX512BW: # %bb.0:
; AVX512BW-NEXT: vmovdqa (%rdi), %ymm0
; AVX512BW-NEXT: vpmovsxbw {{.*#+}} ymm1 = [0,17,0,19,0,21,0,23,0,25,0,27,0,29,0,31]
; AVX512BW-NEXT: vpermi2w 32(%rdi), %ymm0, %ymm1
; AVX512BW-NEXT: vpaddb (%rsi), %zmm1, %zmm0
; AVX512BW-NEXT: vpmovsxbw {{.*#+}} ymm0 = [0,17,0,19,0,21,0,23,0,25,0,27,0,29,0,31]
; AVX512BW-NEXT: vpermw (%rdi), %zmm0, %zmm0
; AVX512BW-NEXT: vpaddb (%rsi), %zmm0, %zmm0
; AVX512BW-NEXT: vmovdqa64 %zmm0, (%rdx)
; AVX512BW-NEXT: vzeroupper
; AVX512BW-NEXT: retq
Expand Down Expand Up @@ -1789,10 +1788,9 @@ define void @vec256_i64_widen_to_i128_factor2_broadcast_to_v2i128_factor2(ptr %i
;
; AVX512F-FAST-LABEL: vec256_i64_widen_to_i128_factor2_broadcast_to_v2i128_factor2:
; AVX512F-FAST: # %bb.0:
; AVX512F-FAST-NEXT: vmovdqa (%rdi), %ymm0
; AVX512F-FAST-NEXT: vpmovsxbq {{.*#+}} ymm1 = [0,5,0,7]
; AVX512F-FAST-NEXT: vpermi2q 32(%rdi), %ymm0, %ymm1
; AVX512F-FAST-NEXT: vpaddb (%rsi), %ymm1, %ymm0
; AVX512F-FAST-NEXT: vpmovsxbq {{.*#+}} ymm0 = [0,5,0,7]
; AVX512F-FAST-NEXT: vpermq (%rdi), %zmm0, %zmm0
; AVX512F-FAST-NEXT: vpaddb (%rsi), %ymm0, %ymm0
; AVX512F-FAST-NEXT: vmovdqa %ymm0, (%rdx)
; AVX512F-FAST-NEXT: vzeroupper
; AVX512F-FAST-NEXT: retq
Expand All @@ -1808,10 +1806,9 @@ define void @vec256_i64_widen_to_i128_factor2_broadcast_to_v2i128_factor2(ptr %i
;
; AVX512DQ-FAST-LABEL: vec256_i64_widen_to_i128_factor2_broadcast_to_v2i128_factor2:
; AVX512DQ-FAST: # %bb.0:
; AVX512DQ-FAST-NEXT: vmovdqa (%rdi), %ymm0
; AVX512DQ-FAST-NEXT: vpmovsxbq {{.*#+}} ymm1 = [0,5,0,7]
; AVX512DQ-FAST-NEXT: vpermi2q 32(%rdi), %ymm0, %ymm1
; AVX512DQ-FAST-NEXT: vpaddb (%rsi), %ymm1, %ymm0
; AVX512DQ-FAST-NEXT: vpmovsxbq {{.*#+}} ymm0 = [0,5,0,7]
; AVX512DQ-FAST-NEXT: vpermq (%rdi), %zmm0, %zmm0
; AVX512DQ-FAST-NEXT: vpaddb (%rsi), %ymm0, %ymm0
; AVX512DQ-FAST-NEXT: vmovdqa %ymm0, (%rdx)
; AVX512DQ-FAST-NEXT: vzeroupper
; AVX512DQ-FAST-NEXT: retq
Expand All @@ -1827,10 +1824,9 @@ define void @vec256_i64_widen_to_i128_factor2_broadcast_to_v2i128_factor2(ptr %i
;
; AVX512BW-FAST-LABEL: vec256_i64_widen_to_i128_factor2_broadcast_to_v2i128_factor2:
; AVX512BW-FAST: # %bb.0:
; AVX512BW-FAST-NEXT: vmovdqa (%rdi), %ymm0
; AVX512BW-FAST-NEXT: vpmovsxbq {{.*#+}} ymm1 = [0,5,0,7]
; AVX512BW-FAST-NEXT: vpermi2q 32(%rdi), %ymm0, %ymm1
; AVX512BW-FAST-NEXT: vpaddb (%rsi), %zmm1, %zmm0
; AVX512BW-FAST-NEXT: vpmovsxbq {{.*#+}} ymm0 = [0,5,0,7]
; AVX512BW-FAST-NEXT: vpermq (%rdi), %zmm0, %zmm0
; AVX512BW-FAST-NEXT: vpaddb (%rsi), %zmm0, %zmm0
; AVX512BW-FAST-NEXT: vmovdqa64 %zmm0, (%rdx)
; AVX512BW-FAST-NEXT: vzeroupper
; AVX512BW-FAST-NEXT: retq
Expand Down
Loading
Loading