Skip to content

Commit fbcc84a

Browse files
committed
[SandboxVec][BottomUpVec] Implement InstrMaps
InstrMaps is a helper data structure that maps scalars to vectors and the reverse. This is used by the vectorizer to figure out which vectors it can extract scalar values from.
1 parent 7c51c31 commit fbcc84a

File tree

11 files changed

+512
-104
lines changed

11 files changed

+512
-104
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===- InstructionMaps.h ----------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRUCTIONMAPS_H
10+
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRUCTIONMAPS_H
11+
12+
#include "llvm/ADT/ArrayRef.h"
13+
#include "llvm/ADT/DenseMap.h"
14+
#include "llvm/ADT/SmallSet.h"
15+
#include "llvm/ADT/SmallVector.h"
16+
#include "llvm/SandboxIR/Value.h"
17+
#include "llvm/Support/Casting.h"
18+
#include "llvm/Support/raw_ostream.h"
19+
20+
namespace llvm::sandboxir {
21+
22+
/// Maps the original instructions to the vectorized instrs and the reverse.
23+
/// For now an original instr can only map to a single vector.
24+
class InstrMaps {
25+
/// A map from the original values that got combined into vectors, to the
26+
/// vector value(s).
27+
DenseMap<Value *, Value *> OrigToVectorMap;
28+
/// A map from the vector value to a map of the original value to its lane.
29+
/// Please note that for constant vectors, there may multiple original values
30+
/// with the same lane, as they may be coming from vectorizing different
31+
/// original values.
32+
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
33+
34+
public:
35+
/// \Returns all the vector value that we got from vectorizing \p Orig, or
36+
/// nullptr if not found.
37+
Value *getVectorForOrig(Value *Orig) const {
38+
auto It = OrigToVectorMap.find(Orig);
39+
return It != OrigToVectorMap.end() ? It->second : nullptr;
40+
}
41+
/// \Returns the lane of \p Orig before it got vectorized into \p Vec, or
42+
/// nullopt if not found.
43+
std::optional<int> getOrigLane(Value *Vec, Value *Orig) const {
44+
auto It1 = VectorToOrigLaneMap.find(Vec);
45+
if (It1 == VectorToOrigLaneMap.end())
46+
return std::nullopt;
47+
const auto &OrigToLaneMap = It1->second;
48+
auto It2 = OrigToLaneMap.find(Orig);
49+
if (It2 == OrigToLaneMap.end())
50+
return std::nullopt;
51+
return It2->second;
52+
}
53+
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
54+
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
55+
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
56+
for (auto [Lane, Orig] : enumerate(Origs)) {
57+
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
58+
assert(Pair.second && "Orig already exists in the map!");
59+
OrigToLaneMap[Orig] = Lane;
60+
}
61+
}
62+
/// Clear all state.
63+
void clear();
64+
65+
#ifndef NDEBUG
66+
void print(raw_ostream &OS) const;
67+
LLVM_DUMP_METHOD void dump() const;
68+
#endif
69+
};
70+
} // namespace llvm::sandboxir
71+
72+
#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRUCTIONMAPS_H

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ namespace llvm::sandboxir {
2323

2424
class LegalityAnalysis;
2525
class Value;
26+
class InstrMaps;
2627

2728
enum class LegalityResultID {
28-
Pack, ///> Collect scalar values.
29-
Widen, ///> Vectorize by combining scalars to a vector.
29+
Pack, ///> Collect scalar values.
30+
Widen, ///> Vectorize by combining scalars to a vector.
31+
DiamondReuse, ///> Don't generate new code, reuse existing vector.
3032
};
3133

3234
/// The reason for vectorizing or not vectorizing.
@@ -50,6 +52,8 @@ struct ToStr {
5052
return "Pack";
5153
case LegalityResultID::Widen:
5254
return "Widen";
55+
case LegalityResultID::DiamondReuse:
56+
return "DiamondReuse";
5357
}
5458
llvm_unreachable("Unknown LegalityResultID enum");
5559
}
@@ -137,6 +141,19 @@ class Widen final : public LegalityResult {
137141
}
138142
};
139143

144+
class DiamondReuse final : public LegalityResult {
145+
friend class LegalityAnalysis;
146+
Value *Vec;
147+
DiamondReuse(Value *Vec)
148+
: LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}
149+
150+
public:
151+
static bool classof(const LegalityResult *From) {
152+
return From->getSubclassID() == LegalityResultID::DiamondReuse;
153+
}
154+
Value *getVector() const { return Vec; }
155+
};
156+
140157
class Pack final : public LegalityResultWithReason {
141158
Pack(ResultReason Reason)
142159
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -148,6 +165,57 @@ class Pack final : public LegalityResultWithReason {
148165
}
149166
};
150167

168+
/// Describes how to collect the values needed by each lane.
169+
class CollectDescr {
170+
public:
171+
/// Describes how to get a value element. If the value is a vector then it
172+
/// also provides the index to extract it from.
173+
class ExtractElementDescr {
174+
Value *V;
175+
/// The index in `V` that the value can be extracted from.
176+
/// This is nullopt if we need to use `V` as a whole.
177+
std::optional<int> ExtractIdx;
178+
179+
public:
180+
ExtractElementDescr(Value *V, int ExtractIdx)
181+
: V(V), ExtractIdx(ExtractIdx) {}
182+
ExtractElementDescr(Value *V) : V(V), ExtractIdx(std::nullopt) {}
183+
Value *getValue() const { return V; }
184+
bool needsExtract() const { return ExtractIdx.has_value(); }
185+
int getExtractIdx() const { return *ExtractIdx; }
186+
};
187+
188+
using DescrVecT = SmallVector<ExtractElementDescr, 4>;
189+
DescrVecT Descrs;
190+
191+
public:
192+
CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
193+
: Descrs(std::move(Descrs)) {}
194+
std::optional<std::pair<Value *, bool>> getSingleInput() const {
195+
const auto &Descr0 = *Descrs.begin();
196+
Value *V0 = Descr0.getValue();
197+
if (!Descr0.needsExtract())
198+
return std::nullopt;
199+
bool NeedsShuffle = Descr0.getExtractIdx() != 0;
200+
int Lane = 1;
201+
for (const auto &Descr : drop_begin(Descrs)) {
202+
if (!Descr.needsExtract())
203+
return std::nullopt;
204+
if (Descr.getValue() != V0)
205+
return std::nullopt;
206+
if (Descr.getExtractIdx() != Lane++)
207+
NeedsShuffle = true;
208+
}
209+
return std::make_pair(V0, NeedsShuffle);
210+
}
211+
bool hasVectorInputs() const {
212+
return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
213+
}
214+
const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
215+
return Descrs;
216+
}
217+
};
218+
151219
/// Performs the legality analysis and returns a LegalityResult object.
152220
class LegalityAnalysis {
153221
Scheduler Sched;
@@ -160,11 +228,17 @@ class LegalityAnalysis {
160228

161229
ScalarEvolution &SE;
162230
const DataLayout &DL;
231+
InstrMaps &IMaps;
232+
233+
/// Finds how we can collect the values in \p Bndl from the vectorized or
234+
/// non-vectorized code. It returns a map of the value we should extract from
235+
/// and the corresponding shuffle mask we need to use.
236+
CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;
163237

164238
public:
165239
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
166-
Context &Ctx)
167-
: Sched(AA, Ctx), SE(SE), DL(DL) {}
240+
Context &Ctx, InstrMaps &IMaps)
241+
: Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
168242
/// A LegalityResult factory.
169243
template <typename ResultT, typename... ArgsT>
170244
ResultT &createLegalityResult(ArgsT... Args) {
@@ -177,7 +251,7 @@ class LegalityAnalysis {
177251
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
178252
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
179253
bool SkipScheduling = false);
180-
void clear() { Sched.clear(); }
254+
void clear();
181255
};
182256

183257
} // namespace llvm::sandboxir

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/SandboxIR/Pass.h"
1919
#include "llvm/SandboxIR/PassManager.h"
2020
#include "llvm/Support/raw_ostream.h"
21+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
2122
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
2223

2324
namespace llvm::sandboxir {
@@ -26,6 +27,8 @@ class BottomUpVec final : public FunctionPass {
2627
bool Change = false;
2728
std::unique_ptr<LegalityAnalysis> Legality;
2829
DenseSet<Instruction *> DeadInstrCandidates;
30+
/// Maps scalars to vectors.
31+
InstrMaps IMaps;
2932

3033
/// Creates and returns a vector instruction that replaces the instructions in
3134
/// \p Bndl. \p Operands are the already vectorized operands.

llvm/lib/Transforms/Vectorize/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_llvm_component_library(LLVMVectorize
44
LoopVectorizationLegality.cpp
55
LoopVectorize.cpp
66
SandboxVectorizer/DependencyGraph.cpp
7+
SandboxVectorizer/InstrMaps.cpp
78
SandboxVectorizer/Interval.cpp
89
SandboxVectorizer/Legality.cpp
910
SandboxVectorizer/Passes/BottomUpVec.cpp
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===- InstructionMaps.cpp - Maps scalars to vectors and reverse ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
10+
#include "llvm/Support/Debug.h"
11+
12+
namespace llvm::sandboxir {
13+
14+
void InstrMaps::clear() {
15+
OrigToVectorMap.clear();
16+
VectorToOrigLaneMap.clear();
17+
}
18+
19+
#ifndef NDEBUG
20+
void InstrMaps::print(raw_ostream &OS) const {
21+
for (auto &[Vec, Map] : VectorToOrigLaneMap) {
22+
OS << *Vec << "\n";
23+
SmallVector<std::pair<Value *, unsigned>> SortedOrigLanePairs;
24+
for (auto [Orig, Lane] : Map)
25+
SortedOrigLanePairs.push_back({Orig, Lane});
26+
sort(SortedOrigLanePairs, [](const auto &Pair1, const auto &Pair2) {
27+
int Lane1 = Pair1.second;
28+
int Lane2 = Pair2.second;
29+
return Lane1 < Lane2;
30+
});
31+
for (auto [Orig, Lane] : SortedOrigLanePairs)
32+
OS.indent(4) << "Lane " << Lane << " : " << *Orig << "\n";
33+
}
34+
}
35+
36+
void InstrMaps::dump() const {
37+
print(dbgs());
38+
dbgs() << "\n";
39+
}
40+
#endif // NDEBUG
41+
42+
} // namespace llvm::sandboxir

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/SandboxIR/Utils.h"
1313
#include "llvm/SandboxIR/Value.h"
1414
#include "llvm/Support/Debug.h"
15+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
1516
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
1617

1718
namespace llvm::sandboxir {
@@ -184,6 +185,22 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
184185
}
185186
#endif // NDEBUG
186187

188+
CollectDescr
189+
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
190+
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
191+
Vec.reserve(Bndl.size());
192+
for (auto [Lane, V] : enumerate(Bndl)) {
193+
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
194+
// If there is a vector containing `V`, then get the lane it came from.
195+
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
196+
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
197+
} else {
198+
Vec.emplace_back(V);
199+
}
200+
}
201+
return CollectDescr(std::move(Vec));
202+
}
203+
187204
const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
188205
bool SkipScheduling) {
189206
// If Bndl contains values other than instructions, we need to Pack.
@@ -193,11 +210,21 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
193210
return createLegalityResult<Pack>(ResultReason::NotInstructions);
194211
}
195212

213+
auto CollectDescrs = getHowToCollectValues(Bndl);
214+
if (CollectDescrs.hasVectorInputs()) {
215+
if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
216+
auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
217+
if (!NeedsShuffle)
218+
return createLegalityResult<DiamondReuse>(Vec);
219+
llvm_unreachable("TODO: Unimplemented");
220+
} else {
221+
llvm_unreachable("TODO: Unimplemented");
222+
}
223+
}
224+
196225
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
197226
return createLegalityResult<Pack>(*ReasonOpt);
198227

199-
// TODO: Check for existing vectors containing values in Bndl.
200-
201228
if (!SkipScheduling) {
202229
// TODO: Try to remove the IBndl vector.
203230
SmallVector<Instruction *, 8> IBndl;
@@ -210,4 +237,9 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
210237

211238
return createLegalityResult<Widen>();
212239
}
240+
241+
void LegalityAnalysis::clear() {
242+
Sched.clear();
243+
IMaps.clear();
244+
}
213245
} // namespace llvm::sandboxir

0 commit comments

Comments
 (0)