Skip to content

Commit e902c69

Browse files
authored
[SandboxVec][BottomUpVec] Implement InstrMaps (#122848)
InstrMaps is a helper data structure that maps scalars to vectors and the reverse. This is used by the vectorizer to figure out which vectors it can extract scalar values from.
1 parent 4e9f04c commit e902c69

File tree

11 files changed

+498
-104
lines changed

11 files changed

+498
-104
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===- InstrMaps.h ----------------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
10+
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H
11+
12+
#include "llvm/ADT/ArrayRef.h"
13+
#include "llvm/ADT/DenseMap.h"
14+
#include "llvm/ADT/SmallSet.h"
15+
#include "llvm/ADT/SmallVector.h"
16+
#include "llvm/SandboxIR/Value.h"
17+
#include "llvm/Support/Casting.h"
18+
#include "llvm/Support/raw_ostream.h"
19+
20+
namespace llvm::sandboxir {
21+
22+
/// Maps the original instructions to the vectorized instrs and the reverse.
23+
/// For now an original instr can only map to a single vector.
24+
class InstrMaps {
25+
/// A map from the original values that got combined into vectors, to the
26+
/// vector value(s).
27+
DenseMap<Value *, Value *> OrigToVectorMap;
28+
/// A map from the vector value to a map of the original value to its lane.
29+
/// Please note that for constant vectors, there may multiple original values
30+
/// with the same lane, as they may be coming from vectorizing different
31+
/// original values.
32+
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
33+
34+
public:
35+
/// \Returns the vector value that we got from vectorizing \p Orig, or
36+
/// nullptr if not found.
37+
Value *getVectorForOrig(Value *Orig) const {
38+
auto It = OrigToVectorMap.find(Orig);
39+
return It != OrigToVectorMap.end() ? It->second : nullptr;
40+
}
41+
/// \Returns the lane of \p Orig before it got vectorized into \p Vec, or
42+
/// nullopt if not found.
43+
std::optional<unsigned> getOrigLane(Value *Vec, Value *Orig) const {
44+
auto It1 = VectorToOrigLaneMap.find(Vec);
45+
if (It1 == VectorToOrigLaneMap.end())
46+
return std::nullopt;
47+
const auto &OrigToLaneMap = It1->second;
48+
auto It2 = OrigToLaneMap.find(Orig);
49+
if (It2 == OrigToLaneMap.end())
50+
return std::nullopt;
51+
return It2->second;
52+
}
53+
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
54+
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
55+
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
56+
for (auto [Lane, Orig] : enumerate(Origs)) {
57+
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
58+
assert(Pair.second && "Orig already exists in the map!");
59+
OrigToLaneMap[Orig] = Lane;
60+
}
61+
}
62+
void clear() {
63+
OrigToVectorMap.clear();
64+
VectorToOrigLaneMap.clear();
65+
}
66+
#ifndef NDEBUG
67+
void print(raw_ostream &OS) const {
68+
OS << "OrigToVectorMap:\n";
69+
for (auto [Orig, Vec] : OrigToVectorMap)
70+
OS << *Orig << " : " << *Vec << "\n";
71+
}
72+
LLVM_DUMP_METHOD void dump() const;
73+
#endif
74+
};
75+
} // namespace llvm::sandboxir
76+
77+
#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVEC_PASSES_INSTRMAPS_H

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ namespace llvm::sandboxir {
2323

2424
class LegalityAnalysis;
2525
class Value;
26+
class InstrMaps;
2627

2728
enum class LegalityResultID {
28-
Pack, ///> Collect scalar values.
29-
Widen, ///> Vectorize by combining scalars to a vector.
29+
Pack, ///> Collect scalar values.
30+
Widen, ///> Vectorize by combining scalars to a vector.
31+
DiamondReuse, ///> Don't generate new code, reuse existing vector.
3032
};
3133

3234
/// The reason for vectorizing or not vectorizing.
@@ -50,6 +52,8 @@ struct ToStr {
5052
return "Pack";
5153
case LegalityResultID::Widen:
5254
return "Widen";
55+
case LegalityResultID::DiamondReuse:
56+
return "DiamondReuse";
5357
}
5458
llvm_unreachable("Unknown LegalityResultID enum");
5559
}
@@ -137,6 +141,19 @@ class Widen final : public LegalityResult {
137141
}
138142
};
139143

144+
class DiamondReuse final : public LegalityResult {
145+
friend class LegalityAnalysis;
146+
Value *Vec;
147+
DiamondReuse(Value *Vec)
148+
: LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}
149+
150+
public:
151+
static bool classof(const LegalityResult *From) {
152+
return From->getSubclassID() == LegalityResultID::DiamondReuse;
153+
}
154+
Value *getVector() const { return Vec; }
155+
};
156+
140157
class Pack final : public LegalityResultWithReason {
141158
Pack(ResultReason Reason)
142159
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -148,6 +165,59 @@ class Pack final : public LegalityResultWithReason {
148165
}
149166
};
150167

168+
/// Describes how to collect the values needed by each lane.
169+
class CollectDescr {
170+
public:
171+
/// Describes how to get a value element. If the value is a vector then it
172+
/// also provides the index to extract it from.
173+
class ExtractElementDescr {
174+
Value *V;
175+
/// The index in `V` that the value can be extracted from.
176+
/// This is nullopt if we need to use `V` as a whole.
177+
std::optional<int> ExtractIdx;
178+
179+
public:
180+
ExtractElementDescr(Value *V, int ExtractIdx)
181+
: V(V), ExtractIdx(ExtractIdx) {}
182+
ExtractElementDescr(Value *V) : V(V), ExtractIdx(std::nullopt) {}
183+
Value *getValue() const { return V; }
184+
bool needsExtract() const { return ExtractIdx.has_value(); }
185+
int getExtractIdx() const { return *ExtractIdx; }
186+
};
187+
188+
using DescrVecT = SmallVector<ExtractElementDescr, 4>;
189+
DescrVecT Descrs;
190+
191+
public:
192+
CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
193+
: Descrs(std::move(Descrs)) {}
194+
/// If all elements come from a single vector input, then return that vector
195+
/// and whether we need a shuffle to get them in order.
196+
std::optional<std::pair<Value *, bool>> getSingleInput() const {
197+
const auto &Descr0 = *Descrs.begin();
198+
Value *V0 = Descr0.getValue();
199+
if (!Descr0.needsExtract())
200+
return std::nullopt;
201+
bool NeedsShuffle = Descr0.getExtractIdx() != 0;
202+
int Lane = 1;
203+
for (const auto &Descr : drop_begin(Descrs)) {
204+
if (!Descr.needsExtract())
205+
return std::nullopt;
206+
if (Descr.getValue() != V0)
207+
return std::nullopt;
208+
if (Descr.getExtractIdx() != Lane++)
209+
NeedsShuffle = true;
210+
}
211+
return std::make_pair(V0, NeedsShuffle);
212+
}
213+
bool hasVectorInputs() const {
214+
return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
215+
}
216+
const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
217+
return Descrs;
218+
}
219+
};
220+
151221
/// Performs the legality analysis and returns a LegalityResult object.
152222
class LegalityAnalysis {
153223
Scheduler Sched;
@@ -160,11 +230,17 @@ class LegalityAnalysis {
160230

161231
ScalarEvolution &SE;
162232
const DataLayout &DL;
233+
InstrMaps &IMaps;
234+
235+
/// Finds how we can collect the values in \p Bndl from the vectorized or
236+
/// non-vectorized code. It returns a map of the value we should extract from
237+
/// and the corresponding shuffle mask we need to use.
238+
CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;
163239

164240
public:
165241
LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
166-
Context &Ctx)
167-
: Sched(AA, Ctx), SE(SE), DL(DL) {}
242+
Context &Ctx, InstrMaps &IMaps)
243+
: Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
168244
/// A LegalityResult factory.
169245
template <typename ResultT, typename... ArgsT>
170246
ResultT &createLegalityResult(ArgsT... Args) {
@@ -177,7 +253,7 @@ class LegalityAnalysis {
177253
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
178254
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
179255
bool SkipScheduling = false);
180-
void clear() { Sched.clear(); }
256+
void clear();
181257
};
182258

183259
} // namespace llvm::sandboxir

llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "llvm/SandboxIR/Pass.h"
1919
#include "llvm/SandboxIR/PassManager.h"
2020
#include "llvm/Support/raw_ostream.h"
21+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
2122
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
2223

2324
namespace llvm::sandboxir {
@@ -26,6 +27,8 @@ class BottomUpVec final : public FunctionPass {
2627
bool Change = false;
2728
std::unique_ptr<LegalityAnalysis> Legality;
2829
DenseSet<Instruction *> DeadInstrCandidates;
30+
/// Maps scalars to vectors.
31+
InstrMaps IMaps;
2932

3033
/// Creates and returns a vector instruction that replaces the instructions in
3134
/// \p Bndl. \p Operands are the already vectorized operands.

llvm/lib/Transforms/Vectorize/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_llvm_component_library(LLVMVectorize
44
LoopVectorizationLegality.cpp
55
LoopVectorize.cpp
66
SandboxVectorizer/DependencyGraph.cpp
7+
SandboxVectorizer/InstrMaps.cpp
78
SandboxVectorizer/Interval.cpp
89
SandboxVectorizer/Legality.cpp
910
SandboxVectorizer/Passes/BottomUpVec.cpp
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- InstructionMaps.cpp - Maps scalars to vectors and reverse ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
10+
#include "llvm/Support/Debug.h"
11+
12+
namespace llvm::sandboxir {
13+
14+
#ifndef NDEBUG
15+
void InstrMaps::dump() const {
16+
print(dbgs());
17+
dbgs() << "\n";
18+
}
19+
#endif // NDEBUG
20+
21+
} // namespace llvm::sandboxir

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/SandboxIR/Utils.h"
1313
#include "llvm/SandboxIR/Value.h"
1414
#include "llvm/Support/Debug.h"
15+
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
1516
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
1617

1718
namespace llvm::sandboxir {
@@ -184,6 +185,22 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
184185
}
185186
#endif // NDEBUG
186187

188+
CollectDescr
189+
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
190+
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
191+
Vec.reserve(Bndl.size());
192+
for (auto [Lane, V] : enumerate(Bndl)) {
193+
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
194+
// If there is a vector containing `V`, then get the lane it came from.
195+
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
196+
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
197+
} else {
198+
Vec.emplace_back(V);
199+
}
200+
}
201+
return CollectDescr(std::move(Vec));
202+
}
203+
187204
const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
188205
bool SkipScheduling) {
189206
// If Bndl contains values other than instructions, we need to Pack.
@@ -193,11 +210,21 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
193210
return createLegalityResult<Pack>(ResultReason::NotInstructions);
194211
}
195212

213+
auto CollectDescrs = getHowToCollectValues(Bndl);
214+
if (CollectDescrs.hasVectorInputs()) {
215+
if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
216+
auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
217+
if (!NeedsShuffle)
218+
return createLegalityResult<DiamondReuse>(Vec);
219+
llvm_unreachable("TODO: Unimplemented");
220+
} else {
221+
llvm_unreachable("TODO: Unimplemented");
222+
}
223+
}
224+
196225
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
197226
return createLegalityResult<Pack>(*ReasonOpt);
198227

199-
// TODO: Check for existing vectors containing values in Bndl.
200-
201228
if (!SkipScheduling) {
202229
// TODO: Try to remove the IBndl vector.
203230
SmallVector<Instruction *, 8> IBndl;
@@ -210,4 +237,9 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
210237

211238
return createLegalityResult<Widen>();
212239
}
240+
241+
void LegalityAnalysis::clear() {
242+
Sched.clear();
243+
IMaps.clear();
244+
}
213245
} // namespace llvm::sandboxir

0 commit comments

Comments
 (0)