-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[IR2Vec] Restructuring Vocabulary #145119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
|
||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/IR/Type.h" | ||
#include "llvm/Support/CommandLine.h" | ||
#include "llvm/Support/Compiler.h" | ||
#include "llvm/Support/ErrorOr.h" | ||
|
@@ -43,10 +44,10 @@ class Module; | |
class BasicBlock; | ||
class Instruction; | ||
class Function; | ||
class Type; | ||
class Value; | ||
class raw_ostream; | ||
class LLVMContext; | ||
class IR2VecVocabAnalysis; | ||
|
||
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware. | ||
/// Symbolic embeddings capture the "syntactic" and "statistical correlation" | ||
|
@@ -128,9 +129,73 @@ struct Embedding { | |
|
||
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>; | ||
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; | ||
// FIXME: Current the keys are strings. This can be changed to | ||
// use integers for cheaper lookups. | ||
using Vocab = std::map<std::string, Embedding>; | ||
|
||
/// Class for storing and accessing the IR2Vec vocabulary. | ||
/// Encapsulates all vocabulary-related constants, logic, and access methods. | ||
class Vocabulary { | ||
friend class llvm::IR2VecVocabAnalysis; | ||
using VocabVector = std::vector<ir2vec::Embedding>; | ||
VocabVector Vocab; | ||
bool Valid = false; | ||
|
||
/// Operand kinds supported by IR2Vec Vocabulary | ||
#define OPERAND_KINDS \ | ||
OPERAND_KIND(FunctionID, "Function") \ | ||
OPERAND_KIND(PointerID, "Pointer") \ | ||
OPERAND_KIND(ConstantID, "Constant") \ | ||
OPERAND_KIND(VariableID, "Variable") | ||
|
||
enum class OperandKind : unsigned { | ||
#define OPERAND_KIND(Name, Str) Name, | ||
OPERAND_KINDS | ||
#undef OPERAND_KIND | ||
MaxOperandKind | ||
}; | ||
|
||
#undef OPERAND_KINDS | ||
|
||
/// Vocabulary layout constants | ||
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM; | ||
#include "llvm/IR/Instruction.def" | ||
#undef LAST_OTHER_INST | ||
|
||
static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
static constexpr unsigned MaxOperandKinds = | ||
static_cast<unsigned>(OperandKind::MaxOperandKind); | ||
|
||
/// Helper function to get vocabulary key for a given OperandKind | ||
static StringRef getVocabKeyForOperandKind(OperandKind Kind); | ||
|
||
/// Helper function to classify an operand into OperandKind | ||
static OperandKind getOperandKind(const Value *Op); | ||
|
||
/// Helper function to get vocabulary key for a given TypeID | ||
static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); | ||
|
||
public: | ||
Vocabulary() = default; | ||
Vocabulary(VocabVector &&Vocab); | ||
|
||
bool isValid() const; | ||
unsigned getDimension() const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
unsigned size() const; | ||
|
||
const ir2vec::Embedding &at(unsigned Position) const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need |
||
const ir2vec::Embedding &operator[](unsigned Opcode) const; | ||
const ir2vec::Embedding &operator[](Type::TypeID TypeId) const; | ||
const ir2vec::Embedding &operator[](const Value *Arg) const; | ||
|
||
/// Returns the string key for a given index position in the vocabulary. | ||
/// This is useful for debugging or printing the vocabulary. Do not use this | ||
/// for embedding generation as string based lookups are inefficient. | ||
static StringRef getStringKey(unsigned Pos); | ||
|
||
/// Create a dummy vocabulary for testing purposes. | ||
static VocabVector createDummyVocabForTest(unsigned Dim = 1); | ||
|
||
bool invalidate(Module &M, const PreservedAnalyses &PA, | ||
ModuleAnalysisManager::Invalidator &Inv) const; | ||
}; | ||
|
||
/// Embedder provides the interface to generate embeddings (vector | ||
/// representations) for instructions, basic blocks, and functions. The | ||
|
@@ -141,7 +206,7 @@ using Vocab = std::map<std::string, Embedding>; | |
class Embedder { | ||
protected: | ||
const Function &F; | ||
const Vocab &Vocabulary; | ||
const Vocabulary &Vocab; | ||
|
||
/// Dimension of the vector representation; captured from the input vocabulary | ||
const unsigned Dimension; | ||
|
@@ -156,7 +221,7 @@ class Embedder { | |
mutable BBEmbeddingsMap BBVecMap; | ||
mutable InstEmbeddingsMap InstVecMap; | ||
|
||
LLVM_ABI Embedder(const Function &F, const Vocab &Vocabulary); | ||
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab); | ||
|
||
/// Helper function to compute embeddings. It generates embeddings for all | ||
/// the instructions and basic blocks in the function F. Logic of computing | ||
|
@@ -167,16 +232,12 @@ class Embedder { | |
/// Specific to the kind of embeddings being computed. | ||
virtual void computeEmbeddings(const BasicBlock &BB) const = 0; | ||
|
||
/// Lookup vocabulary for a given Key. If the key is not found, it returns a | ||
/// zero vector. | ||
LLVM_ABI Embedding lookupVocab(const std::string &Key) const; | ||
|
||
public: | ||
virtual ~Embedder() = default; | ||
|
||
/// Factory method to create an Embedder object. | ||
LLVM_ABI static std::unique_ptr<Embedder> | ||
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary); | ||
create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab); | ||
|
||
/// Returns a map containing instructions and the corresponding embeddings for | ||
/// the function F if it has been computed. If not, it computes the embeddings | ||
|
@@ -202,56 +263,40 @@ class Embedder { | |
/// representations obtained from the Vocabulary. | ||
class LLVM_ABI SymbolicEmbedder : public Embedder { | ||
private: | ||
/// Utility function to compute the embedding for a given type. | ||
Embedding getTypeEmbedding(const Type *Ty) const; | ||
|
||
/// Utility function to compute the embedding for a given operand. | ||
Embedding getOperandEmbedding(const Value *Op) const; | ||
|
||
void computeEmbeddings() const override; | ||
void computeEmbeddings(const BasicBlock &BB) const override; | ||
|
||
public: | ||
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary) | ||
: Embedder(F, Vocabulary) { | ||
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab) | ||
: Embedder(F, Vocab) { | ||
FuncVector = Embedding(Dimension, 0); | ||
} | ||
}; | ||
|
||
} // namespace ir2vec | ||
|
||
/// Class for storing the result of the IR2VecVocabAnalysis. | ||
class IR2VecVocabResult { | ||
ir2vec::Vocab Vocabulary; | ||
bool Valid = false; | ||
|
||
public: | ||
IR2VecVocabResult() = default; | ||
LLVM_ABI IR2VecVocabResult(ir2vec::Vocab &&Vocabulary); | ||
|
||
bool isValid() const { return Valid; } | ||
LLVM_ABI const ir2vec::Vocab &getVocabulary() const; | ||
LLVM_ABI unsigned getDimension() const; | ||
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA, | ||
ModuleAnalysisManager::Invalidator &Inv) const; | ||
}; | ||
|
||
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a | ||
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and | ||
/// its corresponding embedding. | ||
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> { | ||
ir2vec::Vocab Vocabulary; | ||
using VocabVector = std::vector<ir2vec::Embedding>; | ||
using VocabMap = std::map<std::string, ir2vec::Embedding>; | ||
VocabMap OpcVocab, TypeVocab, ArgVocab; | ||
VocabVector Vocab; | ||
|
||
unsigned Dim = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. needed? isn't it a property of |
||
Error readVocabulary(); | ||
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue, | ||
ir2vec::Vocab &TargetVocab, unsigned &Dim); | ||
VocabMap &TargetVocab, unsigned &Dim); | ||
void generateNumMappedVocab(); | ||
void emitError(Error Err, LLVMContext &Ctx); | ||
|
||
public: | ||
LLVM_ABI static AnalysisKey Key; | ||
IR2VecVocabAnalysis() = default; | ||
LLVM_ABI explicit IR2VecVocabAnalysis(const ir2vec::Vocab &Vocab); | ||
LLVM_ABI explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab); | ||
using Result = IR2VecVocabResult; | ||
LLVM_ABI explicit IR2VecVocabAnalysis(const VocabVector &Vocab); | ||
LLVM_ABI explicit IR2VecVocabAnalysis(VocabVector &&Vocab); | ||
using Result = ir2vec::Vocabulary; | ||
LLVM_ABI Result run(Module &M, ModuleAnalysisManager &MAM); | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can
const ir2vec::Vocabulary * const;
- the value is const, potentiallynullptr
, and also the field is set-once, right?