@@ -1743,6 +1743,36 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1743
1743
TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, BinResTy,
1744
1744
OldMask, CostKind, 0 , nullptr , {LHS, RHS}, &I);
1745
1745
1746
+ // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
1747
+ // where one use shuffles have gotten split across the binop/cmp. These
1748
+ // often allow a major reduction in total cost that wouldn't happen as
1749
+ // individual folds.
1750
+ auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int > Mask,
1751
+ TTI::TargetCostKind CostKind) -> bool {
1752
+ Value *InnerOp;
1753
+ ArrayRef<int > InnerMask;
1754
+ if (match (Op, m_OneUse (m_Shuffle (m_Value (InnerOp), m_Undef (),
1755
+ m_Mask (InnerMask)))) &&
1756
+ InnerOp->getType () == Op->getType () &&
1757
+ all_of (InnerMask,
1758
+ [NumSrcElts](int M) { return M < (int )NumSrcElts; })) {
1759
+ for (int &M : Mask)
1760
+ if (Offset <= M && M < (int )(Offset + NumSrcElts)) {
1761
+ M = InnerMask[M - Offset];
1762
+ M = 0 <= M ? M + Offset : M;
1763
+ }
1764
+ OldCost += TTI.getInstructionCost (cast<Instruction>(Op), CostKind);
1765
+ Op = InnerOp;
1766
+ return true ;
1767
+ }
1768
+ return false ;
1769
+ };
1770
+ bool ReducedInstCount = false ;
1771
+ ReducedInstCount |= MergeInner (X, 0 , NewMask0, CostKind);
1772
+ ReducedInstCount |= MergeInner (Y, 0 , NewMask1, CostKind);
1773
+ ReducedInstCount |= MergeInner (Z, NumSrcElts, NewMask0, CostKind);
1774
+ ReducedInstCount |= MergeInner (W, NumSrcElts, NewMask1, CostKind);
1775
+
1746
1776
InstructionCost NewCost =
1747
1777
TTI.getShuffleCost (SK0, BinOpTy, NewMask0, CostKind, 0 , nullptr , {X, Z}) +
1748
1778
TTI.getShuffleCost (SK1, BinOpTy, NewMask1, CostKind, 0 , nullptr , {Y, W});
@@ -1763,8 +1793,8 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
1763
1793
1764
1794
// If either shuffle will constant fold away, then fold for the same cost as
1765
1795
// we will reduce the instruction count.
1766
- bool ReducedInstCount = (isa<Constant>(X) && isa<Constant>(Z)) ||
1767
- (isa<Constant>(Y) && isa<Constant>(W));
1796
+ ReducedInstCount | = (isa<Constant>(X) && isa<Constant>(Z)) ||
1797
+ (isa<Constant>(Y) && isa<Constant>(W));
1768
1798
if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
1769
1799
return false ;
1770
1800
0 commit comments