Skip to content

[SandboxVec][VecUtils] Implement VecUtils::getLowest() #124024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class VecUtils {
}
return FixedVectorType::get(ElemTy, NumElts);
}
/// \Returns the instruction in \p Instrs that is lowest in the BB. Expects
/// that all instructions are in the same BB.
static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
Instruction *LowestI = Instrs.front();
for (auto *I : drop_begin(Instrs)) {
Expand All @@ -108,6 +110,33 @@ class VecUtils {
}
return LowestI;
}
/// \Returns the lowest instruction in \p Vals, or nullptr if no instructions
/// are found or if not in the same BB.
static Instruction *getLowest(ArrayRef<Value *> Vals) {
// Find the first Instruction in Vals.
auto It = find_if(Vals, [](Value *V) { return isa<Instruction>(V); });
// If we couldn't find an instruction return nullptr.
if (It == Vals.end())
return nullptr;
Instruction *FirstI = cast<Instruction>(*It);
// Now look for the lowest instruction in Vals starting from one position
// after FirstI.
Instruction *LowestI = FirstI;
auto *LowestBB = LowestI->getParent();
for (auto *V : make_range(std::next(It), Vals.end())) {
auto *I = dyn_cast<Instruction>(V);
// Skip non-instructions.
if (I == nullptr)
continue;
// If the instructions are in different BBs return nullptr.
if (I->getParent() != LowestBB)
return nullptr;
// If `LowestI` comes before `I` then `I` is the new lowest.
if (LowestI->comesBefore(I))
LowestI = I;
}
return LowestI;
}
/// If all values in \p Bndl are of the same scalar type then return it,
/// otherwise return nullptr.
static Type *tryGetCommonScalarType(ArrayRef<Value *> Bndl) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,

static BasicBlock::iterator
getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
// TODO: Use the VecUtils function for getting the bottom instr once it lands.
auto *BotI = cast<Instruction>(
*std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
}));
auto *BotI = VecUtils::getLowest(Instrs);
// If Bndl contains Arguments or Constants, use the beginning of the BB.
return std::next(BotI->getIterator());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ struct VecUtilsTest : public testing::Test {
}
};

sandboxir::BasicBlock &getBasicBlockByName(sandboxir::Function &F,
StringRef Name) {
for (sandboxir::BasicBlock &BB : F)
if (BB.getName() == Name)
return BB;
llvm_unreachable("Expected to find basic block!");
}

TEST_F(VecUtilsTest, GetNumElements) {
sandboxir::Context Ctx(C);
auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
Expand Down Expand Up @@ -415,21 +423,33 @@ TEST_F(VecUtilsTest, GetLowest) {
parseIR(R"IR(
define void @foo(i8 %v) {
bb0:
%A = add i8 %v, %v
%B = add i8 %v, %v
%C = add i8 %v, %v
br label %bb1
bb1:
%A = add i8 %v, 1
%B = add i8 %v, 2
%C = add i8 %v, 3
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");

sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto It = BB.begin();
auto *IA = &*It++;
auto *IB = &*It++;
auto *IC = &*It++;
auto &BB0 = getBasicBlockByName(F, "bb0");
auto It = BB0.begin();
auto *BB0I = cast<sandboxir::BranchInst>(&*It++);

auto &BB = getBasicBlockByName(F, "bb1");
It = BB.begin();
auto *IA = cast<sandboxir::Instruction>(&*It++);
auto *C1 = cast<sandboxir::Constant>(IA->getOperand(1));
auto *IB = cast<sandboxir::Instruction>(&*It++);
auto *C2 = cast<sandboxir::Constant>(IB->getOperand(1));
auto *IC = cast<sandboxir::Instruction>(&*It++);
auto *C3 = cast<sandboxir::Constant>(IC->getOperand(1));
// Check getLowest(ArrayRef<Instruction *>)
SmallVector<sandboxir::Instruction *> A({IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(A), IA);
SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
Expand All @@ -438,6 +458,27 @@ define void @foo(i8 %v) {
EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);

// Check getLowest(ArrayRef<Value *>)
SmallVector<sandboxir::Value *> C1Only({C1});
EXPECT_EQ(sandboxir::VecUtils::getLowest(C1Only), nullptr);
SmallVector<sandboxir::Value *> AOnly({IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(AOnly), IA);
SmallVector<sandboxir::Value *> AC1({IA, C1});
EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1), IA);
SmallVector<sandboxir::Value *> C1A({C1, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(C1A), IA);
SmallVector<sandboxir::Value *> AC1B({IA, C1, IB});
EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1B), IB);
SmallVector<sandboxir::Value *> ABC1({IA, IB, C1});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC1), IB);
SmallVector<sandboxir::Value *> AC1C2({IA, C1, C2});
EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1C2), IA);
SmallVector<sandboxir::Value *> C1C2C3({C1, C2, C3});
EXPECT_EQ(sandboxir::VecUtils::getLowest(C1C2C3), nullptr);

SmallVector<sandboxir::Value *> DiffBBs({BB0I, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(DiffBBs), nullptr);
}

TEST_F(VecUtilsTest, GetCommonScalarType) {
Expand Down
Loading