@@ -114,6 +114,7 @@ class VectorCombine {
114
114
bool foldShuffleOfBinops (Instruction &I);
115
115
bool foldShuffleOfCastops (Instruction &I);
116
116
bool foldShuffleOfShuffles (Instruction &I);
117
+ bool foldShuffleToIdentity (Instruction &I);
117
118
bool foldShuffleFromReductions (Instruction &I);
118
119
bool foldTruncFromReductions (Instruction &I);
119
120
bool foldSelectShuffle (Instruction &I, bool FromReduction = false );
@@ -1667,6 +1668,148 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
1667
1668
return true ;
1668
1669
}
1669
1670
1671
+ // Starting from a shuffle, look up through operands tracking the shuffled index
1672
+ // of each lane. If we can simplify away the shuffles to identities then
1673
+ // do so.
1674
+ bool VectorCombine::foldShuffleToIdentity (Instruction &I) {
1675
+ FixedVectorType *Ty = dyn_cast<FixedVectorType>(I.getType ());
1676
+ if (!Ty || !isa<Instruction>(I.getOperand (0 )) ||
1677
+ !isa<Instruction>(I.getOperand (1 )))
1678
+ return false ;
1679
+
1680
+ using InstLane = std::pair<Value *, int >;
1681
+
1682
+ auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
1683
+ while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
1684
+ unsigned NumElts =
1685
+ cast<FixedVectorType>(SV->getOperand (0 )->getType ())->getNumElements ();
1686
+ int M = SV->getMaskValue (Lane);
1687
+ if (M < 0 )
1688
+ return {nullptr , -1 };
1689
+ else if (M < (int )NumElts) {
1690
+ V = SV->getOperand (0 );
1691
+ Lane = M;
1692
+ } else {
1693
+ V = SV->getOperand (1 );
1694
+ Lane = M - NumElts;
1695
+ }
1696
+ }
1697
+ return InstLane{V, Lane};
1698
+ };
1699
+
1700
+ auto GenerateInstLaneVectorFromOperand =
1701
+ [&LookThroughShuffles](const SmallVector<InstLane> &Item, int Op) {
1702
+ SmallVector<InstLane> NItem;
1703
+ for (InstLane V : Item) {
1704
+ NItem.emplace_back (
1705
+ !V.first
1706
+ ? InstLane{nullptr , -1 }
1707
+ : LookThroughShuffles (
1708
+ cast<Instruction>(V.first )->getOperand (Op), V.second ));
1709
+ }
1710
+ return NItem;
1711
+ };
1712
+
1713
+ SmallVector<InstLane> Start;
1714
+ for (unsigned M = 0 ; M < Ty->getNumElements (); ++M)
1715
+ Start.push_back (LookThroughShuffles (&I, M));
1716
+
1717
+ SmallVector<SmallVector<InstLane>> Worklist;
1718
+ Worklist.push_back (Start);
1719
+ SmallPtrSet<Value *, 4 > IdentityLeafs, SplatLeafs;
1720
+ unsigned NumVisited = 0 ;
1721
+
1722
+ while (!Worklist.empty ()) {
1723
+ SmallVector<InstLane> Item = Worklist.pop_back_val ();
1724
+ if (++NumVisited > MaxInstrsToScan)
1725
+ return false ;
1726
+
1727
+ // If we found an undef first lane then bail out to keep things simple.
1728
+ if (!Item[0 ].first )
1729
+ return false ;
1730
+
1731
+ // Look for an identity value.
1732
+ if (Item[0 ].second == 0 && Item[0 ].first ->getType () == Ty &&
1733
+ all_of (drop_begin (enumerate(Item)), [&](const auto &E) {
1734
+ return !E.value ().first || (E.value ().first == Item[0 ].first &&
1735
+ E.value ().second == (int )E.index ());
1736
+ })) {
1737
+ IdentityLeafs.insert (Item[0 ].first );
1738
+ continue ;
1739
+ }
1740
+ // Look for a splat value.
1741
+ if (all_of (drop_begin (Item), [&](InstLane &IL) {
1742
+ return !IL.first ||
1743
+ (IL.first == Item[0 ].first && IL.second == Item[0 ].second );
1744
+ })) {
1745
+ SplatLeafs.insert (Item[0 ].first );
1746
+ continue ;
1747
+ }
1748
+
1749
+ // We need each element to be the same type of value, and check that each
1750
+ // element has a single use.
1751
+ if (!all_of (drop_begin (Item), [&](InstLane IL) {
1752
+ if (!IL.first )
1753
+ return true ;
1754
+ if (isa<Instruction>(IL.first ) &&
1755
+ !cast<Instruction>(IL.first )->hasOneUse ())
1756
+ return false ;
1757
+ return IL.first ->getValueID () == Item[0 ].first ->getValueID () &&
1758
+ (!isa<IntrinsicInst>(IL.first ) ||
1759
+ cast<IntrinsicInst>(IL.first )->getIntrinsicID () ==
1760
+ cast<IntrinsicInst>(Item[0 ].first )->getIntrinsicID ());
1761
+ }))
1762
+ return false ;
1763
+
1764
+ // Check the operator is one that we support.
1765
+ if (isa<BinaryOperator>(Item[0 ].first )) {
1766
+ Worklist.push_back (GenerateInstLaneVectorFromOperand (Item, 0 ));
1767
+ Worklist.push_back (GenerateInstLaneVectorFromOperand (Item, 1 ));
1768
+ } else if (isa<UnaryOperator>(Item[0 ].first )) {
1769
+ Worklist.push_back (GenerateInstLaneVectorFromOperand (Item, 0 ));
1770
+ } else {
1771
+ return false ;
1772
+ }
1773
+ }
1774
+
1775
+ // If we got this far, we know the shuffles are superfluous and can be
1776
+ // removed. Scan through again and generate the new tree of instructions.
1777
+ std::function<Value *(const SmallVector<InstLane> &)> generate =
1778
+ [&](const SmallVector<InstLane> &Item) -> Value * {
1779
+ if (IdentityLeafs.contains (Item[0 ].first ) &&
1780
+ all_of (drop_begin (enumerate(Item)), [&](const auto &E) {
1781
+ return !E.value ().first || (E.value ().first == Item[0 ].first &&
1782
+ E.value ().second == (int )E.index ());
1783
+ })) {
1784
+ return Item[0 ].first ;
1785
+ } else if (SplatLeafs.contains (Item[0 ].first )) {
1786
+ if (auto ILI = dyn_cast<Instruction>(Item[0 ].first ))
1787
+ Builder.SetInsertPoint (*ILI->getInsertionPointAfterDef ());
1788
+ else if (isa<Argument>(Item[0 ].first ))
1789
+ Builder.SetInsertPointPastAllocas (I.getParent ()->getParent ());
1790
+ SmallVector<int , 16 > Mask (Ty->getNumElements (), Item[0 ].second );
1791
+ return Builder.CreateShuffleVector (Item[0 ].first , Mask);
1792
+ }
1793
+
1794
+ auto *I = cast<Instruction>(Item[0 ].first );
1795
+ SmallVector<Value *> Ops;
1796
+ unsigned E = I->getNumOperands ();
1797
+ for (unsigned Idx = 0 ; Idx < E; Idx++)
1798
+ Ops.push_back (generate (GenerateInstLaneVectorFromOperand (Item, Idx)));
1799
+ Builder.SetInsertPoint (I);
1800
+ if (auto BI = dyn_cast<BinaryOperator>(I))
1801
+ return Builder.CreateBinOp ((Instruction::BinaryOps)BI->getOpcode (),
1802
+ Ops[0 ], Ops[1 ]);
1803
+ if (auto UI = dyn_cast<UnaryOperator>(I))
1804
+ return Builder.CreateUnOp ((Instruction::UnaryOps)UI->getOpcode (), Ops[0 ]);
1805
+ llvm_unreachable (" Unhandled instruction in generate" );
1806
+ };
1807
+
1808
+ Value *V = generate (Start);
1809
+ replaceValue (I, *V);
1810
+ return true ;
1811
+ }
1812
+
1670
1813
// / Given a commutative reduction, the order of the input lanes does not alter
1671
1814
// / the results. We can use this to remove certain shuffles feeding the
1672
1815
// / reduction, removing the need to shuffle at all.
@@ -2224,6 +2367,7 @@ bool VectorCombine::run() {
2224
2367
MadeChange |= foldShuffleOfCastops (I);
2225
2368
MadeChange |= foldShuffleOfShuffles (I);
2226
2369
MadeChange |= foldSelectShuffle (I);
2370
+ MadeChange |= foldShuffleToIdentity (I);
2227
2371
break ;
2228
2372
case Instruction::BitCast:
2229
2373
MadeChange |= foldBitcastShuffle (I);
0 commit comments