31
31
32
32
#include " llvm/ADT/DenseMap.h"
33
33
#include " llvm/IR/PassManager.h"
34
+ #include " llvm/IR/Type.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/Compiler.h"
36
37
#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
43
44
class BasicBlock ;
44
45
class Instruction ;
45
46
class Function ;
46
- class Type ;
47
47
class Value ;
48
48
class raw_ostream ;
49
49
class LLVMContext ;
50
+ class IR2VecVocabAnalysis ;
50
51
51
52
// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
52
53
// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -125,9 +126,73 @@ struct Embedding {
125
126
126
127
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
127
128
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
128
- // FIXME: Current the keys are strings. This can be changed to
129
- // use integers for cheaper lookups.
130
- using Vocab = std::map<std::string, Embedding>;
129
+
130
+ // / Class for storing and accessing the IR2Vec vocabulary.
131
+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
132
+ class Vocabulary {
133
+ friend class llvm ::IR2VecVocabAnalysis;
134
+ using VocabVector = std::vector<ir2vec::Embedding>;
135
+ VocabVector Vocab;
136
+ bool Valid = false ;
137
+
138
+ // / Operand kinds supported by IR2Vec Vocabulary
139
+ #define OPERAND_KINDS \
140
+ OPERAND_KIND (FunctionID, " Function" ) \
141
+ OPERAND_KIND (PointerID, " Pointer" ) \
142
+ OPERAND_KIND (ConstantID, " Constant" ) \
143
+ OPERAND_KIND (VariableID, " Variable" )
144
+
145
+ enum class OperandKind : unsigned {
146
+ #define OPERAND_KIND (Name, Str ) Name,
147
+ OPERAND_KINDS
148
+ #undef OPERAND_KIND
149
+ MaxOperandKind
150
+ };
151
+
152
+ #undef OPERAND_KINDS
153
+
154
+ // / Vocabulary layout constants
155
+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
156
+ #include " llvm/IR/Instruction.def"
157
+ #undef LAST_OTHER_INST
158
+
159
+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
160
+ static constexpr unsigned MaxOperandKinds =
161
+ static_cast <unsigned >(OperandKind::MaxOperandKind);
162
+
163
+ // / Helper function to get vocabulary key for a given OperandKind
164
+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
165
+
166
+ // / Helper function to classify an operand into OperandKind
167
+ static OperandKind getOperandKind (const Value *Op);
168
+
169
+ // / Helper function to get vocabulary key for a given TypeID
170
+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
171
+
172
+ public:
173
+ Vocabulary () = default ;
174
+ Vocabulary (VocabVector &&Vocab);
175
+
176
+ bool isValid () const ;
177
+ unsigned getDimension () const ;
178
+ unsigned size () const ;
179
+
180
+ const ir2vec::Embedding &at (unsigned Position) const ;
181
+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
182
+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
183
+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
184
+
185
+ // / Returns the string key for a given index position in the vocabulary.
186
+ // / This is useful for debugging or printing the vocabulary. Do not use this
187
+ // / for embedding generation as string based lookups are inefficient.
188
+ static StringRef getStringKey (unsigned Pos);
189
+
190
+ // / Create a dummy vocabulary for testing purposes.
191
+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
192
+
193
+ bool invalidate (Module &M, const PreservedAnalyses &PA,
194
+ ModuleAnalysisManager::Invalidator &Inv) const ;
195
+ };
131
196
132
197
// / Embedder provides the interface to generate embeddings (vector
133
198
// / representations) for instructions, basic blocks, and functions. The
@@ -138,7 +203,7 @@ using Vocab = std::map<std::string, Embedding>;
138
203
class Embedder {
139
204
protected:
140
205
const Function &F;
141
- const Vocab &Vocabulary ;
206
+ const Vocabulary &Vocab ;
142
207
143
208
// / Dimension of the vector representation; captured from the input vocabulary
144
209
const unsigned Dimension;
@@ -153,7 +218,7 @@ class Embedder {
153
218
mutable BBEmbeddingsMap BBVecMap;
154
219
mutable InstEmbeddingsMap InstVecMap;
155
220
156
- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
221
+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
157
222
158
223
// / Helper function to compute embeddings. It generates embeddings for all
159
224
// / the instructions and basic blocks in the function F. Logic of computing
@@ -164,16 +229,12 @@ class Embedder {
164
229
// / Specific to the kind of embeddings being computed.
165
230
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
166
231
167
- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
168
- // / zero vector.
169
- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
170
-
171
232
public:
172
233
virtual ~Embedder () = default ;
173
234
174
235
// / Factory method to create an Embedder object.
175
236
LLVM_ABI static std::unique_ptr<Embedder>
176
- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
237
+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
177
238
178
239
// / Returns a map containing instructions and the corresponding embeddings for
179
240
// / the function F if it has been computed. If not, it computes the embeddings
@@ -199,56 +260,40 @@ class Embedder {
199
260
// / representations obtained from the Vocabulary.
200
261
class LLVM_ABI SymbolicEmbedder : public Embedder {
201
262
private:
202
- // / Utility function to compute the embedding for a given type.
203
- Embedding getTypeEmbedding (const Type *Ty) const ;
204
-
205
- // / Utility function to compute the embedding for a given operand.
206
- Embedding getOperandEmbedding (const Value *Op) const ;
207
-
208
263
void computeEmbeddings () const override ;
209
264
void computeEmbeddings (const BasicBlock &BB) const override ;
210
265
211
266
public:
212
- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
213
- : Embedder(F, Vocabulary ) {
267
+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
268
+ : Embedder(F, Vocab ) {
214
269
FuncVector = Embedding (Dimension, 0 );
215
270
}
216
271
};
217
272
218
273
} // namespace ir2vec
219
274
220
- // / Class for storing the result of the IR2VecVocabAnalysis.
221
- class IR2VecVocabResult {
222
- ir2vec::Vocab Vocabulary;
223
- bool Valid = false ;
224
-
225
- public:
226
- IR2VecVocabResult () = default ;
227
- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
228
-
229
- bool isValid () const { return Valid; }
230
- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
231
- LLVM_ABI unsigned getDimension () const ;
232
- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
233
- ModuleAnalysisManager::Invalidator &Inv) const ;
234
- };
235
-
236
275
// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
237
276
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
238
277
// / its corresponding embedding.
239
278
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
240
- ir2vec::Vocab Vocabulary;
279
+ using VocabVector = std::vector<ir2vec::Embedding>;
280
+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
281
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
282
+ VocabVector Vocab;
283
+
284
+ unsigned Dim = 0 ;
241
285
Error readVocabulary ();
242
286
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
243
- ir2vec::Vocab &TargetVocab, unsigned &Dim);
287
+ VocabMap &TargetVocab, unsigned &Dim);
288
+ void generateNumMappedVocab ();
244
289
void emitError (Error Err, LLVMContext &Ctx);
245
290
246
291
public:
247
292
LLVM_ABI static AnalysisKey Key;
248
293
IR2VecVocabAnalysis () = default ;
249
- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
250
- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
251
- using Result = IR2VecVocabResult ;
294
+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
295
+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
296
+ using Result = ir2vec::Vocabulary ;
252
297
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
253
298
};
254
299
0 commit comments