Skip to content

[IR2Vec] Scale embeddings once in vocab analysis instead of repetitive scaling #143986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion llvm/docs/MLGO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,16 @@ downstream tasks, including ML-guided compiler optimizations.

The core components are:
- **Vocabulary**: A mapping from IR entities (opcodes, types, etc.) to their
vector representations. This is managed by ``IR2VecVocabAnalysis``.
vector representations. This is managed by ``IR2VecVocabAnalysis``. The
vocabulary (.json file) contains three sections -- Opcodes, Types, and
Arguments, each containing the representations of the corresponding
entities.

.. note::

It is mandatory to have these three sections present in the vocabulary file
for it to be valid; order in which they appear does not matter.

- **Embedder**: A class (``ir2vec::Embedder``) that uses the vocabulary to
compute embeddings for instructions, basic blocks, and functions.

Expand Down
23 changes: 18 additions & 5 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,19 @@ struct Embedding {
const std::vector<double> &getData() const { return Data; }

/// Arithmetic operators
LLVM_ABI Embedding &operator+=(const Embedding &RHS);
LLVM_ABI Embedding &operator-=(const Embedding &RHS);
Embedding &operator+=(const Embedding &RHS);
Embedding &operator-=(const Embedding &RHS);
Embedding &operator*=(double Factor);

/// Adds Src Embedding scaled by Factor with the called Embedding.
/// Called_Embedding += Src * Factor
LLVM_ABI Embedding &scaleAndAdd(const Embedding &Src, float Factor);

/// Returns true if the embedding is approximately equal to the RHS embedding
/// within the specified tolerance.
LLVM_ABI bool approximatelyEquals(const Embedding &RHS,
double Tolerance = 1e-6) const;
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;

void print(raw_ostream &OS) const;
};

using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
Expand Down Expand Up @@ -236,6 +238,8 @@ class IR2VecVocabResult {
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
ir2vec::Vocab Vocabulary;
Error readVocabulary();
Error parseVocabSection(StringRef Key, const json::Value &ParsedVocabValue,
ir2vec::Vocab &TargetVocab, unsigned &Dim);
void emitError(Error Err, LLVMContext &Ctx);

public:
Expand All @@ -251,14 +255,23 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
/// functions.
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
raw_ostream &OS;
void printVector(const ir2vec::Embedding &Vec) const;

public:
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
LLVM_ABI PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
};

/// This pass prints the embeddings in the vocabulary
class IR2VecVocabPrinterPass : public PassInfoMixin<IR2VecVocabPrinterPass> {
raw_ostream &OS;

public:
explicit IR2VecVocabPrinterPass(raw_ostream &OS) : OS(OS) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
static bool isRequired() { return true; }
};

} // namespace llvm

#endif // LLVM_ANALYSIS_IR2VEC_H
136 changes: 97 additions & 39 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
return *this;
}

Embedding &Embedding::operator*=(double Factor) {
std::transform(this->begin(), this->end(), this->begin(),
[Factor](double Elem) { return Elem * Factor; });
return *this;
}

Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
assert(this->size() == Src.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
Expand All @@ -101,6 +107,13 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
return true;
}

void Embedding::print(raw_ostream &OS) const {
OS << " [";
for (const auto &Elem : Data)
OS << " " << format("%.2f", Elem) << " ";
OS << "]\n";
}

// ==----------------------------------------------------------------------===//
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -196,18 +209,12 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding InstVector(Dimension, 0);

const auto OpcVec = lookupVocab(I.getOpcodeName());
InstVector.scaleAndAdd(OpcVec, OpcWeight);

// FIXME: Currently lookups are string based. Use numeric Keys
// for efficiency.
const auto Type = I.getType();
const auto TypeVec = getTypeEmbedding(Type);
InstVector.scaleAndAdd(TypeVec, TypeWeight);

InstVector += lookupVocab(I.getOpcodeName());
InstVector += getTypeEmbedding(I.getType());
for (const auto &Op : I.operands()) {
const auto OperandVec = getOperandEmbedding(Op.get());
InstVector.scaleAndAdd(OperandVec, ArgWeight);
InstVector += getOperandEmbedding(Op.get());
}
InstVecMap[&I] = InstVector;
BBVector += InstVector;
Expand Down Expand Up @@ -251,6 +258,43 @@ bool IR2VecVocabResult::invalidate(
return !(PAC.preservedWhenStateless());
}

Error IR2VecVocabAnalysis::parseVocabSection(
StringRef Key, const json::Value &ParsedVocabValue,
ir2vec::Vocab &TargetVocab, unsigned &Dim) {
json::Path::Root Path("");
const json::Object *RootObj = ParsedVocabValue.getAsObject();
if (!RootObj)
return createStringError(errc::invalid_argument,
"JSON root is not an object");

const json::Value *SectionValue = RootObj->get(Key);
if (!SectionValue)
return createStringError(errc::invalid_argument,
"Missing '" + std::string(Key) +
"' section in vocabulary file");
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
return createStringError(errc::illegal_byte_sequence,
"Unable to parse '" + std::string(Key) +
"' section from vocabulary");

Dim = TargetVocab.begin()->second.size();
if (Dim == 0)
return createStringError(errc::illegal_byte_sequence,
"Dimension of '" + std::string(Key) +
"' section of the vocabulary is zero");

if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the '" + std::string(Key) +
"' section of the vocabulary are not of the same dimension");

return Error::success();
};

// FIXME: Make this optional. We can avoid file reads
// by auto-generating a default vocabulary during the build time.
Error IR2VecVocabAnalysis::readVocabulary() {
Expand All @@ -259,32 +303,40 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return createFileError(VocabFile, BufOrError.getError());

auto Content = BufOrError.get()->getBuffer();
json::Path::Root Path("");

Expected<json::Value> ParsedVocabValue = json::parse(Content);
if (!ParsedVocabValue)
return ParsedVocabValue.takeError();

bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
if (!Res)
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
ir2vec::Vocab OpcodeVocab, TypeVocab, ArgVocab;
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
if (auto Err = parseVocabSection("Opcodes", *ParsedVocabValue, OpcodeVocab,
OpcodeDim))
return Err;

if (Vocabulary.empty())
return createStringError(errc::illegal_byte_sequence,
"Vocabulary is empty");
if (auto Err =
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
return Err;

unsigned Dim = Vocabulary.begin()->second.size();
if (Dim == 0)
if (auto Err =
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
return Err;

if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
return createStringError(errc::illegal_byte_sequence,
"Dimension of vocabulary is zero");
"Vocabulary sections have different dimensions");

if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the vocabulary are not of the same dimension");
auto scaleVocabSection = [](ir2vec::Vocab &Vocab, double Weight) {
for (auto &Entry : Vocab)
Entry.second *= Weight;
};
scaleVocabSection(OpcodeVocab, OpcWeight);
scaleVocabSection(TypeVocab, TypeWeight);
scaleVocabSection(ArgVocab, ArgWeight);

Vocabulary.insert(OpcodeVocab.begin(), OpcodeVocab.end());
Vocabulary.insert(TypeVocab.begin(), TypeVocab.end());
Vocabulary.insert(ArgVocab.begin(), ArgVocab.end());

return Error::success();
}
Expand All @@ -304,7 +356,6 @@ void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
// If vocabulary is already populated by the constructor, use it.
if (!Vocabulary.empty())
return IR2VecVocabResult(std::move(Vocabulary));
Expand All @@ -323,16 +374,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
}

// ==----------------------------------------------------------------------===//
// IR2VecPrinterPass
// Printer Passes
//===----------------------------------------------------------------------===//

void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
OS << " [";
for (const auto &Elem : Vec)
OS << " " << format("%.2f", Elem) << " ";
OS << "]\n";
}

PreservedAnalyses IR2VecPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
Expand All @@ -353,15 +397,15 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,

OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
printVector(Emb->getFunctionVector());
Emb->getFunctionVector().print(OS);

OS << "Basic block vectors:\n";
const auto &BBMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBMap.find(&BB);
if (It != BBMap.end()) {
OS << "Basic block: " << BB.getName() << ":\n";
printVector(It->second);
It->second.print(OS);
}
}

Expand All @@ -373,10 +417,24 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
if (It != InstMap.end()) {
OS << "Instruction: ";
I.print(OS);
printVector(It->second);
It->second.print(OS);
}
}
}
}
return PreservedAnalyses::all();
}

PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");

auto Vocab = IR2VecVocabResult.getVocabulary();
for (const auto &Entry : Vocab) {
OS << "Key: " << Entry.first << ": ";
Entry.second.print(OS);
}

return PreservedAnalyses::all();
}
Loading
Loading