Skip to content

Commit efa8463

Browse files
authored
[VectorCombine] Add free concats to shuffleToIdentity. (#94954)
This is another relatively small adjustment to shuffleToIdentity, which has had a few knock-one effects to need a few more changes. It attempts to detect free concats, that will be legalized to multiple vector operations. For example if the lanes are '[a[0], a[1], b[0], b[1]]' and a and b are v2f64 under aarch64. In order to do this: - isFreeConcat detects whether the input has piece-wise identities from multiple inputs that can become a concat. - A tree of concat shuffles is created to concatenate the input values into a single vector. This is a little different to most other inputs as there are created from multiple values that are being combined together, and we cannot rely on the Lane0 insert location always being valid. - The insert location is changed to the original location instead of updating per item, which ensure it is valid due to the order that we visit and create items.
1 parent 4c91b49 commit efa8463

File tree

4 files changed

+162
-206
lines changed

4 files changed

+162
-206
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 102 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,23 +1703,73 @@ generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
17031703
return NItem;
17041704
}
17051705

1706+
/// Detect concat of multiple values into a vector
1707+
static bool isFreeConcat(ArrayRef<InstLane> Item,
1708+
const TargetTransformInfo &TTI) {
1709+
auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
1710+
unsigned NumElts = Ty->getNumElements();
1711+
if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
1712+
return false;
1713+
1714+
// Check that the concat is free, usually meaning that the type will be split
1715+
// during legalization.
1716+
SmallVector<int, 16> ConcatMask(NumElts * 2);
1717+
std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1718+
if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, Ty, ConcatMask,
1719+
TTI::TCK_RecipThroughput) != 0)
1720+
return false;
1721+
1722+
unsigned NumSlices = Item.size() / NumElts;
1723+
// Currently we generate a tree of shuffles for the concats, which limits us
1724+
// to a power2.
1725+
if (!isPowerOf2_32(NumSlices))
1726+
return false;
1727+
for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
1728+
Use *SliceV = Item[Slice * NumElts].first;
1729+
if (!SliceV || SliceV->get()->getType() != Ty)
1730+
return false;
1731+
for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
1732+
auto [V, Lane] = Item[Slice * NumElts + Elt];
1733+
if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
1734+
return false;
1735+
}
1736+
}
1737+
return true;
1738+
}
1739+
17061740
static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
17071741
const SmallPtrSet<Use *, 4> &IdentityLeafs,
17081742
const SmallPtrSet<Use *, 4> &SplatLeafs,
1743+
const SmallPtrSet<Use *, 4> &ConcatLeafs,
17091744
IRBuilder<> &Builder) {
17101745
auto [FrontU, FrontLane] = Item.front();
17111746

17121747
if (IdentityLeafs.contains(FrontU)) {
17131748
return FrontU->get();
17141749
}
17151750
if (SplatLeafs.contains(FrontU)) {
1716-
if (auto *ILI = dyn_cast<Instruction>(FrontU))
1717-
Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
1718-
else if (auto *Arg = dyn_cast<Argument>(FrontU))
1719-
Builder.SetInsertPointPastAllocas(Arg->getParent());
17201751
SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
17211752
return Builder.CreateShuffleVector(FrontU->get(), Mask);
17221753
}
1754+
if (ConcatLeafs.contains(FrontU)) {
1755+
unsigned NumElts =
1756+
cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
1757+
SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
1758+
for (unsigned S = 0; S < Values.size(); ++S)
1759+
Values[S] = Item[S * NumElts].first->get();
1760+
1761+
while (Values.size() > 1) {
1762+
NumElts *= 2;
1763+
SmallVector<int, 16> Mask(NumElts, 0);
1764+
std::iota(Mask.begin(), Mask.end(), 0);
1765+
SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
1766+
for (unsigned S = 0; S < NewValues.size(); ++S)
1767+
NewValues[S] =
1768+
Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
1769+
Values = NewValues;
1770+
}
1771+
return Values[0];
1772+
}
17231773

17241774
auto *I = cast<Instruction>(FrontU->get());
17251775
auto *II = dyn_cast<IntrinsicInst>(I);
@@ -1730,16 +1780,16 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
17301780
Ops[Idx] = II->getOperand(Idx);
17311781
continue;
17321782
}
1733-
Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
1734-
Ty, IdentityLeafs, SplatLeafs, Builder);
1783+
Ops[Idx] =
1784+
generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx), Ty,
1785+
IdentityLeafs, SplatLeafs, ConcatLeafs, Builder);
17351786
}
17361787

17371788
SmallVector<Value *, 8> ValueList;
17381789
for (const auto &Lane : Item)
17391790
if (Lane.first)
17401791
ValueList.push_back(Lane.first->get());
17411792

1742-
Builder.SetInsertPoint(I);
17431793
Type *DstTy =
17441794
FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
17451795
if (auto *BI = dyn_cast<BinaryOperator>(I)) {
@@ -1790,7 +1840,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
17901840

17911841
SmallVector<SmallVector<InstLane>> Worklist;
17921842
Worklist.push_back(Start);
1793-
SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs;
1843+
SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
17941844
unsigned NumVisited = 0;
17951845

17961846
while (!Worklist.empty()) {
@@ -1839,7 +1889,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
18391889

18401890
// We need each element to be the same type of value, and check that each
18411891
// element has a single use.
1842-
if (!all_of(drop_begin(Item), [Item](InstLane IL) {
1892+
if (all_of(drop_begin(Item), [Item](InstLane IL) {
18431893
Value *FrontV = Item.front().first->get();
18441894
if (!IL.first)
18451895
return true;
@@ -1860,48 +1910,59 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
18601910
return !II || (isa<IntrinsicInst>(FrontV) &&
18611911
II->getIntrinsicID() ==
18621912
cast<IntrinsicInst>(FrontV)->getIntrinsicID());
1863-
}))
1864-
return false;
1865-
1866-
// Check the operator is one that we support. We exclude div/rem in case
1867-
// they hit UB from poison lanes.
1868-
if ((isa<BinaryOperator>(FrontU) &&
1869-
!cast<BinaryOperator>(FrontU)->isIntDivRem()) ||
1870-
isa<CmpInst>(FrontU)) {
1871-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1872-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1873-
} else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) {
1874-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1875-
} else if (isa<SelectInst>(FrontU)) {
1876-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1877-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1878-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
1879-
} else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
1880-
II && isTriviallyVectorizable(II->getIntrinsicID())) {
1881-
for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
1882-
if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
1883-
if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
1884-
Value *FrontV = Item.front().first->get();
1885-
Use *U = IL.first;
1886-
return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
1887-
cast<Instruction>(FrontV)->getOperand(Op));
1888-
}))
1889-
return false;
1890-
continue;
1913+
})) {
1914+
// Check the operator is one that we support.
1915+
if (isa<BinaryOperator, CmpInst>(FrontU)) {
1916+
// We exclude div/rem in case they hit UB from poison lanes.
1917+
if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
1918+
BO && BO->isIntDivRem())
1919+
return false;
1920+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1921+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1922+
continue;
1923+
} else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontU)) {
1924+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1925+
continue;
1926+
} else if (isa<SelectInst>(FrontU)) {
1927+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
1928+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
1929+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
1930+
continue;
1931+
} else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
1932+
II && isTriviallyVectorizable(II->getIntrinsicID())) {
1933+
for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
1934+
if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
1935+
if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
1936+
Value *FrontV = Item.front().first->get();
1937+
Use *U = IL.first;
1938+
return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
1939+
cast<Instruction>(FrontV)->getOperand(Op));
1940+
}))
1941+
return false;
1942+
continue;
1943+
}
1944+
Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
18911945
}
1892-
Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
1946+
continue;
18931947
}
1894-
} else {
1895-
return false;
18961948
}
1949+
1950+
if (isFreeConcat(Item, TTI)) {
1951+
ConcatLeafs.insert(FrontU);
1952+
continue;
1953+
}
1954+
1955+
return false;
18971956
}
18981957

18991958
if (NumVisited <= 1)
19001959
return false;
19011960

19021961
// If we got this far, we know the shuffles are superfluous and can be
19031962
// removed. Scan through again and generate the new tree of instructions.
1904-
Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
1963+
Builder.SetInsertPoint(&I);
1964+
Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
1965+
ConcatLeafs, Builder);
19051966
replaceValue(I, *V);
19061967
return true;
19071968
}

llvm/test/Transforms/PhaseOrdering/AArch64/interleavevectorization.ll

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ define void @add4(ptr noalias noundef %x, ptr noalias noundef %y, i32 noundef %n
2222
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <32 x i16>, ptr [[TMP0]], align 2
2323
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
2424
; CHECK-NEXT: [[WIDE_VEC24:%.*]] = load <32 x i16>, ptr [[TMP1]], align 2
25-
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
2625
; CHECK-NEXT: [[TMP2:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
2726
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP2]]
27+
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[WIDE_VEC24]], [[WIDE_VEC]]
2828
; CHECK-NEXT: store <32 x i16> [[INTERLEAVED_VEC]], ptr [[GEP]], align 2
2929
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
3030
; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
@@ -403,12 +403,12 @@ define void @addmul(ptr noalias noundef %x, ptr noundef %y, ptr noundef %z, i32
403403
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <32 x i16>, ptr [[TMP0]], align 2
404404
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i16, ptr [[Z:%.*]], i64 [[OFFSET_IDX]]
405405
; CHECK-NEXT: [[WIDE_VEC31:%.*]] = load <32 x i16>, ptr [[TMP1]], align 2
406-
; CHECK-NEXT: [[TMP2:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
407-
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
408-
; CHECK-NEXT: [[WIDE_VEC36:%.*]] = load <32 x i16>, ptr [[TMP3]], align 2
409-
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[TMP2]], [[WIDE_VEC36]]
410-
; CHECK-NEXT: [[TMP4:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
411-
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP4]]
406+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i16, ptr [[X]], i64 [[OFFSET_IDX]]
407+
; CHECK-NEXT: [[WIDE_VEC36:%.*]] = load <32 x i16>, ptr [[TMP2]], align 2
408+
; CHECK-NEXT: [[TMP3:%.*]] = or disjoint i64 [[OFFSET_IDX]], 3
409+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[INVARIANT_GEP]], i64 [[TMP3]]
410+
; CHECK-NEXT: [[TMP4:%.*]] = mul <32 x i16> [[WIDE_VEC31]], [[WIDE_VEC]]
411+
; CHECK-NEXT: [[INTERLEAVED_VEC:%.*]] = add <32 x i16> [[TMP4]], [[WIDE_VEC36]]
412412
; CHECK-NEXT: store <32 x i16> [[INTERLEAVED_VEC]], ptr [[GEP]], align 2
413413
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8
414414
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256

0 commit comments

Comments
 (0)