Skip to content

[SanbdoxVec][BottomUpVec] Fix diamond shuffle with multiple vector inputs #126965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Feb 12, 2025

When the operand comes from multiple inputs then we need additional packing code. When the operands are scalar then we can use a single InsertElementInst. But when the operands are vectors then we need a chain of ExtractElementInst and InsertElementInst instructions to insert the vector value into the destination vector. This is what this patch implements.

@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

When the operand comes from multiple inputs then we need additional packing code. When the operands are scalar then we can use a single InsertElementInst. But when the operands are vectors then we need a chain of ExtractElementInst and InsertElementInst instructions to insert the vector value into the destination vector. This is what this patch implements.


Full diff: https://github.com/llvm/llvm-project/pull/126965.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (+1-4)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+29-6)
  • (modified) llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll (+33)
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 366243231379f..e8331933594da 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -202,20 +202,17 @@ CollectDescr
 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
   Vec.reserve(Bndl.size());
-  uint32_t LaneAccum;
   for (auto [Elm, V] : enumerate(Bndl)) {
-    uint32_t VLanes = VecUtils::getNumLanes(V);
     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
       // If there is a vector containing `V`, then get the lane it came from.
       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
       // This could be a vector, like <2 x float> in which case the mask needs
       // to enumerate all lanes.
-      for (int Ln = 0; Ln != VLanes; ++Ln)
+      for (uint32_t Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln)
         Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
     } else {
       Vec.emplace_back(V);
     }
-    LaneAccum += VLanes;
   }
   return CollectDescr(std::move(Vec));
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 4fb029d3344b8..0ccef5aecd28b 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -335,7 +335,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
   case LegalityResultID::DiamondReuseMultiInput: {
     const auto &Descr =
         cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
-    Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
+    Type *ResTy = VecUtils::getWideType(Bndl[0]->getType(), Bndl.size());
 
     // TODO: Try to get WhereIt without creating a vector.
     SmallVector<Value *, 4> DescrInstrs;
@@ -347,7 +347,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
         getInsertPointAfterInstrs(DescrInstrs, UserBB);
 
     Value *LastV = PoisonValue::get(ResTy);
-    for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
+    unsigned Lane = 0;
+    for (const auto &ElmDescr : Descr.getDescrs()) {
       Value *VecOp = ElmDescr.getValue();
       Context &Ctx = VecOp->getContext();
       Value *ValueToInsert;
@@ -359,10 +360,32 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
       } else {
         ValueToInsert = VecOp;
       }
-      ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
-      Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
-                                             WhereIt, Ctx, "VIns");
-      LastV = Ins;
+      auto NumLanesToInsert = VecUtils::getNumLanes(ValueToInsert);
+      if (NumLanesToInsert == 1) {
+        // If we are inserting a scalar element then we need a single insert.
+        //   %VIns = insert %DstVec,  %SrcScalar, Lane
+        ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
+        LastV = InsertElementInst::create(LastV, ValueToInsert, LaneC, WhereIt,
+                                          Ctx, "VIns");
+      } else {
+        // If we are inserting a vector element then we need to extract and
+        // insert each vector element one by one with a chain of extracts and
+        // inserts, for example:
+        //   %VExt0 = extract %SrcVec, 0
+        //   %VIns0 = insert  %DstVec, %Vect0, Lane + 0
+        //   %VExt1 = extract %SrcVec, 1
+        //   %VIns1 = insert  %VIns0,  %Vect0, Lane + 1
+        for (unsigned LnCnt = 0; LnCnt != NumLanesToInsert; ++LnCnt) {
+          auto *ExtrIdxC = ConstantInt::get(Type::getInt32Ty(Ctx), LnCnt);
+          auto *ExtrI = ExtractElementInst::create(ValueToInsert, ExtrIdxC,
+                                                   WhereIt, Ctx, "VExt");
+          unsigned InsLane = Lane + LnCnt;
+          auto *InsLaneC = ConstantInt::get(Type::getInt32Ty(Ctx), InsLane);
+          LastV = InsertElementInst::create(LastV, ExtrI, InsLaneC, WhereIt,
+                                            Ctx, "VIns");
+        }
+      }
+      Lane += NumLanesToInsert;
     }
     NewVec = LastV;
     break;
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 301d6649669f4..6b18d4069e0ae 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -292,6 +292,39 @@ define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
   ret void
 }
 
+; Same but vectorizing <2 x float> vectors instead of scalars.
+define void @diamondMultiInputVector(ptr %ptr, ptr %ptrX) {
+; CHECK-LABEL: define void @diamondMultiInputVector(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[LDX:%.*]] = load <2 x float>, ptr [[PTRX]], align 8
+; CHECK-NEXT:    [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT:    [[VEXT:%.*]] = extractelement <2 x float> [[LDX]], i32 0
+; CHECK-NEXT:    [[INSI:%.*]] = insertelement <4 x float> poison, float [[VEXT]], i32 0
+; CHECK-NEXT:    [[VEXT1:%.*]] = extractelement <2 x float> [[LDX]], i32 1
+; CHECK-NEXT:    [[INSI2:%.*]] = insertelement <4 x float> [[INSI]], float [[VEXT1]], i32 1
+; CHECK-NEXT:    [[VEXT3:%.*]] = extractelement <4 x float> [[VECL]], i32 0
+; CHECK-NEXT:    [[VINS4:%.*]] = insertelement <4 x float> [[INSI2]], float [[VEXT3]], i32 2
+; CHECK-NEXT:    [[VEXT4:%.*]] = extractelement <4 x float> [[VECL]], i32 1
+; CHECK-NEXT:    [[VINS5:%.*]] = insertelement <4 x float> [[VINS4]], float [[VEXT4]], i32 3
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VINS5]]
+; CHECK-NEXT:    store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+  %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+  %ld0 = load <2 x float>, ptr %ptr0
+  %ld1 = load <2 x float>, ptr %ptr1
+
+  %ldX = load <2 x float>, ptr %ptrX
+
+  %sub0 = fsub <2 x float> %ld0, %ldX
+  %sub1 = fsub <2 x float> %ld1, %ld0
+  store <2 x float> %sub0, ptr %ptr0
+  store <2 x float> %sub1, ptr %ptr1
+  ret void
+}
+
 define void @diamondWithConstantVector(ptr %ptr) {
 ; CHECK-LABEL: define void @diamondWithConstantVector(
 ; CHECK-SAME: ptr [[PTR:%.*]]) {

…puts

When the operand comes from multiple inputs then we need additional packing
code. When the operands are scalar then we can use a single InsertElementInst.
But when the operands are vectors then we need a chain of ExtractElementInst
and InsertElementInst instructions to insert the vector value into the
destination vector. This is what this patch implements.
@vporpo vporpo merged commit 31cb807 into llvm:main Feb 12, 2025
6 of 8 checks passed
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
…puts (llvm#126965)

When the operand comes from multiple inputs then we need additional
packing code. When the operands are scalar then we can use a single
InsertElementInst. But when the operands are vectors then we need a
chain of ExtractElementInst and InsertElementInst instructions to insert
the vector value into the destination vector. This is what this patch
implements.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…puts (llvm#126965)

When the operand comes from multiple inputs then we need additional
packing code. When the operands are scalar then we can use a single
InsertElementInst. But when the operands are vectors then we need a
chain of ExtractElementInst and InsertElementInst instructions to insert
the vector value into the destination vector. This is what this patch
implements.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…puts (llvm#126965)

When the operand comes from multiple inputs then we need additional
packing code. When the operands are scalar then we can use a single
InsertElementInst. But when the operands are vectors then we need a
chain of ExtractElementInst and InsertElementInst instructions to insert
the vector value into the destination vector. This is what this patch
implements.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants