Skip to content

[VectorCombine] Allow shuffling between vectors the same type but different element sizes #121216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3067,42 +3067,73 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
m_ConstantInt(InsIdx))))
return false;

auto *VecTy = dyn_cast<FixedVectorType>(I.getType());
if (!VecTy || SrcVec->getType() != VecTy)
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
// We can try combining vectors with different element sizes.
if (!DstVecTy || !SrcVecTy ||
SrcVecTy->getElementType() != DstVecTy->getElementType())
return false;

unsigned NumElts = VecTy->getNumElements();
if (ExtIdx >= NumElts || InsIdx >= NumElts)
unsigned NumDstElts = DstVecTy->getNumElements();
unsigned NumSrcElts = SrcVecTy->getNumElements();
if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
return false;

// Insertion into poison is a cheaper single operand shuffle.
TargetTransformInfo::ShuffleKind SK;
SmallVector<int> Mask(NumElts, PoisonMaskElem);
if (isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
SmallVector<int> Mask(NumDstElts, PoisonMaskElem);

bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
bool IsExtIdxInBounds = ExtIdx < NumDstElts;
bool NeedDstSrcSwap = isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec);
if (NeedDstSrcSwap) {
SK = TargetTransformInfo::SK_PermuteSingleSrc;
Mask[InsIdx] = ExtIdx;
if (!IsExtIdxInBounds && NeedExpOrNarrow)
Mask[InsIdx] = 0;
else
Mask[InsIdx] = ExtIdx;
std::swap(DstVec, SrcVec);
} else {
SK = TargetTransformInfo::SK_PermuteTwoSrc;
std::iota(Mask.begin(), Mask.end(), 0);
Mask[InsIdx] = ExtIdx + NumElts;
if (!IsExtIdxInBounds && NeedExpOrNarrow)
Mask[InsIdx] = NumDstElts;
else
Mask[InsIdx] = ExtIdx + NumDstElts;
}

// Cost
auto *Ins = cast<InsertElementInst>(&I);
auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
InstructionCost InsCost =
TTI.getVectorInstrCost(*Ins, VecTy, CostKind, InsIdx);
TTI.getVectorInstrCost(*Ins, DstVecTy, CostKind, InsIdx);
InstructionCost ExtCost =
TTI.getVectorInstrCost(*Ext, VecTy, CostKind, ExtIdx);
TTI.getVectorInstrCost(*Ext, DstVecTy, CostKind, ExtIdx);
InstructionCost OldCost = ExtCost + InsCost;

// Ignore 'free' identity insertion shuffle.
// TODO: getShuffleCost should return TCC_Free for Identity shuffles.
InstructionCost NewCost = 0;
if (!ShuffleVectorInst::isIdentityMask(Mask, NumElts))
NewCost += TTI.getShuffleCost(SK, VecTy, Mask, CostKind, 0, nullptr,
{DstVec, SrcVec});
SmallVector<int> ExtToVecMask;
if (!NeedExpOrNarrow) {
// Ignore 'free' identity insertion shuffle.
// TODO: getShuffleCost should return TCC_Free for Identity shuffles.
if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
NewCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind, 0, nullptr,
{DstVec, SrcVec});
} else {
// When creating length-changing-vector, always create with a Mask whose
// first element has an ExtIdx, so that the first element of the vector
// being created is always the target to be extracted.
ExtToVecMask.assign(NumDstElts, PoisonMaskElem);
if (IsExtIdxInBounds)
ExtToVecMask[ExtIdx] = ExtIdx;
else
ExtToVecMask[0] = ExtIdx;
// Add cost for expanding or narrowing
NewCost = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
DstVecTy, ExtToVecMask, CostKind);
NewCost += TTI.getShuffleCost(SK, DstVecTy, Mask, CostKind);
}

if (!Ext->hasOneUse())
NewCost += ExtCost;

Expand All @@ -3113,9 +3144,16 @@ bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
if (OldCost < NewCost)
return false;

if (NeedExpOrNarrow) {
if (!NeedDstSrcSwap)
SrcVec = Builder.CreateShuffleVector(SrcVec, ExtToVecMask);
else
DstVec = Builder.CreateShuffleVector(DstVec, ExtToVecMask);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't you already swapped the SrcVec/DstVec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I wrote the code like this because the position of SrcVec changes depending on when it was swapped and when it was not.

}

// Canonicalize undef param to RHS to help further folds.
if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
ShuffleVectorInst::commuteShuffleMask(Mask, NumDstElts);
std::swap(DstVec, SrcVec);
}

Expand Down
196 changes: 196 additions & 0 deletions llvm/test/Transforms/VectorCombine/X86/extract-insert-poison.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64-- -mattr=SSE2 | FileCheck %s --check-prefixes=CHECK,SSE
; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64-- -mattr=AVX2 | FileCheck %s --check-prefixes=CHECK,AVX


define <4 x double> @src_ins0_v4f64_ext0_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; CHECK-LABEL: @src_ins0_v4f64_ext0_v2f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <2 x double> [[B:%.*]], i32 0
; CHECK-NEXT: [[INS:%.*]] = insertelement <4 x double> poison, double [[EXT]], i32 0
; CHECK-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 0
%ins = insertelement <4 x double> poison, double %ext, i32 0
ret <4 x double> %ins
}

define <4 x double> @src_ins1_v4f64_ext0_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; CHECK-LABEL: @src_ins1_v4f64_ext0_v2f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <2 x double> [[B:%.*]], i32 0
; CHECK-NEXT: [[INS:%.*]] = insertelement <4 x double> poison, double [[EXT]], i32 1
; CHECK-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 0
%ins = insertelement <4 x double> poison, double %ext, i32 1
ret <4 x double> %ins
}

define <4 x double> @src_ins2_v4f64_ext0_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; SSE-LABEL: @src_ins2_v4f64_ext0_v2f64(
; SSE-NEXT: [[EXT:%.*]] = extractelement <2 x double> [[B:%.*]], i32 0
; SSE-NEXT: [[INS:%.*]] = insertelement <4 x double> poison, double [[EXT]], i32 2
; SSE-NEXT: ret <4 x double> [[INS]]
;
; AVX-LABEL: @src_ins2_v4f64_ext0_v2f64(
; AVX-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[B:%.*]], <2 x double> poison, <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>
; AVX-NEXT: [[INS:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> <i32 poison, i32 poison, i32 0, i32 poison>
; AVX-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 0
%ins = insertelement <4 x double> poison, double %ext, i32 2
ret <4 x double> %ins
}

define <4 x double> @src_ins3_v4f64_ext0_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; SSE-LABEL: @src_ins3_v4f64_ext0_v2f64(
; SSE-NEXT: [[EXT:%.*]] = extractelement <2 x double> [[B:%.*]], i32 0
; SSE-NEXT: [[INS:%.*]] = insertelement <4 x double> poison, double [[EXT]], i32 3
; SSE-NEXT: ret <4 x double> [[INS]]
;
; AVX-LABEL: @src_ins3_v4f64_ext0_v2f64(
; AVX-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[B:%.*]], <2 x double> poison, <4 x i32> <i32 0, i32 poison, i32 poison, i32 poison>
; AVX-NEXT: [[INS:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> <i32 poison, i32 poison, i32 poison, i32 0>
; AVX-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 0
%ins = insertelement <4 x double> poison, double %ext, i32 3
ret <4 x double> %ins
}

define <4 x double> @src_ins0_v4f64_ext1_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; SSE-LABEL: @src_ins0_v4f64_ext1_v2f64(
; SSE-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[B:%.*]], <2 x double> poison, <4 x i32> <i32 poison, i32 1, i32 poison, i32 poison>
; SSE-NEXT: [[INS:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> <i32 1, i32 poison, i32 poison, i32 poison>
; SSE-NEXT: ret <4 x double> [[INS]]
;
; AVX-LABEL: @src_ins0_v4f64_ext1_v2f64(
; AVX-NEXT: [[EXT:%.*]] = extractelement <2 x double> [[B:%.*]], i32 1
; AVX-NEXT: [[INS:%.*]] = insertelement <4 x double> poison, double [[EXT]], i32 0
; AVX-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 1
%ins = insertelement <4 x double> poison, double %ext, i32 0
ret <4 x double> %ins
}

define <4 x double> @src_ins1_v4f64_ext1_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; CHECK-LABEL: @src_ins1_v4f64_ext1_v2f64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[B:%.*]], <2 x double> poison, <4 x i32> <i32 poison, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[INS:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> <i32 poison, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 1
%ins = insertelement <4 x double> poison, double %ext, i32 1
ret <4 x double> %ins
}

define <4 x double> @src_ins2_v4f64_ext1_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; CHECK-LABEL: @src_ins2_v4f64_ext1_v2f64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[B:%.*]], <2 x double> poison, <4 x i32> <i32 poison, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[INS:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> <i32 poison, i32 poison, i32 1, i32 poison>
; CHECK-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 1
%ins = insertelement <4 x double> poison, double %ext, i32 2
ret <4 x double> %ins
}

define <4 x double> @src_ins3_v4f64_ext1_v2f64(<4 x double> %a, <2 x double> %b) #0 {
; CHECK-LABEL: @src_ins3_v4f64_ext1_v2f64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x double> [[B:%.*]], <2 x double> poison, <4 x i32> <i32 poison, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[INS:%.*]] = shufflevector <4 x double> [[TMP1]], <4 x double> poison, <4 x i32> <i32 poison, i32 poison, i32 poison, i32 1>
; CHECK-NEXT: ret <4 x double> [[INS]]
;
%ext = extractelement <2 x double> %b, i32 1
%ins = insertelement <4 x double> poison, double %ext, i32 3
ret <4 x double> %ins
}

define <2 x double> @src_ins0_v2f64_ext0_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins0_v2f64_ext0_v4f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <4 x double> [[B:%.*]], i32 0
; CHECK-NEXT: [[INS:%.*]] = insertelement <2 x double> poison, double [[EXT]], i32 0
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 0
%ins = insertelement <2 x double> poison, double %ext, i32 0
ret <2 x double> %ins
}

define <2 x double> @src_ins0_v2f64_ext1_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins0_v2f64_ext1_v4f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <4 x double> [[B:%.*]], i32 1
; CHECK-NEXT: [[INS:%.*]] = insertelement <2 x double> poison, double [[EXT]], i32 0
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 1
%ins = insertelement <2 x double> poison, double %ext, i32 0
ret <2 x double> %ins
}

define <2 x double> @src_ins0_v2f64_ext2_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins0_v2f64_ext2_v4f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <4 x double> [[B:%.*]], i32 2
; CHECK-NEXT: [[INS:%.*]] = insertelement <2 x double> poison, double [[EXT]], i32 0
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 2
%ins = insertelement <2 x double> poison, double %ext, i32 0
ret <2 x double> %ins
}

define <2 x double> @src_ins0_v2f64_ext3_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins0_v2f64_ext3_v4f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <4 x double> [[B:%.*]], i32 3
; CHECK-NEXT: [[INS:%.*]] = insertelement <2 x double> poison, double [[EXT]], i32 0
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 3
%ins = insertelement <2 x double> poison, double %ext, i32 0
ret <2 x double> %ins
}

define <2 x double> @src_ins1_v2f64_ext0_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins1_v2f64_ext0_v4f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <4 x double> [[B:%.*]], i32 0
; CHECK-NEXT: [[INS:%.*]] = insertelement <2 x double> poison, double [[EXT]], i32 1
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 0
%ins = insertelement <2 x double> poison, double %ext, i32 1
ret <2 x double> %ins
}

define <2 x double> @src_ins1_v2f64_ext1_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins1_v2f64_ext1_v4f64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[B:%.*]], <4 x double> poison, <2 x i32> <i32 poison, i32 1>
; CHECK-NEXT: [[INS:%.*]] = shufflevector <2 x double> [[TMP1]], <2 x double> poison, <2 x i32> <i32 poison, i32 1>
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 1
%ins = insertelement <2 x double> poison, double %ext, i32 1
ret <2 x double> %ins
}

define <2 x double> @src_ins1_v2f64_ext2_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins1_v2f64_ext2_v4f64(
; CHECK-NEXT: [[EXT:%.*]] = extractelement <4 x double> [[B:%.*]], i32 2
; CHECK-NEXT: [[INS:%.*]] = insertelement <2 x double> poison, double [[EXT]], i32 1
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 2
%ins = insertelement <2 x double> poison, double %ext, i32 1
ret <2 x double> %ins
}

define <2 x double> @src_ins1_v2f64_ext3_v4f64(<2 x double> %a, <4 x double> %b) {
; CHECK-LABEL: @src_ins1_v2f64_ext3_v4f64(
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[B:%.*]], <4 x double> poison, <2 x i32> <i32 3, i32 poison>
; CHECK-NEXT: [[INS:%.*]] = shufflevector <2 x double> [[TMP1]], <2 x double> poison, <2 x i32> <i32 poison, i32 0>
; CHECK-NEXT: ret <2 x double> [[INS]]
;
%ext = extractelement <4 x double> %b, i32 3
%ins = insertelement <2 x double> poison, double %ext, i32 1
ret <2 x double> %ins
}

Loading