Skip to content

Commit fe2119a

Browse files
committed
[VectorCombine] foldBitcastShuffle - include the cost of bitcasts in the comparison
This makes no real difference currently as we only fold unary shuffles, but I'm hoping to handle binary shuffles in a future patch.
1 parent 6086937 commit fe2119a

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
684684
/// destination type followed by shuffle. This can enable further transforms by
685685
/// moving bitcasts or shuffles together.
686686
bool VectorCombine::foldBitcastShuffle(Instruction &I) {
687-
Value *V;
687+
Value *V0;
688688
ArrayRef<int> Mask;
689-
if (!match(&I, m_BitCast(
690-
m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
689+
if (!match(&I, m_BitCast(m_OneUse(
690+
m_Shuffle(m_Value(V0), m_Undef(), m_Mask(Mask))))))
691691
return false;
692692

693693
// 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
@@ -696,7 +696,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
696696
// 2) Disallow non-vector casts.
697697
// TODO: We could allow any shuffle.
698698
auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
699-
auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
699+
auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
700700
if (!DestTy || !SrcTy)
701701
return false;
702702

@@ -724,20 +724,31 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
724724
// Bitcast the shuffle src - keep its original width but using the destination
725725
// scalar type.
726726
unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
727-
auto *ShuffleTy = FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
728-
729-
// The new shuffle must not cost more than the old shuffle. The bitcast is
730-
// moved ahead of the shuffle, so assume that it has the same cost as before.
731-
InstructionCost DestCost = TTI.getShuffleCost(
732-
TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
727+
auto *NewShuffleTy =
728+
FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
729+
auto *OldShuffleTy =
730+
FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
731+
732+
// The new shuffle must not cost more than the old shuffle.
733+
TargetTransformInfo::TargetCostKind CK =
734+
TargetTransformInfo::TCK_RecipThroughput;
735+
TargetTransformInfo::ShuffleKind SK =
736+
TargetTransformInfo::SK_PermuteSingleSrc;
737+
738+
InstructionCost DestCost =
739+
TTI.getShuffleCost(SK, NewShuffleTy, NewMask, CK) +
740+
TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
741+
TargetTransformInfo::CastContextHint::None, CK);
733742
InstructionCost SrcCost =
734-
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
743+
TTI.getShuffleCost(SK, SrcTy, Mask, CK) +
744+
TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
745+
TargetTransformInfo::CastContextHint::None, CK);
735746
if (DestCost > SrcCost || !DestCost.isValid())
736747
return false;
737748

738-
// bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
749+
// bitcast (shuf V0, MaskC) --> shuf (bitcast V0), MaskC'
739750
++NumShufOfBitcast;
740-
Value *CastV = Builder.CreateBitCast(V, ShuffleTy);
751+
Value *CastV = Builder.CreateBitCast(V0, NewShuffleTy);
741752
Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
742753
replaceValue(I, *Shuf);
743754
return true;

0 commit comments

Comments
 (0)