@@ -684,10 +684,10 @@ bool VectorCombine::foldInsExtFNeg(Instruction &I) {
684
684
// / destination type followed by shuffle. This can enable further transforms by
685
685
// / moving bitcasts or shuffles together.
686
686
bool VectorCombine::foldBitcastShuffle (Instruction &I) {
687
- Value *V ;
687
+ Value *V0 ;
688
688
ArrayRef<int > Mask;
689
- if (!match (&I, m_BitCast (
690
- m_OneUse ( m_Shuffle (m_Value (V ), m_Undef (), m_Mask (Mask))))))
689
+ if (!match (&I, m_BitCast (m_OneUse (
690
+ m_Shuffle (m_Value (V0 ), m_Undef (), m_Mask (Mask))))))
691
691
return false ;
692
692
693
693
// 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
@@ -696,7 +696,7 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
696
696
// 2) Disallow non-vector casts.
697
697
// TODO: We could allow any shuffle.
698
698
auto *DestTy = dyn_cast<FixedVectorType>(I.getType ());
699
- auto *SrcTy = dyn_cast<FixedVectorType>(V ->getType ());
699
+ auto *SrcTy = dyn_cast<FixedVectorType>(V0 ->getType ());
700
700
if (!DestTy || !SrcTy)
701
701
return false ;
702
702
@@ -724,20 +724,31 @@ bool VectorCombine::foldBitcastShuffle(Instruction &I) {
724
724
// Bitcast the shuffle src - keep its original width but using the destination
725
725
// scalar type.
726
726
unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits () / DestEltSize;
727
- auto *ShuffleTy = FixedVectorType::get (DestTy->getScalarType (), NumSrcElts);
728
-
729
- // The new shuffle must not cost more than the old shuffle. The bitcast is
730
- // moved ahead of the shuffle, so assume that it has the same cost as before.
731
- InstructionCost DestCost = TTI.getShuffleCost (
732
- TargetTransformInfo::SK_PermuteSingleSrc, ShuffleTy, NewMask);
727
+ auto *NewShuffleTy =
728
+ FixedVectorType::get (DestTy->getScalarType (), NumSrcElts);
729
+ auto *OldShuffleTy =
730
+ FixedVectorType::get (SrcTy->getScalarType (), Mask.size ());
731
+
732
+ // The new shuffle must not cost more than the old shuffle.
733
+ TargetTransformInfo::TargetCostKind CK =
734
+ TargetTransformInfo::TCK_RecipThroughput;
735
+ TargetTransformInfo::ShuffleKind SK =
736
+ TargetTransformInfo::SK_PermuteSingleSrc;
737
+
738
+ InstructionCost DestCost =
739
+ TTI.getShuffleCost (SK, NewShuffleTy, NewMask, CK) +
740
+ TTI.getCastInstrCost (Instruction::BitCast, NewShuffleTy, SrcTy,
741
+ TargetTransformInfo::CastContextHint::None, CK);
733
742
InstructionCost SrcCost =
734
- TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, SrcTy, Mask);
743
+ TTI.getShuffleCost (SK, SrcTy, Mask, CK) +
744
+ TTI.getCastInstrCost (Instruction::BitCast, DestTy, OldShuffleTy,
745
+ TargetTransformInfo::CastContextHint::None, CK);
735
746
if (DestCost > SrcCost || !DestCost.isValid ())
736
747
return false ;
737
748
738
- // bitcast (shuf V , MaskC) --> shuf (bitcast V ), MaskC'
749
+ // bitcast (shuf V0 , MaskC) --> shuf (bitcast V0 ), MaskC'
739
750
++NumShufOfBitcast;
740
- Value *CastV = Builder.CreateBitCast (V, ShuffleTy );
751
+ Value *CastV = Builder.CreateBitCast (V0, NewShuffleTy );
741
752
Value *Shuf = Builder.CreateShuffleVector (CastV, NewMask);
742
753
replaceValue (I, *Shuf);
743
754
return true ;
0 commit comments