-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[IR2Vec] Adding unit tests #141873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR2Vec] Adding unit tests #141873
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesThis PR adds unit tests for IR2Vec (Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/141873.diff 4 Files Affected:
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 5f3114dcdeeaa..683f05d5beb04 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -77,6 +77,7 @@ Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
}
void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
+ assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
std::plus<double>());
}
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 67f0b043e4f68..cd04a779b9467 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -32,6 +32,7 @@ set(ANALYSIS_TEST_SOURCES
GlobalsModRefTest.cpp
FunctionPropertiesAnalysisTest.cpp
InlineCostTest.cpp
+ IR2VecTest.cpp
IRSimilarityIdentifierTest.cpp
IVDescriptorsTest.cpp
LastRunTrackingAnalysisTest.cpp
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
new file mode 100644
index 0000000000000..e03e7a2032628
--- /dev/null
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -0,0 +1,254 @@
+//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Error.h"
+
+#include "gtest/gtest.h"
+#include <map>
+#include <vector>
+
+using namespace llvm;
+using namespace ir2vec;
+
+namespace {
+
+class TestableEmbedder : public Embedder {
+public:
+ TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
+ : Embedder(F, V, Dim) {}
+ void computeEmbeddings() const override {}
+ using Embedder::lookupVocab;
+ static void addVectors(Embedding &Dst, const Embedding &Src) {
+ Embedder::addVectors(Dst, Src);
+ }
+ static void addScaledVector(Embedding &Dst, const Embedding &Src,
+ float Factor) {
+ Embedder::addScaledVector(Dst, Src, Factor);
+ }
+};
+
+class IR2VecTest : public ::testing::Test {
+protected:
+ void SetUp() override {}
+ void TearDown() override {}
+};
+
+TEST_F(IR2VecTest, CreateSymbolicEmbedder) {
+ Vocab V = {{"foo", {1.0, 2.0}}};
+
+ LLVMContext Ctx;
+ Module M("M", Ctx);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+ EXPECT_TRUE(static_cast<bool>(Result));
+
+ auto *Emb = Result->get();
+ EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(IR2VecTest, CreateInvalidMode) {
+ Vocab V = {{"foo", {1.0, 2.0}}};
+
+ LLVMContext Ctx;
+ Module M("M", Ctx);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+ // static_cast an invalid int to IR2VecKind
+ auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
+ EXPECT_FALSE(static_cast<bool>(Result));
+
+ std::string ErrMsg;
+ llvm::handleAllErrors(
+ Result.takeError(),
+ [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
+ EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+}
+
+TEST_F(IR2VecTest, AddVectors) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {0.5, 1.5, -1.0};
+
+ TestableEmbedder::addVectors(E1, E2);
+ EXPECT_DOUBLE_EQ(E1[0], 1.5);
+ EXPECT_DOUBLE_EQ(E1[1], 3.5);
+ EXPECT_DOUBLE_EQ(E1[2], 2.0);
+
+ // Check that E2 is unchanged
+ EXPECT_DOUBLE_EQ(E2[0], 0.5);
+ EXPECT_DOUBLE_EQ(E2[1], 1.5);
+ EXPECT_DOUBLE_EQ(E2[2], -1.0);
+}
+
+TEST_F(IR2VecTest, AddScaledVector) {
+ Embedding E1 = {1.0, 2.0, 3.0};
+ Embedding E2 = {2.0, 0.5, -1.0};
+
+ TestableEmbedder::addScaledVector(E1, E2, 0.5f);
+ EXPECT_DOUBLE_EQ(E1[0], 2.0);
+ EXPECT_DOUBLE_EQ(E1[1], 2.25);
+ EXPECT_DOUBLE_EQ(E1[2], 2.5);
+
+ // Check that E2 is unchanged
+ EXPECT_DOUBLE_EQ(E2[0], 2.0);
+ EXPECT_DOUBLE_EQ(E2[1], 0.5);
+ EXPECT_DOUBLE_EQ(E2[2], -1.0);
+}
+
+#if GTEST_HAS_DEATH_TEST
+TEST_F(IR2VecTest, MismatchedDimensionsAddVectors) {
+ Embedding E1 = {1.0, 2.0};
+ Embedding E2 = {1.0};
+ EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
+ "Vectors must have the same dimension");
+}
+
+TEST_F(IR2VecTest, MismatchedDimensionsAddScaledVector) {
+ Embedding E1 = {1.0, 2.0};
+ Embedding E2 = {1.0};
+ EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
+ "Vectors must have the same dimension");
+}
+#endif
+
+TEST_F(IR2VecTest, LookupVocab) {
+ Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
+ LLVMContext Ctx;
+ Module M("M", Ctx);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+ TestableEmbedder E(*F, V, 2);
+ auto V_foo = E.lookupVocab("foo");
+ EXPECT_EQ(V_foo.size(), 2u);
+ EXPECT_DOUBLE_EQ(V_foo[0], 1.0);
+ EXPECT_DOUBLE_EQ(V_foo[1], 2.0);
+
+ auto V_missing = E.lookupVocab("missing");
+ EXPECT_EQ(V_missing.size(), 2u);
+ EXPECT_DOUBLE_EQ(V_missing[0], 0.0);
+ EXPECT_DOUBLE_EQ(V_missing[1], 0.0);
+}
+
+TEST_F(IR2VecTest, ZeroDimensionEmbedding) {
+ Embedding E1;
+ Embedding E2;
+ // Should be no-op, but not crash
+ TestableEmbedder::addVectors(E1, E2);
+ TestableEmbedder::addScaledVector(E1, E2, 1.0f);
+ EXPECT_TRUE(E1.empty());
+}
+
+TEST_F(IR2VecTest, IR2VecVocabResultValidity) {
+ // Default constructed is invalid
+ IR2VecVocabResult invalidResult;
+ EXPECT_FALSE(invalidResult.isValid());
+#if GTEST_HAS_DEATH_TEST
+ EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid");
+ EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
+#endif
+
+ // Valid vocab
+ Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
+ IR2VecVocabResult validResult(std::move(V));
+ EXPECT_TRUE(validResult.isValid());
+ EXPECT_EQ(validResult.getDimension(), 2u);
+}
+
+// Helper to create a minimal function and embedder for getter tests
+struct GetterTestEnv {
+ Vocab V;
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M;
+ Function *F;
+ BasicBlock *BB;
+ Instruction *Add;
+ Instruction *Ret;
+ std::unique_ptr<Embedder> Emb;
+
+ GetterTestEnv() {
+ V = {{"add", {1.0, 2.0}},
+ {"integerTy", {0.5, 0.5}},
+ {"constant", {0.2, 0.3}},
+ {"variable", {0.0, 0.0}},
+ {"unknownTy", {0.0, 0.0}}};
+
+ M = std::make_unique<Module>("M", Ctx);
+ FunctionType *FTy = FunctionType::get(
+ Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
+ false);
+ F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get());
+ BB = BasicBlock::Create(Ctx, "entry", F);
+ Argument *Arg = F->getArg(0);
+ Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
+
+ Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
+ Ret = ReturnInst::Create(Ctx, Add, BB);
+
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+ EXPECT_TRUE(static_cast<bool>(Result));
+ Emb = std::move(*Result);
+ }
+};
+
+TEST_F(IR2VecTest, GetInstVecMap) {
+ GetterTestEnv Env;
+ const auto &InstMap = Env.Emb->getInstVecMap();
+
+ EXPECT_EQ(InstMap.size(), 2u);
+ EXPECT_TRUE(InstMap.count(Env.Add));
+ EXPECT_TRUE(InstMap.count(Env.Ret));
+
+ EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
+ EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
+
+ // Check values for add: {1.29, 2.31}
+ EXPECT_NEAR(InstMap.at(Env.Add)[0], 1.29, 1e-6);
+ EXPECT_NEAR(InstMap.at(Env.Add)[1], 2.31, 1e-6);
+
+ // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
+ // vocab
+ EXPECT_NEAR(InstMap.at(Env.Ret)[0], 0.0, 1e-6);
+ EXPECT_NEAR(InstMap.at(Env.Ret)[1], 0.0, 1e-6);
+}
+
+TEST_F(IR2VecTest, GetBBVecMap) {
+ GetterTestEnv Env;
+ const auto &BBMap = Env.Emb->getBBVecMap();
+
+ EXPECT_EQ(BBMap.size(), 1u);
+ EXPECT_TRUE(BBMap.count(Env.BB));
+ EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
+
+ // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
+ // {1.29, 2.31}
+ EXPECT_NEAR(BBMap.at(Env.BB)[0], 1.29, 1e-6);
+ EXPECT_NEAR(BBMap.at(Env.BB)[1], 2.31, 1e-6);
+}
+
+TEST_F(IR2VecTest, GetFunctionVector) {
+ GetterTestEnv Env;
+ const auto &FuncVec = Env.Emb->getFunctionVector();
+
+ EXPECT_EQ(FuncVec.size(), 2u);
+
+ // Function vector should match BB vector (only one BB): {1.29, 2.31}
+ EXPECT_NEAR(FuncVec[0], 1.29, 1e-6);
+ EXPECT_NEAR(FuncVec[1], 2.31, 1e-6);
+}
+
+} // end anonymous namespace
diff --git a/llvm/utils/gn/secondary/llvm/unittests/Analysis/BUILD.gn b/llvm/utils/gn/secondary/llvm/unittests/Analysis/BUILD.gn
index 7b91aeae2e322..82411b8b5bdd6 100644
--- a/llvm/utils/gn/secondary/llvm/unittests/Analysis/BUILD.gn
+++ b/llvm/utils/gn/secondary/llvm/unittests/Analysis/BUILD.gn
@@ -31,6 +31,7 @@ unittest("AnalysisTests") {
"FunctionPropertiesAnalysisTest.cpp",
"GlobalsModRefTest.cpp",
"GraphWriterTest.cpp",
+ "IR2VecTest.cpp",
"IRSimilarityIdentifierTest.cpp",
"IVDescriptorsTest.cpp",
"InlineCostTest.cpp",
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, some nits.
eb95b9d
to
378ea53
Compare
378ea53
to
77440cb
Compare
77440cb
to
c51a650
Compare
Merge activity
|
This PR adds unit tests for IR2Vec (Tracking issue - llvm#141817)
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/64/builds/3921 Here is the relevant piece of the build log for the reference
|
This PR adds unit tests for IR2Vec (Tracking issue - llvm#141817)
This PR adds unit tests for IR2Vec
(Tracking issue - #141817)