Skip to content

Commit eb95b9d

Browse files
committed
unit tests
1 parent 893ef7f commit eb95b9d

File tree

4 files changed

+257
-0
lines changed

4 files changed

+257
-0
lines changed

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
7777
}
7878

7979
void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
80+
assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
8081
std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
8182
std::plus<double>());
8283
}

llvm/unittests/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ set(ANALYSIS_TEST_SOURCES
3232
GlobalsModRefTest.cpp
3333
FunctionPropertiesAnalysisTest.cpp
3434
InlineCostTest.cpp
35+
IR2VecTest.cpp
3536
IRSimilarityIdentifierTest.cpp
3637
IVDescriptorsTest.cpp
3738
LastRunTrackingAnalysisTest.cpp
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "llvm/Analysis/IR2Vec.h"
10+
#include "llvm/IR/Constants.h"
11+
#include "llvm/IR/Function.h"
12+
#include "llvm/IR/Instructions.h"
13+
#include "llvm/IR/LLVMContext.h"
14+
#include "llvm/IR/Module.h"
15+
#include "llvm/IR/Type.h"
16+
#include "llvm/Support/Error.h"
17+
18+
#include "gtest/gtest.h"
19+
#include <map>
20+
#include <vector>
21+
22+
using namespace llvm;
23+
using namespace ir2vec;
24+
25+
namespace {
26+
27+
class TestableEmbedder : public Embedder {
28+
public:
29+
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
30+
: Embedder(F, V, Dim) {}
31+
void computeEmbeddings() const override {}
32+
using Embedder::lookupVocab;
33+
static void addVectors(Embedding &Dst, const Embedding &Src) {
34+
Embedder::addVectors(Dst, Src);
35+
}
36+
static void addScaledVector(Embedding &Dst, const Embedding &Src,
37+
float Factor) {
38+
Embedder::addScaledVector(Dst, Src, Factor);
39+
}
40+
};
41+
42+
class IR2VecTest : public ::testing::Test {
43+
protected:
44+
void SetUp() override {}
45+
void TearDown() override {}
46+
};
47+
48+
TEST_F(IR2VecTest, CreateSymbolicEmbedder) {
49+
Vocab V = {{"foo", {1.0, 2.0}}};
50+
51+
LLVMContext Ctx;
52+
Module M("M", Ctx);
53+
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
54+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
55+
56+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
57+
EXPECT_TRUE(static_cast<bool>(Result));
58+
59+
auto *Emb = Result->get();
60+
EXPECT_NE(Emb, nullptr);
61+
}
62+
63+
TEST_F(IR2VecTest, CreateInvalidMode) {
64+
Vocab V = {{"foo", {1.0, 2.0}}};
65+
66+
LLVMContext Ctx;
67+
Module M("M", Ctx);
68+
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
69+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
70+
71+
// static_cast an invalid int to IR2VecKind
72+
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
73+
EXPECT_FALSE(static_cast<bool>(Result));
74+
75+
std::string ErrMsg;
76+
llvm::handleAllErrors(
77+
Result.takeError(),
78+
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
79+
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
80+
}
81+
82+
TEST_F(IR2VecTest, AddVectors) {
83+
Embedding E1 = {1.0, 2.0, 3.0};
84+
Embedding E2 = {0.5, 1.5, -1.0};
85+
86+
TestableEmbedder::addVectors(E1, E2);
87+
EXPECT_DOUBLE_EQ(E1[0], 1.5);
88+
EXPECT_DOUBLE_EQ(E1[1], 3.5);
89+
EXPECT_DOUBLE_EQ(E1[2], 2.0);
90+
91+
// Check that E2 is unchanged
92+
EXPECT_DOUBLE_EQ(E2[0], 0.5);
93+
EXPECT_DOUBLE_EQ(E2[1], 1.5);
94+
EXPECT_DOUBLE_EQ(E2[2], -1.0);
95+
}
96+
97+
TEST_F(IR2VecTest, AddScaledVector) {
98+
Embedding E1 = {1.0, 2.0, 3.0};
99+
Embedding E2 = {2.0, 0.5, -1.0};
100+
101+
TestableEmbedder::addScaledVector(E1, E2, 0.5f);
102+
EXPECT_DOUBLE_EQ(E1[0], 2.0);
103+
EXPECT_DOUBLE_EQ(E1[1], 2.25);
104+
EXPECT_DOUBLE_EQ(E1[2], 2.5);
105+
106+
// Check that E2 is unchanged
107+
EXPECT_DOUBLE_EQ(E2[0], 2.0);
108+
EXPECT_DOUBLE_EQ(E2[1], 0.5);
109+
EXPECT_DOUBLE_EQ(E2[2], -1.0);
110+
}
111+
112+
#if GTEST_HAS_DEATH_TEST
113+
TEST_F(IR2VecTest, MismatchedDimensionsAddVectors) {
114+
Embedding E1 = {1.0, 2.0};
115+
Embedding E2 = {1.0};
116+
EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
117+
"Vectors must have the same dimension");
118+
}
119+
120+
TEST_F(IR2VecTest, MismatchedDimensionsAddScaledVector) {
121+
Embedding E1 = {1.0, 2.0};
122+
Embedding E2 = {1.0};
123+
EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
124+
"Vectors must have the same dimension");
125+
}
126+
#endif
127+
128+
TEST_F(IR2VecTest, LookupVocab) {
129+
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
130+
LLVMContext Ctx;
131+
Module M("M", Ctx);
132+
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
133+
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
134+
135+
TestableEmbedder E(*F, V, 2);
136+
auto V_foo = E.lookupVocab("foo");
137+
EXPECT_EQ(V_foo.size(), 2u);
138+
EXPECT_DOUBLE_EQ(V_foo[0], 1.0);
139+
EXPECT_DOUBLE_EQ(V_foo[1], 2.0);
140+
141+
auto V_missing = E.lookupVocab("missing");
142+
EXPECT_EQ(V_missing.size(), 2u);
143+
EXPECT_DOUBLE_EQ(V_missing[0], 0.0);
144+
EXPECT_DOUBLE_EQ(V_missing[1], 0.0);
145+
}
146+
147+
TEST_F(IR2VecTest, ZeroDimensionEmbedding) {
148+
Embedding E1;
149+
Embedding E2;
150+
// Should be no-op, but not crash
151+
TestableEmbedder::addVectors(E1, E2);
152+
TestableEmbedder::addScaledVector(E1, E2, 1.0f);
153+
EXPECT_TRUE(E1.empty());
154+
}
155+
156+
TEST_F(IR2VecTest, IR2VecVocabResultValidity) {
157+
// Default constructed is invalid
158+
IR2VecVocabResult invalidResult;
159+
EXPECT_FALSE(invalidResult.isValid());
160+
#if GTEST_HAS_DEATH_TEST
161+
EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid");
162+
EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
163+
#endif
164+
165+
// Valid vocab
166+
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
167+
IR2VecVocabResult validResult(std::move(V));
168+
EXPECT_TRUE(validResult.isValid());
169+
EXPECT_EQ(validResult.getDimension(), 2u);
170+
}
171+
172+
// Helper to create a minimal function and embedder for getter tests
173+
struct GetterTestEnv {
174+
Vocab V;
175+
LLVMContext Ctx;
176+
std::unique_ptr<Module> M;
177+
Function *F;
178+
BasicBlock *BB;
179+
Instruction *Add;
180+
Instruction *Ret;
181+
std::unique_ptr<Embedder> Emb;
182+
183+
GetterTestEnv() {
184+
V = {{"add", {1.0, 2.0}},
185+
{"integerTy", {0.5, 0.5}},
186+
{"constant", {0.2, 0.3}},
187+
{"variable", {0.0, 0.0}},
188+
{"unknownTy", {0.0, 0.0}}};
189+
190+
M = std::make_unique<Module>("M", Ctx);
191+
FunctionType *FTy = FunctionType::get(
192+
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
193+
false);
194+
F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get());
195+
BB = BasicBlock::Create(Ctx, "entry", F);
196+
Argument *Arg = F->getArg(0);
197+
Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
198+
199+
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
200+
Ret = ReturnInst::Create(Ctx, Add, BB);
201+
202+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
203+
EXPECT_TRUE(static_cast<bool>(Result));
204+
Emb = std::move(*Result);
205+
}
206+
};
207+
208+
TEST_F(IR2VecTest, GetInstVecMap) {
209+
GetterTestEnv Env;
210+
const auto &InstMap = Env.Emb->getInstVecMap();
211+
212+
EXPECT_EQ(InstMap.size(), 2u);
213+
EXPECT_TRUE(InstMap.count(Env.Add));
214+
EXPECT_TRUE(InstMap.count(Env.Ret));
215+
216+
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
217+
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
218+
219+
// Check values for add: {1.29, 2.31}
220+
EXPECT_NEAR(InstMap.at(Env.Add)[0], 1.29, 1e-6);
221+
EXPECT_NEAR(InstMap.at(Env.Add)[1], 2.31, 1e-6);
222+
223+
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
224+
// vocab
225+
EXPECT_NEAR(InstMap.at(Env.Ret)[0], 0.0, 1e-6);
226+
EXPECT_NEAR(InstMap.at(Env.Ret)[1], 0.0, 1e-6);
227+
}
228+
229+
TEST_F(IR2VecTest, GetBBVecMap) {
230+
GetterTestEnv Env;
231+
const auto &BBMap = Env.Emb->getBBVecMap();
232+
233+
EXPECT_EQ(BBMap.size(), 1u);
234+
EXPECT_TRUE(BBMap.count(Env.BB));
235+
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
236+
237+
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
238+
// {1.29, 2.31}
239+
EXPECT_NEAR(BBMap.at(Env.BB)[0], 1.29, 1e-6);
240+
EXPECT_NEAR(BBMap.at(Env.BB)[1], 2.31, 1e-6);
241+
}
242+
243+
TEST_F(IR2VecTest, GetFunctionVector) {
244+
GetterTestEnv Env;
245+
const auto &FuncVec = Env.Emb->getFunctionVector();
246+
247+
EXPECT_EQ(FuncVec.size(), 2u);
248+
249+
// Function vector should match BB vector (only one BB): {1.29, 2.31}
250+
EXPECT_NEAR(FuncVec[0], 1.29, 1e-6);
251+
EXPECT_NEAR(FuncVec[1], 2.31, 1e-6);
252+
}
253+
254+
} // end anonymous namespace

llvm/utils/gn/secondary/llvm/unittests/Analysis/BUILD.gn

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ unittest("AnalysisTests") {
3131
"FunctionPropertiesAnalysisTest.cpp",
3232
"GlobalsModRefTest.cpp",
3333
"GraphWriterTest.cpp",
34+
"IR2VecTest.cpp",
3435
"IRSimilarityIdentifierTest.cpp",
3536
"IVDescriptorsTest.cpp",
3637
"InlineCostTest.cpp",

0 commit comments

Comments
 (0)