-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SandboxVec][Legality] Implement ShuffleMask #123404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This patch implements a helper ShuffleMask data structure that helps describe shuffles of elements across lanes.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: vporpo (vporpo) ChangesThis patch implements a helper ShuffleMask data structure that helps describe shuffles of elements across lanes. Full diff: https://github.com/llvm/llvm-project/pull/123404.diff 6 Files Affected:
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index c03e7a10397ad2..4858ebaf0770aa 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -25,10 +25,62 @@ class LegalityAnalysis;
class Value;
class InstrMaps;
+class ShuffleMask {
+public:
+ using IndicesVecT = SmallVector<int, 8>;
+
+private:
+ IndicesVecT Indices;
+
+public:
+ ShuffleMask(SmallVectorImpl<int> &&Indices) : Indices(std::move(Indices)) {}
+ ShuffleMask(std::initializer_list<int> Indices) : Indices(Indices) {}
+ explicit ShuffleMask(ArrayRef<int> Indices) : Indices(Indices) {}
+ operator ArrayRef<int>() const { return Indices; }
+ /// Creates and returns an identity shuffle mask of size \p Sz.
+ /// For example if Sz == 4 the returned mask is {0, 1, 2, 3}.
+ static ShuffleMask getIdentity(unsigned Sz) {
+ IndicesVecT Indices;
+ Indices.reserve(Sz);
+ for (auto Idx : seq<int>(0, (int)Sz))
+ Indices.push_back(Idx);
+ return ShuffleMask(std::move(Indices));
+ }
+ /// \Returns true if the mask is a perfect identity mask with consecutive
+ /// indices, i.e., performs no lane shuffling, like 0,1,2,3...
+ bool isIdentity() const {
+ for (auto [Idx, Elm] : enumerate(Indices)) {
+ if ((int)Idx != Elm)
+ return false;
+ }
+ return true;
+ }
+ bool operator==(const ShuffleMask &Other) const {
+ return Indices == Other.Indices;
+ }
+ bool operator!=(const ShuffleMask &Other) const { return !(*this == Other); }
+ size_t size() const { return Indices.size(); }
+ int operator[](int Idx) const { return Indices[Idx]; }
+ using const_iterator = IndicesVecT::const_iterator;
+ const_iterator begin() const { return Indices.begin(); }
+ const_iterator end() const { return Indices.end(); }
+#ifndef NDEBUG
+ friend raw_ostream &operator<<(raw_ostream &OS, const ShuffleMask &Mask) {
+ Mask.print(OS);
+ return OS;
+ }
+ void print(raw_ostream &OS) const {
+ interleave(Indices, OS, [&OS](auto Elm) { OS << Elm; }, ",");
+ }
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
enum class LegalityResultID {
- Pack, ///> Collect scalar values.
- Widen, ///> Vectorize by combining scalars to a vector.
- DiamondReuse, ///> Don't generate new code, reuse existing vector.
+ Pack, ///> Collect scalar values.
+ Widen, ///> Vectorize by combining scalars to a vector.
+ DiamondReuse, ///> Don't generate new code, reuse existing vector.
+ DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
};
/// The reason for vectorizing or not vectorizing.
@@ -54,6 +106,8 @@ struct ToStr {
return "Widen";
case LegalityResultID::DiamondReuse:
return "DiamondReuse";
+ case LegalityResultID::DiamondReuseWithShuffle:
+ return "DiamondReuseWithShuffle";
}
llvm_unreachable("Unknown LegalityResultID enum");
}
@@ -154,6 +208,22 @@ class DiamondReuse final : public LegalityResult {
Value *getVector() const { return Vec; }
};
+class DiamondReuseWithShuffle final : public LegalityResult {
+ friend class LegalityAnalysis;
+ Value *Vec;
+ ShuffleMask Mask;
+ DiamondReuseWithShuffle(Value *Vec, const ShuffleMask &Mask)
+ : LegalityResult(LegalityResultID::DiamondReuseWithShuffle), Vec(Vec),
+ Mask(Mask) {}
+
+public:
+ static bool classof(const LegalityResult *From) {
+ return From->getSubclassID() == LegalityResultID::DiamondReuseWithShuffle;
+ }
+ Value *getVector() const { return Vec; }
+ const ShuffleMask &getMask() const { return Mask; }
+};
+
class Pack final : public LegalityResultWithReason {
Pack(ResultReason Reason)
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -192,23 +262,22 @@ class CollectDescr {
CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
: Descrs(std::move(Descrs)) {}
/// If all elements come from a single vector input, then return that vector
- /// and whether we need a shuffle to get them in order.
- std::optional<std::pair<Value *, bool>> getSingleInput() const {
+ /// and also the shuffle mask required to get them in order.
+ std::optional<std::pair<Value *, ShuffleMask>> getSingleInput() const {
const auto &Descr0 = *Descrs.begin();
Value *V0 = Descr0.getValue();
if (!Descr0.needsExtract())
return std::nullopt;
- bool NeedsShuffle = Descr0.getExtractIdx() != 0;
- int Lane = 1;
+ ShuffleMask::IndicesVecT MaskIndices;
+ MaskIndices.push_back(Descr0.getExtractIdx());
for (const auto &Descr : drop_begin(Descrs)) {
if (!Descr.needsExtract())
return std::nullopt;
if (Descr.getValue() != V0)
return std::nullopt;
- if (Descr.getExtractIdx() != Lane++)
- NeedsShuffle = true;
+ MaskIndices.push_back(Descr.getExtractIdx());
}
- return std::make_pair(V0, NeedsShuffle);
+ return std::make_pair(V0, ShuffleMask(std::move(MaskIndices)));
}
bool hasVectorInputs() const {
return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index dd3012f7c9b556..ac051c3b6570ff 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -36,6 +36,8 @@ class BottomUpVec final : public FunctionPass {
/// Erases all dead instructions from the dead instruction candidates
/// collected during vectorization.
void tryEraseDeadInstrs();
+ /// Creates a shuffle instruction that shuffles \p VecOp according to \p Mask.
+ Value *createShuffle(Value *VecOp, const ShuffleMask &Mask);
/// Packs all elements of \p ToPack into a vector and returns that vector.
Value *createPack(ArrayRef<Value *> ToPack);
void collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index f8149c5bc66363..ad3e38e2f1d923 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -20,6 +20,11 @@ namespace llvm::sandboxir {
#define DEBUG_TYPE "SBVec:Legality"
#ifndef NDEBUG
+void ShuffleMask::dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+
void LegalityResult::dump() const {
print(dbgs());
dbgs() << "\n";
@@ -213,13 +218,12 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
auto CollectDescrs = getHowToCollectValues(Bndl);
if (CollectDescrs.hasVectorInputs()) {
if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
- auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
- if (!NeedsShuffle)
+ auto [Vec, Mask] = *ValueShuffleOpt;
+ if (Mask.isIdentity())
return createLegalityResult<DiamondReuse>(Vec);
- llvm_unreachable("TODO: Unimplemented");
- } else {
- llvm_unreachable("TODO: Unimplemented");
+ return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
}
+ llvm_unreachable("TODO: Unimplemented");
}
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index b8e2697839a3c2..d62023ea018846 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -179,6 +179,12 @@ void BottomUpVec::tryEraseDeadInstrs() {
DeadInstrCandidates.clear();
}
+Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) {
+ BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp});
+ return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
+ VecOp->getContext(), "VShuf");
+}
+
Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
@@ -295,6 +301,13 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
NewVec = cast<DiamondReuse>(LegalityRes).getVector();
break;
}
+ case LegalityResultID::DiamondReuseWithShuffle: {
+ auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
+ const ShuffleMask &Mask =
+ cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
+ NewVec = createShuffle(VecOp, Mask);
+ break;
+ }
case LegalityResultID::Pack: {
// If we can't vectorize the seeds then just return.
if (Depth == 0)
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 7bc6e5ac3d7605..a3798af8399087 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -221,3 +221,24 @@ define void @diamond(ptr %ptr) {
store float %sub1, ptr %ptr1
ret void
}
+
+define void @diamondWithShuffle(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffle(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <2 x float> [[VECL]], <2 x float> [[VECL]], <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ %sub0 = fsub float %ld0, %ld1
+ %sub1 = fsub float %ld1, %ld0
+ store float %sub0, ptr %ptr0
+ store float %sub1, ptr %ptr1
+ ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 069bfdba0a7cdb..b421d08bc6b020 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -19,6 +19,7 @@
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -321,7 +322,7 @@ define void @foo(ptr %ptr) {
sandboxir::CollectDescr CD(std::move(Descrs));
EXPECT_TRUE(CD.getSingleInput());
EXPECT_EQ(CD.getSingleInput()->first, VLd);
- EXPECT_EQ(CD.getSingleInput()->second, false);
+ EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(0, 1));
EXPECT_TRUE(CD.hasVectorInputs());
}
{
@@ -331,7 +332,7 @@ define void @foo(ptr %ptr) {
sandboxir::CollectDescr CD(std::move(Descrs));
EXPECT_TRUE(CD.getSingleInput());
EXPECT_EQ(CD.getSingleInput()->first, VLd);
- EXPECT_EQ(CD.getSingleInput()->second, true);
+ EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(1, 0));
EXPECT_TRUE(CD.hasVectorInputs());
}
{
@@ -352,3 +353,95 @@ define void @foo(ptr %ptr) {
EXPECT_FALSE(CD.hasVectorInputs());
}
}
+
+TEST_F(LegalityTest, ShuffleMask) {
+ {
+ // Check SmallVector constructor.
+ SmallVector<int> Indices({0, 1, 2, 3});
+ sandboxir::ShuffleMask Mask(std::move(Indices));
+ EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check initializer_list constructor.
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check ArrayRef constructor.
+ sandboxir::ShuffleMask Mask(ArrayRef<int>({0, 1, 2, 3}));
+ EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check operator ArrayRef<int>().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ ArrayRef<int> Array = Mask;
+ EXPECT_THAT(Array, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check getIdentity().
+ auto IdentityMask = sandboxir::ShuffleMask::getIdentity(4);
+ EXPECT_THAT(IdentityMask, testing::ElementsAre(0, 1, 2, 3));
+ EXPECT_TRUE(IdentityMask.isIdentity());
+ }
+ {
+ // Check isIdentity().
+ sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+ EXPECT_TRUE(Mask1.isIdentity());
+ sandboxir::ShuffleMask Mask2({1, 2, 3, 4});
+ EXPECT_FALSE(Mask2.isIdentity());
+ }
+ {
+ // Check operator==().
+ sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+ sandboxir::ShuffleMask Mask2({0, 1, 2, 3});
+ EXPECT_TRUE(Mask1 == Mask2);
+ EXPECT_FALSE(Mask1 != Mask2);
+ }
+ {
+ // Check operator!=().
+ sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+ sandboxir::ShuffleMask Mask2({0, 1, 2, 4});
+ EXPECT_TRUE(Mask1 != Mask2);
+ EXPECT_FALSE(Mask1 == Mask2);
+ }
+ {
+ // Check size().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ EXPECT_EQ(Mask.size(), 4u);
+ }
+ {
+ // Check operator[].
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ for (auto [Idx, Elm] : enumerate(Mask)) {
+ EXPECT_EQ(Elm, Mask[Idx]);
+ }
+ }
+ {
+ // Check begin(), end().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ sandboxir::ShuffleMask::const_iterator Begin = Mask.begin();
+ sandboxir::ShuffleMask::const_iterator End = Mask.begin();
+ int Idx = 0;
+ for (auto It = Begin; It != End; ++It) {
+ EXPECT_EQ(*It, Mask[Idx++]);
+ }
+ }
+#ifndef NDEBUG
+ {
+ // Check print(OS).
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ std::string Str;
+ raw_string_ostream OS(Str);
+ Mask.print(OS);
+ EXPECT_EQ(Str, "0,1,2,3");
+ }
+ {
+ // Check operator<<().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ std::string Str;
+ raw_string_ostream OS(Str);
+ OS << Mask;
+ EXPECT_EQ(Str, "0,1,2,3");
+ }
+#endif // NDEBUG
+}
|
tmsri
approved these changes
Jan 17, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This patch implements a helper ShuffleMask data structure that helps describe shuffles of elements across lanes.