Skip to content

[SandboxVec][Legality] Fix mask on diamond reuse with shuffle #126963

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Feb 12, 2025

This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.

This patch fixes a bug in the creation of shuffle masks when vectorizing vectors
in case of a diamond reuse with shuffle. The mask needs to enumerate all
elements of a vector, not treat the original vector value as a single element.
That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask
needs to have 4 indices, not just 2.
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.


Full diff: https://github.com/llvm/llvm-project/pull/126963.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h (+4-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (+8-2)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+2)
  • (modified) llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll (+22)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp (+27)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index c931319d3b002..9bdf940fc77b7 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -18,6 +18,7 @@
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
 #include <algorithm>
 
 namespace llvm::sandboxir {
@@ -85,11 +86,13 @@ class InstrMaps {
   /// Update the map to reflect that \p Origs got vectorized into \p Vec.
   void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
     auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
-    for (auto [Lane, Orig] : enumerate(Origs)) {
+    unsigned Lane = 0;
+    for (Value *Orig : Origs) {
       auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
       assert(Pair.second && "Orig already exists in the map!");
       (void)Pair;
       OrigToLaneMap[Orig] = Lane;
+      Lane += VecUtils::getNumLanes(Orig);
     }
   }
   void clear() {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index c9329c24e1f4c..366243231379f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -202,14 +202,20 @@ CollectDescr
 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
   Vec.reserve(Bndl.size());
-  for (auto [Lane, V] : enumerate(Bndl)) {
+  uint32_t LaneAccum;
+  for (auto [Elm, V] : enumerate(Bndl)) {
+    uint32_t VLanes = VecUtils::getNumLanes(V);
     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
       // If there is a vector containing `V`, then get the lane it came from.
       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
-      Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
+      // This could be a vector, like <2 x float> in which case the mask needs
+      // to enumerate all lanes.
+      for (int Ln = 0; Ln != VLanes; ++Ln)
+        Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
     } else {
       Vec.emplace_back(V);
     }
+    LaneAccum += VLanes;
   }
   return CollectDescr(std::move(Vec));
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 507d163240127..4fb029d3344b8 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
     const ShuffleMask &Mask =
         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
     NewVec = createShuffle(VecOp, Mask, UserBB);
+    assert(NewVec->getType() == VecOp->getType() &&
+           "Expected same type! Bad mask ?");
     break;
   }
   case LegalityResultID::DiamondReuseMultiInput: {
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 45b937dc1b1b6..301d6649669f4 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
   ret void
 }
 
+; Same but with <2 x float> elements instead of scalars.
+define void @diamondWithShuffleFromVec(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffleFromVec(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT:    [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT:    store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+  %ld0 = load <2 x float>, ptr %ptr0
+  %ld1 = load <2 x float>, ptr %ptr1
+  %sub0 = fsub <2 x float> %ld0, %ld1
+  %sub1 = fsub <2 x float> %ld1, %ld0
+  store <2 x float> %sub0, ptr %ptr0
+  store <2 x float> %sub1, ptr %ptr1
+  ret void
+}
+
 define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
 ; CHECK-LABEL: define void @diamondMultiInput(
 ; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index 1d7c8f9cdde04..5b033f0edcb02 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
   EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
 }
+
+TEST_F(InstrMapsTest, VectorLanes) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
+  %vadd0 = add <2 x i8> %v0, %v1
+  %vadd1 = add <2 x i8> %v0, %v1
+  %vadd2 = add <4 x i8> %v2, %v3
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+
+  auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
+
+  sandboxir::InstrMaps IMaps(Ctx);
+
+  // Check that the vector lanes are calculated correctly.
+  IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

Changes

This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.


Full diff: https://github.com/llvm/llvm-project/pull/126963.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h (+4-1)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (+8-2)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+2)
  • (modified) llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll (+22)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp (+27)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index c931319d3b002..9bdf940fc77b7 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -18,6 +18,7 @@
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
 #include <algorithm>
 
 namespace llvm::sandboxir {
@@ -85,11 +86,13 @@ class InstrMaps {
   /// Update the map to reflect that \p Origs got vectorized into \p Vec.
   void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
     auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
-    for (auto [Lane, Orig] : enumerate(Origs)) {
+    unsigned Lane = 0;
+    for (Value *Orig : Origs) {
       auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
       assert(Pair.second && "Orig already exists in the map!");
       (void)Pair;
       OrigToLaneMap[Orig] = Lane;
+      Lane += VecUtils::getNumLanes(Orig);
     }
   }
   void clear() {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index c9329c24e1f4c..366243231379f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -202,14 +202,20 @@ CollectDescr
 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
   Vec.reserve(Bndl.size());
-  for (auto [Lane, V] : enumerate(Bndl)) {
+  uint32_t LaneAccum;
+  for (auto [Elm, V] : enumerate(Bndl)) {
+    uint32_t VLanes = VecUtils::getNumLanes(V);
     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
       // If there is a vector containing `V`, then get the lane it came from.
       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
-      Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
+      // This could be a vector, like <2 x float> in which case the mask needs
+      // to enumerate all lanes.
+      for (int Ln = 0; Ln != VLanes; ++Ln)
+        Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
     } else {
       Vec.emplace_back(V);
     }
+    LaneAccum += VLanes;
   }
   return CollectDescr(std::move(Vec));
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 507d163240127..4fb029d3344b8 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
     const ShuffleMask &Mask =
         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
     NewVec = createShuffle(VecOp, Mask, UserBB);
+    assert(NewVec->getType() == VecOp->getType() &&
+           "Expected same type! Bad mask ?");
     break;
   }
   case LegalityResultID::DiamondReuseMultiInput: {
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 45b937dc1b1b6..301d6649669f4 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
   ret void
 }
 
+; Same but with <2 x float> elements instead of scalars.
+define void @diamondWithShuffleFromVec(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffleFromVec(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT:    [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT:    store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+  %ld0 = load <2 x float>, ptr %ptr0
+  %ld1 = load <2 x float>, ptr %ptr1
+  %sub0 = fsub <2 x float> %ld0, %ld1
+  %sub1 = fsub <2 x float> %ld1, %ld0
+  store <2 x float> %sub0, ptr %ptr0
+  store <2 x float> %sub1, ptr %ptr1
+  ret void
+}
+
 define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
 ; CHECK-LABEL: define void @diamondMultiInput(
 ; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index 1d7c8f9cdde04..5b033f0edcb02 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
   EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
 }
+
+TEST_F(InstrMapsTest, VectorLanes) {
+  parseIR(C, R"IR(
+define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
+  %vadd0 = add <2 x i8> %v0, %v1
+  %vadd1 = add <2 x i8> %v0, %v1
+  %vadd2 = add <4 x i8> %v2, %v3
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+
+  auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
+
+  sandboxir::InstrMaps IMaps(Ctx);
+
+  // Check that the vector lanes are calculated correctly.
+  IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
+}

@vporpo vporpo merged commit 7a7f919 into llvm:main Feb 12, 2025
8 of 10 checks passed
@@ -202,14 +202,20 @@ CollectDescr
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
Vec.reserve(Bndl.size());
for (auto [Lane, V] : enumerate(Bndl)) {
uint32_t LaneAccum;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use this uninitialized in LaneAccum += VLanes;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is dead code that I accidentally removed in the follow-up patch instead of this one, while maintaining the patch chain. I already pushed the fix: e75e617

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you sent e75e617 to fix that

flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
…26963)

This patch fixes a bug in the creation of shuffle masks when vectorizing
vectors in case of a diamond reuse with shuffle. The mask needs to
enumerate all elements of a vector, not treat the original vector value
as a single element. That is: if vectorizing two <2 x float> vectors
into a <4 x float> the mask needs to have 4 indices, not just 2.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…26963)

This patch fixes a bug in the creation of shuffle masks when vectorizing
vectors in case of a diamond reuse with shuffle. The mask needs to
enumerate all elements of a vector, not treat the original vector value
as a single element. That is: if vectorizing two <2 x float> vectors
into a <4 x float> the mask needs to have 4 indices, not just 2.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…26963)

This patch fixes a bug in the creation of shuffle masks when vectorizing
vectors in case of a diamond reuse with shuffle. The mask needs to
enumerate all elements of a vector, not treat the original vector value
as a single element. That is: if vectorizing two <2 x float> vectors
into a <4 x float> the mask needs to have 4 indices, not just 2.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants