Skip to content

[SandboxVec][VecUtils] Implement VecUtils::getLowest() #124024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2025

Conversation

vporpo
Copy link
Contributor

@vporpo vporpo commented Jan 22, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jan 22, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: vporpo (vporpo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/124024.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h (+29)
  • (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (+1-5)
  • (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp (+49-8)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 6cbbb396ea823f..4e3ca2bccfe6fd 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -100,6 +100,8 @@ class VecUtils {
     }
     return FixedVectorType::get(ElemTy, NumElts);
   }
+  /// \Returns the instruction in \p Instrs that is lowest in the BB. Expects
+  /// that all instructions are in the same BB.
   static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
     Instruction *LowestI = Instrs.front();
     for (auto *I : drop_begin(Instrs)) {
@@ -108,6 +110,33 @@ class VecUtils {
     }
     return LowestI;
   }
+  /// \Returns the lowest instruction in \p Vals, or nullptr if no instructions
+  /// are found or if not in the same BB.
+  static Instruction *getLowest(ArrayRef<Value *> Vals) {
+    // Find the first Instruction in Vals.
+    auto It = find_if(Vals, [](Value *V) { return isa<Instruction>(V); });
+    // If we couldn't find an instruction return nullptr.
+    if (It == Vals.end())
+      return nullptr;
+    Instruction *FirstI = cast<Instruction>(*It);
+    // Now look for the lowest instruction in Vals starting from one position
+    // after FirstI.
+    Instruction *LowestI = FirstI;
+    auto *LowestBB = LowestI->getParent();
+    for (auto *V : make_range(std::next(It), Vals.end())) {
+      auto *I = dyn_cast<Instruction>(V);
+      // Skip non-instructions.
+      if (I == nullptr)
+        continue;
+      // If the instructions are in different BBs return nullptr.
+      if (I->getParent() != LowestBB)
+        return nullptr;
+      // If `LowestI` comes before `I` then `I` is the new lowest.
+      if (LowestI->comesBefore(I))
+        LowestI = I;
+    }
+    return LowestI;
+  }
   /// If all values in \p Bndl are of the same scalar type then return it,
   /// otherwise return nullptr.
   static Type *tryGetCommonScalarType(ArrayRef<Value *> Bndl) {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index c6ab3c1942c330..8432b4c6c469ae 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -45,11 +45,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
 
 static BasicBlock::iterator
 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
-  // TODO: Use the VecUtils function for getting the bottom instr once it lands.
-  auto *BotI = cast<Instruction>(
-      *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
-        return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
-      }));
+  auto *BotI = VecUtils::getLowest(Instrs);
   // If Bndl contains Arguments or Constants, use the beginning of the BB.
   return std::next(BotI->getIterator());
 }
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index 8661dcd5067c0a..b69172738d36a5 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -50,6 +50,14 @@ struct VecUtilsTest : public testing::Test {
   }
 };
 
+sandboxir::BasicBlock &getBasicBlockByName(sandboxir::Function &F,
+                                           StringRef Name) {
+  for (sandboxir::BasicBlock &BB : F)
+    if (BB.getName() == Name)
+      return BB;
+  llvm_unreachable("Expected to find basic block!");
+}
+
 TEST_F(VecUtilsTest, GetNumElements) {
   sandboxir::Context Ctx(C);
   auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
@@ -415,9 +423,11 @@ TEST_F(VecUtilsTest, GetLowest) {
   parseIR(R"IR(
 define void @foo(i8 %v) {
 bb0:
-  %A = add i8 %v, %v
-  %B = add i8 %v, %v
-  %C = add i8 %v, %v
+  br label %bb1
+bb1:
+  %A = add i8 %v, 1
+  %B = add i8 %v, 2
+  %C = add i8 %v, 3
   ret void
 }
 )IR");
@@ -425,11 +435,21 @@ define void @foo(i8 %v) {
 
   sandboxir::Context Ctx(C);
   auto &F = *Ctx.createFunction(&LLVMF);
-  auto &BB = *F.begin();
-  auto It = BB.begin();
-  auto *IA = &*It++;
-  auto *IB = &*It++;
-  auto *IC = &*It++;
+  auto &BB0 = getBasicBlockByName(F, "bb0");
+  auto It = BB0.begin();
+  auto *BB0I = cast<sandboxir::BranchInst>(&*It++);
+
+  auto &BB = getBasicBlockByName(F, "bb1");
+  It = BB.begin();
+  auto *IA = cast<sandboxir::Instruction>(&*It++);
+  auto *C1 = cast<sandboxir::Constant>(IA->getOperand(1));
+  auto *IB = cast<sandboxir::Instruction>(&*It++);
+  auto *C2 = cast<sandboxir::Constant>(IB->getOperand(1));
+  auto *IC = cast<sandboxir::Instruction>(&*It++);
+  auto *C3 = cast<sandboxir::Constant>(IC->getOperand(1));
+  // Check getLowest(ArrayRef<Instruction *>)
+  SmallVector<sandboxir::Instruction *> A({IA});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(A), IA);
   SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
   EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
   SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
@@ -438,6 +458,27 @@ define void @foo(i8 %v) {
   EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
   SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
   EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
+
+  // Check getLowest(ArrayRef<Value *>)
+  SmallVector<sandboxir::Value *> C1Only({C1});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(C1Only), nullptr);
+  SmallVector<sandboxir::Value *> AOnly({IA});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(AOnly), IA);
+  SmallVector<sandboxir::Value *> AC1({IA, C1});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1), IA);
+  SmallVector<sandboxir::Value *> C1A({C1, IA});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(C1A), IA);
+  SmallVector<sandboxir::Value *> AC1B({IA, C1, IB});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1B), IB);
+  SmallVector<sandboxir::Value *> ABC1({IA, IB, C1});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC1), IB);
+  SmallVector<sandboxir::Value *> AC1C2({IA, C1, C2});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1C2), IA);
+  SmallVector<sandboxir::Value *> C1C2C3({C1, C2, C3});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(C1C2C3), nullptr);
+
+  SmallVector<sandboxir::Value *> DiffBBs({BB0I, IA});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(DiffBBs), nullptr);
 }
 
 TEST_F(VecUtilsTest, GetCommonScalarType) {

Copy link
Member

@tmsri tmsri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The title is mis-leading.

VecUtils::getLowest(Valse) returns the lowest instruction in the BB among Vals.
If the instructions are not in the same BB, or if none of them is an
instruction it returns nullptr.
@vporpo vporpo changed the title [SandboxVec][VecUtils] Implement getLowest() for non-instr values [SandboxVec][VecUtils] Implement VecUtils::getLowest() Jan 22, 2025
@vporpo
Copy link
Contributor Author

vporpo commented Jan 22, 2025

I updated the title and will add a short description.

@vporpo vporpo merged commit 2dc1c95 into llvm:main Jan 23, 2025
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants