Skip to content

Commit e74b45e

Browse files
authored
[IR2Vec] Adding unit tests (#141873)
This PR adds unit tests for IR2Vec (Tracking issue - #141817)
1 parent a8c6a50 commit e74b45e

File tree

3 files changed

+245
-0
lines changed

3 files changed

+245
-0
lines changed

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
7777
}
7878

7979
void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
80+
assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
8081
std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
8182
std::plus<double>());
8283
}

llvm/unittests/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ set(ANALYSIS_TEST_SOURCES
3232
GlobalsModRefTest.cpp
3333
FunctionPropertiesAnalysisTest.cpp
3434
InlineCostTest.cpp
35+
IR2VecTest.cpp
3536
IRSimilarityIdentifierTest.cpp
3637
IVDescriptorsTest.cpp
3738
LastRunTrackingAnalysisTest.cpp
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Analysis/IR2Vec.h"
10+
#include "llvm/IR/Constants.h"
11+
#include "llvm/IR/Function.h"
12+
#include "llvm/IR/Instructions.h"
13+
#include "llvm/IR/LLVMContext.h"
14+
#include "llvm/IR/Module.h"
15+
#include "llvm/IR/Type.h"
16+
#include "llvm/Support/Error.h"
17+
18+
#include "gmock/gmock.h"
19+
#include "gtest/gtest.h"
20+
#include <map>
21+
#include <vector>
22+
23+
using namespace llvm;
24+
using namespace ir2vec;
25+
using namespace ::testing;
26+
27+
namespace {
28+
29+
class TestableEmbedder : public Embedder {
30+
public:
31+
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
32+
: Embedder(F, V, Dim) {}
33+
void computeEmbeddings() const override {}
34+
using Embedder::lookupVocab;
35+
static void addVectors(Embedding &Dst, const Embedding &Src) {
36+
Embedder::addVectors(Dst, Src);
37+
}
38+
static void addScaledVector(Embedding &Dst, const Embedding &Src,
39+
float Factor) {
40+
Embedder::addScaledVector(Dst, Src, Factor);
41+
}
42+
};
43+
44+
TEST(IR2VecTest, CreateSymbolicEmbedder) {
45+
Vocab V = {{"foo", {1.0, 2.0}}};
46+
47+
LLVMContext Ctx;
48+
Module M("M", Ctx);
49+
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
50+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
51+
52+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
53+
EXPECT_TRUE(static_cast<bool>(Result));
54+
55+
auto *Emb = Result->get();
56+
EXPECT_NE(Emb, nullptr);
57+
}
58+
59+
TEST(IR2VecTest, CreateInvalidMode) {
60+
Vocab V = {{"foo", {1.0, 2.0}}};
61+
62+
LLVMContext Ctx;
63+
Module M("M", Ctx);
64+
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
65+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
66+
67+
// static_cast an invalid int to IR2VecKind
68+
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
69+
EXPECT_FALSE(static_cast<bool>(Result));
70+
71+
std::string ErrMsg;
72+
llvm::handleAllErrors(
73+
Result.takeError(),
74+
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
75+
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
76+
}
77+
78+
TEST(IR2VecTest, AddVectors) {
79+
Embedding E1 = {1.0, 2.0, 3.0};
80+
Embedding E2 = {0.5, 1.5, -1.0};
81+
82+
TestableEmbedder::addVectors(E1, E2);
83+
EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
84+
85+
// Check that E2 is unchanged
86+
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
87+
}
88+
89+
TEST(IR2VecTest, AddScaledVector) {
90+
Embedding E1 = {1.0, 2.0, 3.0};
91+
Embedding E2 = {2.0, 0.5, -1.0};
92+
93+
TestableEmbedder::addScaledVector(E1, E2, 0.5f);
94+
EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
95+
96+
// Check that E2 is unchanged
97+
EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
98+
}
99+
100+
#if GTEST_HAS_DEATH_TEST
101+
#ifndef NDEBUG
102+
TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
103+
Embedding E1 = {1.0, 2.0};
104+
Embedding E2 = {1.0};
105+
EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
106+
"Vectors must have the same dimension");
107+
}
108+
109+
TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
110+
Embedding E1 = {1.0, 2.0};
111+
Embedding E2 = {1.0};
112+
EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
113+
"Vectors must have the same dimension");
114+
}
115+
#endif // NDEBUG
116+
#endif // GTEST_HAS_DEATH_TEST
117+
118+
TEST(IR2VecTest, LookupVocab) {
119+
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
120+
LLVMContext Ctx;
121+
Module M("M", Ctx);
122+
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
123+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
124+
125+
TestableEmbedder E(*F, V, 2);
126+
auto V_foo = E.lookupVocab("foo");
127+
EXPECT_EQ(V_foo.size(), 2u);
128+
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
129+
130+
auto V_missing = E.lookupVocab("missing");
131+
EXPECT_EQ(V_missing.size(), 2u);
132+
EXPECT_THAT(V_missing, ElementsAre(0.0, 0.0));
133+
}
134+
135+
TEST(IR2VecTest, ZeroDimensionEmbedding) {
136+
Embedding E1;
137+
Embedding E2;
138+
// Should be no-op, but not crash
139+
TestableEmbedder::addVectors(E1, E2);
140+
TestableEmbedder::addScaledVector(E1, E2, 1.0f);
141+
EXPECT_TRUE(E1.empty());
142+
}
143+
144+
TEST(IR2VecTest, IR2VecVocabResultValidity) {
145+
// Default constructed is invalid
146+
IR2VecVocabResult invalidResult;
147+
EXPECT_FALSE(invalidResult.isValid());
148+
#if GTEST_HAS_DEATH_TEST
149+
#ifndef NDEBUG
150+
EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid");
151+
EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
152+
#endif // NDEBUG
153+
#endif // GTEST_HAS_DEATH_TEST
154+
155+
// Valid vocab
156+
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
157+
IR2VecVocabResult validResult(std::move(V));
158+
EXPECT_TRUE(validResult.isValid());
159+
EXPECT_EQ(validResult.getDimension(), 2u);
160+
}
161+
162+
// Helper to create a minimal function and embedder for getter tests
163+
struct GetterTestEnv {
164+
Vocab V = {};
165+
LLVMContext Ctx;
166+
std::unique_ptr<Module> M = nullptr;
167+
Function *F = nullptr;
168+
BasicBlock *BB = nullptr;
169+
Instruction *Add = nullptr;
170+
Instruction *Ret = nullptr;
171+
std::unique_ptr<Embedder> Emb = nullptr;
172+
173+
GetterTestEnv() {
174+
V = {{"add", {1.0, 2.0}},
175+
{"integerTy", {0.5, 0.5}},
176+
{"constant", {0.2, 0.3}},
177+
{"variable", {0.0, 0.0}},
178+
{"unknownTy", {0.0, 0.0}}};
179+
180+
M = std::make_unique<Module>("M", Ctx);
181+
FunctionType *FTy = FunctionType::get(
182+
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
183+
false);
184+
F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get());
185+
BB = BasicBlock::Create(Ctx, "entry", F);
186+
Argument *Arg = F->getArg(0);
187+
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
188+
189+
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
190+
Ret = ReturnInst::Create(Ctx, Add, BB);
191+
192+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
193+
EXPECT_TRUE(static_cast<bool>(Result));
194+
Emb = std::move(*Result);
195+
}
196+
};
197+
198+
TEST(IR2VecTest, GetInstVecMap) {
199+
GetterTestEnv Env;
200+
const auto &InstMap = Env.Emb->getInstVecMap();
201+
202+
EXPECT_EQ(InstMap.size(), 2u);
203+
EXPECT_TRUE(InstMap.count(Env.Add));
204+
EXPECT_TRUE(InstMap.count(Env.Ret));
205+
206+
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
207+
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
208+
209+
// Check values for add: {1.29, 2.31}
210+
EXPECT_THAT(InstMap.at(Env.Add),
211+
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
212+
213+
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
214+
// vocab
215+
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
216+
}
217+
218+
TEST(IR2VecTest, GetBBVecMap) {
219+
GetterTestEnv Env;
220+
const auto &BBMap = Env.Emb->getBBVecMap();
221+
222+
EXPECT_EQ(BBMap.size(), 1u);
223+
EXPECT_TRUE(BBMap.count(Env.BB));
224+
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
225+
226+
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
227+
// {1.29, 2.31}
228+
EXPECT_THAT(BBMap.at(Env.BB),
229+
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
230+
}
231+
232+
TEST(IR2VecTest, GetFunctionVector) {
233+
GetterTestEnv Env;
234+
const auto &FuncVec = Env.Emb->getFunctionVector();
235+
236+
EXPECT_EQ(FuncVec.size(), 2u);
237+
238+
// Function vector should match BB vector (only one BB): {1.29, 2.31}
239+
EXPECT_THAT(FuncVec,
240+
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
241+
}
242+
243+
} // end anonymous namespace

0 commit comments

Comments
 (0)