Skip to content

Commit aba407e

Browse files
committed
[LoadStoreVectorizer] Postprocess and merge equivalence classes (llvm#114501)
This patch introduces a new method: void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const The method is called at the end of Vectorizer::collectEquivalenceClasses() and is needed to merge equivalence classes that differ only by their underlying objects (UO1 and UO2), where UO1 is 1-level-indirection underlying base for UO2. This situation arises due to the limited lookup depth used during the search of underlying bases with llvm::getUnderlyingObject(ptr). Using any fixed lookup depth can result into creation of multiple equivalence classes that only differ by 1-level indirection bases. The new approach merges equivalence classes if they have adjacent bases (1-level indirection). If a series of equivalence classes form ladder formed of 1-step/level indirections, they are all merged into a single equivalence class. This provides more opportunities for the load-store vectorizer to generate better vectors. --------- Signed-off-by: Klochkov, Vyacheslav N <[email protected]>
1 parent f6365a4 commit aba407e

File tree

2 files changed

+265
-0
lines changed

2 files changed

+265
-0
lines changed

llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ class Vectorizer {
324324
Instruction *ChainElem, Instruction *ChainBegin,
325325
const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets);
326326

327+
/// Merges the equivalence classes if they have underlying objects that differ
328+
/// by one level of indirection (i.e., one is a getelementptr and the other is
329+
/// the base pointer in that getelementptr).
330+
void mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const;
331+
327332
/// Collects loads and stores grouped by "equivalence class", where:
328333
/// - all elements in an eq class are a load or all are a store,
329334
/// - they all load/store the same element size (it's OK to have e.g. i8 and
@@ -1305,6 +1310,123 @@ std::optional<APInt> Vectorizer::getConstantOffsetSelects(
13051310
return std::nullopt;
13061311
}
13071312

1313+
void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
1314+
if (EQClasses.size() < 2) // There is nothing to merge.
1315+
return;
1316+
1317+
// The reduced key has all elements of the ECClassKey except the underlying
1318+
// object. Check that EqClassKey has 4 elements and define the reduced key.
1319+
static_assert(std::tuple_size_v<EqClassKey> == 4,
1320+
"EqClassKey has changed - EqClassReducedKey needs changes too");
1321+
using EqClassReducedKey =
1322+
std::tuple<std::tuple_element_t<1, EqClassKey> /* AddrSpace */,
1323+
std::tuple_element_t<2, EqClassKey> /* Element size */,
1324+
std::tuple_element_t<3, EqClassKey> /* IsLoad; */>;
1325+
using ECReducedKeyToUnderlyingObjectMap =
1326+
MapVector<EqClassReducedKey,
1327+
SmallPtrSet<std::tuple_element_t<0, EqClassKey>, 4>>;
1328+
1329+
// Form a map from the reduced key (without the underlying object) to the
1330+
// underlying objects: 1 reduced key to many underlying objects, to form
1331+
// groups of potentially merge-able equivalence classes.
1332+
ECReducedKeyToUnderlyingObjectMap RedKeyToUOMap;
1333+
bool FoundPotentiallyOptimizableEC = false;
1334+
for (const auto &EC : EQClasses) {
1335+
const auto &Key = EC.first;
1336+
EqClassReducedKey RedKey{std::get<1>(Key), std::get<2>(Key),
1337+
std::get<3>(Key)};
1338+
RedKeyToUOMap[RedKey].insert(std::get<0>(Key));
1339+
if (RedKeyToUOMap[RedKey].size() > 1)
1340+
FoundPotentiallyOptimizableEC = true;
1341+
}
1342+
if (!FoundPotentiallyOptimizableEC)
1343+
return;
1344+
1345+
LLVM_DEBUG({
1346+
dbgs() << "LSV: mergeEquivalenceClasses: before merging:\n";
1347+
for (const auto &EC : EQClasses) {
1348+
dbgs() << " Key: ([" << std::get<0>(EC.first)
1349+
<< "]: " << *std::get<0>(EC.first) << ", " << std::get<1>(EC.first)
1350+
<< ", " << std::get<2>(EC.first) << ", "
1351+
<< static_cast<int>(std::get<3>(EC.first)) << ")\n";
1352+
for (const auto &Inst : EC.second)
1353+
dbgs() << "\tInst: " << *Inst << '\n';
1354+
}
1355+
});
1356+
LLVM_DEBUG({
1357+
dbgs() << "LSV: mergeEquivalenceClasses: RedKeyToUOMap:\n";
1358+
for (const auto &RedKeyToUO : RedKeyToUOMap) {
1359+
dbgs() << " Reduced key: (" << std::get<0>(RedKeyToUO.first) << ", "
1360+
<< std::get<1>(RedKeyToUO.first) << ", "
1361+
<< static_cast<int>(std::get<2>(RedKeyToUO.first)) << ") --> "
1362+
<< RedKeyToUO.second.size() << " underlying objects:\n";
1363+
for (auto UObject : RedKeyToUO.second)
1364+
dbgs() << " [" << UObject << "]: " << *UObject << '\n';
1365+
}
1366+
});
1367+
1368+
using UObjectToUObjectMap = DenseMap<const Value *, const Value *>;
1369+
1370+
// Compute the ultimate targets for a set of underlying objects.
1371+
auto GetUltimateTargets =
1372+
[](SmallPtrSetImpl<const Value *> &UObjects) -> UObjectToUObjectMap {
1373+
UObjectToUObjectMap IndirectionMap;
1374+
for (const auto *UObject : UObjects) {
1375+
const unsigned MaxLookupDepth = 1; // look for 1-level indirections only
1376+
const auto *UltimateTarget = getUnderlyingObject(UObject, MaxLookupDepth);
1377+
if (UltimateTarget != UObject)
1378+
IndirectionMap[UObject] = UltimateTarget;
1379+
}
1380+
UObjectToUObjectMap UltimateTargetsMap;
1381+
for (const auto *UObject : UObjects) {
1382+
auto Target = UObject;
1383+
auto It = IndirectionMap.find(Target);
1384+
for (; It != IndirectionMap.end(); It = IndirectionMap.find(Target))
1385+
Target = It->second;
1386+
UltimateTargetsMap[UObject] = Target;
1387+
}
1388+
return UltimateTargetsMap;
1389+
};
1390+
1391+
// For each item in RedKeyToUOMap, if it has more than one underlying object,
1392+
// try to merge the equivalence classes.
1393+
for (auto &[RedKey, UObjects] : RedKeyToUOMap) {
1394+
if (UObjects.size() < 2)
1395+
continue;
1396+
auto UTMap = GetUltimateTargets(UObjects);
1397+
for (const auto &[UObject, UltimateTarget] : UTMap) {
1398+
if (UObject == UltimateTarget)
1399+
continue;
1400+
1401+
EqClassKey KeyFrom{UObject, std::get<0>(RedKey), std::get<1>(RedKey),
1402+
std::get<2>(RedKey)};
1403+
EqClassKey KeyTo{UltimateTarget, std::get<0>(RedKey), std::get<1>(RedKey),
1404+
std::get<2>(RedKey)};
1405+
const auto &VecFrom = EQClasses[KeyFrom];
1406+
const auto &VecTo = EQClasses[KeyTo];
1407+
SmallVector<Instruction *, 8> MergedVec;
1408+
std::merge(VecFrom.begin(), VecFrom.end(), VecTo.begin(), VecTo.end(),
1409+
std::back_inserter(MergedVec),
1410+
[](Instruction *A, Instruction *B) {
1411+
return A && B && A->comesBefore(B);
1412+
});
1413+
EQClasses[KeyTo] = std::move(MergedVec);
1414+
EQClasses.erase(KeyFrom);
1415+
}
1416+
}
1417+
LLVM_DEBUG({
1418+
dbgs() << "LSV: mergeEquivalenceClasses: after merging:\n";
1419+
for (const auto &EC : EQClasses) {
1420+
dbgs() << " Key: ([" << std::get<0>(EC.first)
1421+
<< "]: " << *std::get<0>(EC.first) << ", " << std::get<1>(EC.first)
1422+
<< ", " << std::get<2>(EC.first) << ", "
1423+
<< static_cast<int>(std::get<3>(EC.first)) << ")\n";
1424+
for (const auto &Inst : EC.second)
1425+
dbgs() << "\tInst: " << *Inst << '\n';
1426+
}
1427+
});
1428+
}
1429+
13081430
EquivalenceClassMap
13091431
Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
13101432
BasicBlock::iterator End) {
@@ -1377,6 +1499,7 @@ Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
13771499
.emplace_back(&I);
13781500
}
13791501

1502+
mergeEquivalenceClasses(Ret);
13801503
return Ret;
13811504
}
13821505

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt %s -mtriple=x86_64-unknown-linux-gnu -passes=load-store-vectorizer -mcpu=skx -S -o - | FileCheck %s
3+
4+
; This test verifies that the vectorizer can handle an extended sequence of
5+
; getelementptr instructions and generate longer vectors. With special handling,
6+
; some elements can still be vectorized even if they require looking up the
7+
; common underlying object deeper than 6 levels from the original pointer.
8+
9+
; The test below is the simplified version of actual performance oriented
10+
; workload; the offsets in getelementptr instructions are similar or same for
11+
; the test simplicity.
12+
13+
define void @v1_v2_v4_v1_to_v8_levels_6_7_8_8(i32 %arg0, ptr align 16 %arg1) {
14+
; CHECK-LABEL: define void @v1_v2_v4_v1_to_v8_levels_6_7_8_8(
15+
; CHECK-SAME: i32 [[ARG0:%.*]], ptr align 16 [[ARG1:%.*]]) #[[ATTR0:[0-9]+]] {
16+
; CHECK-NEXT: [[LEVEL1:%.*]] = getelementptr i8, ptr [[ARG1]], i32 917504
17+
; CHECK-NEXT: [[LEVEL2:%.*]] = getelementptr i8, ptr [[LEVEL1]], i32 [[ARG0]]
18+
; CHECK-NEXT: [[LEVEL3:%.*]] = getelementptr i8, ptr [[LEVEL2]], i32 32768
19+
; CHECK-NEXT: [[LEVEL4:%.*]] = getelementptr i8, ptr [[LEVEL3]], i32 [[ARG0]]
20+
; CHECK-NEXT: [[LEVEL5:%.*]] = getelementptr i8, ptr [[LEVEL4]], i32 [[ARG0]]
21+
; CHECK-NEXT: [[A6:%.*]] = getelementptr i8, ptr [[LEVEL5]], i32 [[ARG0]]
22+
; CHECK-NEXT: store <8 x half> zeroinitializer, ptr [[A6]], align 16
23+
; CHECK-NEXT: ret void
24+
;
25+
26+
%level1 = getelementptr i8, ptr %arg1, i32 917504
27+
%level2 = getelementptr i8, ptr %level1, i32 %arg0
28+
%level3 = getelementptr i8, ptr %level2, i32 32768
29+
%level4 = getelementptr i8, ptr %level3, i32 %arg0
30+
%level5 = getelementptr i8, ptr %level4, i32 %arg0
31+
32+
%a6 = getelementptr i8, ptr %level5, i32 %arg0
33+
%b7 = getelementptr i8, ptr %a6, i32 2
34+
%c8 = getelementptr i8, ptr %b7, i32 8
35+
%d8 = getelementptr i8, ptr %b7, i32 12
36+
37+
store half 0xH0000, ptr %a6, align 16
38+
store <4 x half> zeroinitializer, ptr %b7, align 2
39+
store <2 x half> zeroinitializer, ptr %c8, align 2
40+
store half 0xH0000, ptr %d8, align 2
41+
ret void
42+
}
43+
44+
define void @v1x8_levels_6_7_8_9_10_11_12_13(i32 %arg0, ptr align 16 %arg1) {
45+
; CHECK-LABEL: define void @v1x8_levels_6_7_8_9_10_11_12_13(
46+
; CHECK-SAME: i32 [[ARG0:%.*]], ptr align 16 [[ARG1:%.*]]) #[[ATTR0]] {
47+
; CHECK-NEXT: [[LEVEL1:%.*]] = getelementptr i8, ptr [[ARG1]], i32 917504
48+
; CHECK-NEXT: [[LEVEL2:%.*]] = getelementptr i8, ptr [[LEVEL1]], i32 [[ARG0]]
49+
; CHECK-NEXT: [[LEVEL3:%.*]] = getelementptr i8, ptr [[LEVEL2]], i32 32768
50+
; CHECK-NEXT: [[LEVEL4:%.*]] = getelementptr i8, ptr [[LEVEL3]], i32 [[ARG0]]
51+
; CHECK-NEXT: [[LEVEL5:%.*]] = getelementptr i8, ptr [[LEVEL4]], i32 [[ARG0]]
52+
; CHECK-NEXT: [[A6:%.*]] = getelementptr i8, ptr [[LEVEL5]], i32 [[ARG0]]
53+
; CHECK-NEXT: store <8 x half> zeroinitializer, ptr [[A6]], align 16
54+
; CHECK-NEXT: ret void
55+
;
56+
57+
%level1 = getelementptr i8, ptr %arg1, i32 917504
58+
%level2 = getelementptr i8, ptr %level1, i32 %arg0
59+
%level3 = getelementptr i8, ptr %level2, i32 32768
60+
%level4 = getelementptr i8, ptr %level3, i32 %arg0
61+
%level5 = getelementptr i8, ptr %level4, i32 %arg0
62+
63+
%a6 = getelementptr i8, ptr %level5, i32 %arg0
64+
%b7 = getelementptr i8, ptr %a6, i32 2
65+
%c8 = getelementptr i8, ptr %b7, i32 2
66+
%d9 = getelementptr i8, ptr %c8, i32 2
67+
%e10 = getelementptr i8, ptr %d9, i32 2
68+
%f11 = getelementptr i8, ptr %e10, i32 2
69+
%g12 = getelementptr i8, ptr %f11, i32 2
70+
%h13 = getelementptr i8, ptr %g12, i32 2
71+
72+
store half 0xH0000, ptr %a6, align 16
73+
store half 0xH0000, ptr %b7, align 2
74+
store half 0xH0000, ptr %c8, align 2
75+
store half 0xH0000, ptr %d9, align 2
76+
store half 0xH0000, ptr %e10, align 8
77+
store half 0xH0000, ptr %f11, align 2
78+
store half 0xH0000, ptr %g12, align 2
79+
store half 0xH0000, ptr %h13, align 2
80+
ret void
81+
}
82+
83+
define void @v1_4_4_4_2_1_to_v8_8_levels_6_7(i32 %arg0, ptr addrspace(3) align 16 %arg1_ptr, i32 %arg2, i32 %arg3, i32 %arg4, i32 %arg5, half %arg6_half, half %arg7_half, <2 x half> %arg8_2xhalf) {
84+
; CHECK-LABEL: define void @v1_4_4_4_2_1_to_v8_8_levels_6_7(
85+
; CHECK-SAME: i32 [[ARG0:%.*]], ptr addrspace(3) align 16 [[ARG1_PTR:%.*]], i32 [[ARG2:%.*]], i32 [[ARG3:%.*]], i32 [[ARG4:%.*]], i32 [[ARG5:%.*]], half [[ARG6_HALF:%.*]], half [[ARG7_HALF:%.*]], <2 x half> [[ARG8_2XHALF:%.*]]) #[[ATTR0]] {
86+
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[ARG1_PTR]], i32 458752
87+
; CHECK-NEXT: br [[DOTPREHEADER11_PREHEADER:label %.*]]
88+
; CHECK: [[_PREHEADER11_PREHEADER:.*:]]
89+
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw nsw i32 [[ARG0]], 6
90+
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP1]], i32 [[TMP2]]
91+
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP3]], i32 [[ARG2]]
92+
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP4]], i32 [[ARG3]]
93+
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[ARG0]], 2
94+
; CHECK-NEXT: br i1 [[CMP]], [[DOTLR_PH:label %.*]], [[DOTEXIT_POINT:label %.*]]
95+
; CHECK: [[_LR_PH:.*:]]
96+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[TMP5]], i32 [[ARG4]]
97+
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds i8, ptr addrspace(3) [[GEP]], i32 [[ARG5]]
98+
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <8 x half> poison, half [[ARG6_HALF]], i32 0
99+
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <8 x half> [[TMP7]], half 0xH0000, i32 1
100+
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <8 x half> [[TMP8]], half 0xH0000, i32 2
101+
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <8 x half> [[TMP9]], half 0xH0000, i32 3
102+
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <8 x half> [[TMP10]], half 0xH0000, i32 4
103+
; CHECK-NEXT: [[TMP12:%.*]] = extractelement <2 x half> [[ARG8_2XHALF]], i32 0
104+
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <8 x half> [[TMP11]], half [[TMP12]], i32 5
105+
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x half> [[ARG8_2XHALF]], i32 1
106+
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <8 x half> [[TMP13]], half [[TMP14]], i32 6
107+
; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x half> [[TMP15]], half [[ARG7_HALF]], i32 7
108+
; CHECK-NEXT: store <8 x half> [[TMP16]], ptr addrspace(3) [[TMP6]], align 2
109+
; CHECK-NEXT: br [[DOTEXIT_POINT]]
110+
; CHECK: [[_EXIT_POINT:.*:]]
111+
; CHECK-NEXT: ret void
112+
;
113+
%base1 = getelementptr inbounds i8, ptr addrspace(3) %arg1_ptr, i32 458752
114+
br label %.preheader11.preheader
115+
116+
.preheader11.preheader:
117+
%base2 = shl nuw nsw i32 %arg0, 6
118+
%base3 = getelementptr inbounds i8, ptr addrspace(3) %base1, i32 %base2
119+
120+
%base4 = getelementptr inbounds i8, ptr addrspace(3) %base3, i32 %arg2
121+
%base5 = getelementptr inbounds i8, ptr addrspace(3) %base4, i32 %arg3
122+
123+
%cmp = icmp sgt i32 %arg0, 2
124+
br i1 %cmp, label %.lr.ph, label %.exit_point
125+
126+
.lr.ph:
127+
%gep = getelementptr inbounds i8, ptr addrspace(3) %base5, i32 %arg4
128+
129+
%dst = getelementptr inbounds i8, ptr addrspace(3) %gep, i32 %arg5
130+
%dst_off2 = getelementptr inbounds i8, ptr addrspace(3) %dst, i32 2
131+
%dst_off10 = getelementptr inbounds i8, ptr addrspace(3) %dst, i32 10
132+
%dst_off14 = getelementptr inbounds i8, ptr addrspace(3) %dst, i32 14
133+
134+
store half %arg6_half, ptr addrspace(3) %dst, align 2
135+
store <4 x half> zeroinitializer, ptr addrspace(3) %dst_off2, align 2
136+
store <2 x half> %arg8_2xhalf, ptr addrspace(3) %dst_off10, align 2
137+
store half %arg7_half, ptr addrspace(3) %dst_off14, align 2
138+
br label %.exit_point
139+
140+
.exit_point:
141+
ret void
142+
}

0 commit comments

Comments
 (0)