31
31
32
32
#include " llvm/ADT/DenseMap.h"
33
33
#include " llvm/IR/PassManager.h"
34
+ #include " llvm/IR/Type.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/Compiler.h"
36
37
#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
43
44
class BasicBlock ;
44
45
class Instruction ;
45
46
class Function ;
46
- class Type ;
47
47
class Value ;
48
48
class raw_ostream ;
49
49
class LLVMContext ;
50
+ class IR2VecVocabAnalysis ;
50
51
51
52
// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
52
53
// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -126,9 +127,73 @@ struct Embedding {
126
127
127
128
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
128
129
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
129
- // FIXME: Current the keys are strings. This can be changed to
130
- // use integers for cheaper lookups.
131
- using Vocab = std::map<std::string, Embedding>;
130
+
131
+ // / Class for storing and accessing the IR2Vec vocabulary.
132
+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
133
+ class Vocabulary {
134
+ friend class llvm ::IR2VecVocabAnalysis;
135
+ using VocabVector = std::vector<ir2vec::Embedding>;
136
+ VocabVector Vocab;
137
+ bool Valid = false ;
138
+
139
+ // / Operand kinds supported by IR2Vec Vocabulary
140
+ #define OPERAND_KINDS \
141
+ OPERAND_KIND (FunctionID, " Function" ) \
142
+ OPERAND_KIND (PointerID, " Pointer" ) \
143
+ OPERAND_KIND (ConstantID, " Constant" ) \
144
+ OPERAND_KIND (VariableID, " Variable" )
145
+
146
+ enum class OperandKind : unsigned {
147
+ #define OPERAND_KIND (Name, Str ) Name,
148
+ OPERAND_KINDS
149
+ #undef OPERAND_KIND
150
+ MaxOperandKind
151
+ };
152
+
153
+ #undef OPERAND_KINDS
154
+
155
+ // / Vocabulary layout constants
156
+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
157
+ #include " llvm/IR/Instruction.def"
158
+ #undef LAST_OTHER_INST
159
+
160
+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
161
+ static constexpr unsigned MaxOperandKinds =
162
+ static_cast <unsigned >(OperandKind::MaxOperandKind);
163
+
164
+ // / Helper function to get vocabulary key for a given OperandKind
165
+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
166
+
167
+ // / Helper function to classify an operand into OperandKind
168
+ static OperandKind getOperandKind (const Value *Op);
169
+
170
+ // / Helper function to get vocabulary key for a given TypeID
171
+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
172
+
173
+ public:
174
+ Vocabulary () = default ;
175
+ Vocabulary (VocabVector &&Vocab);
176
+
177
+ bool isValid () const ;
178
+ unsigned getDimension () const ;
179
+ unsigned size () const ;
180
+
181
+ const ir2vec::Embedding &at (unsigned Position) const ;
182
+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
183
+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
184
+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
185
+
186
+ // / Returns the string key for a given index position in the vocabulary.
187
+ // / This is useful for debugging or printing the vocabulary. Do not use this
188
+ // / for embedding generation as string based lookups are inefficient.
189
+ static StringRef getStringKey (unsigned Pos);
190
+
191
+ // / Create a dummy vocabulary for testing purposes.
192
+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
193
+
194
+ bool invalidate (Module &M, const PreservedAnalyses &PA,
195
+ ModuleAnalysisManager::Invalidator &Inv) const ;
196
+ };
132
197
133
198
// / Embedder provides the interface to generate embeddings (vector
134
199
// / representations) for instructions, basic blocks, and functions. The
@@ -139,7 +204,7 @@ using Vocab = std::map<std::string, Embedding>;
139
204
class Embedder {
140
205
protected:
141
206
const Function &F;
142
- const Vocab &Vocabulary ;
207
+ const Vocabulary &Vocab ;
143
208
144
209
// / Dimension of the vector representation; captured from the input vocabulary
145
210
const unsigned Dimension;
@@ -154,7 +219,7 @@ class Embedder {
154
219
mutable BBEmbeddingsMap BBVecMap;
155
220
mutable InstEmbeddingsMap InstVecMap;
156
221
157
- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
222
+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
158
223
159
224
// / Helper function to compute embeddings. It generates embeddings for all
160
225
// / the instructions and basic blocks in the function F. Logic of computing
@@ -165,16 +230,12 @@ class Embedder {
165
230
// / Specific to the kind of embeddings being computed.
166
231
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
167
232
168
- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
169
- // / zero vector.
170
- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
171
-
172
233
public:
173
234
virtual ~Embedder () = default ;
174
235
175
236
// / Factory method to create an Embedder object.
176
237
LLVM_ABI static std::unique_ptr<Embedder>
177
- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
238
+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
178
239
179
240
// / Returns a map containing instructions and the corresponding embeddings for
180
241
// / the function F if it has been computed. If not, it computes the embeddings
@@ -200,56 +261,40 @@ class Embedder {
200
261
// / representations obtained from the Vocabulary.
201
262
class LLVM_ABI SymbolicEmbedder : public Embedder {
202
263
private:
203
- // / Utility function to compute the embedding for a given type.
204
- Embedding getTypeEmbedding (const Type *Ty) const ;
205
-
206
- // / Utility function to compute the embedding for a given operand.
207
- Embedding getOperandEmbedding (const Value *Op) const ;
208
-
209
264
void computeEmbeddings () const override ;
210
265
void computeEmbeddings (const BasicBlock &BB) const override ;
211
266
212
267
public:
213
- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
214
- : Embedder(F, Vocabulary ) {
268
+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
269
+ : Embedder(F, Vocab ) {
215
270
FuncVector = Embedding (Dimension, 0 );
216
271
}
217
272
};
218
273
219
274
} // namespace ir2vec
220
275
221
- // / Class for storing the result of the IR2VecVocabAnalysis.
222
- class IR2VecVocabResult {
223
- ir2vec::Vocab Vocabulary;
224
- bool Valid = false ;
225
-
226
- public:
227
- IR2VecVocabResult () = default ;
228
- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
229
-
230
- bool isValid () const { return Valid; }
231
- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
232
- LLVM_ABI unsigned getDimension () const ;
233
- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
234
- ModuleAnalysisManager::Invalidator &Inv) const ;
235
- };
236
-
237
276
// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
238
277
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
239
278
// / its corresponding embedding.
240
279
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
241
- ir2vec::Vocab Vocabulary;
280
+ using VocabVector = std::vector<ir2vec::Embedding>;
281
+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
282
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
283
+ VocabVector Vocab;
284
+
285
+ unsigned Dim = 0 ;
242
286
Error readVocabulary ();
243
287
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
244
- ir2vec::Vocab &TargetVocab, unsigned &Dim);
288
+ VocabMap &TargetVocab, unsigned &Dim);
289
+ void generateNumMappedVocab ();
245
290
void emitError (Error Err, LLVMContext &Ctx);
246
291
247
292
public:
248
293
LLVM_ABI static AnalysisKey Key;
249
294
IR2VecVocabAnalysis () = default ;
250
- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
251
- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
252
- using Result = IR2VecVocabResult ;
295
+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
296
+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
297
+ using Result = ir2vec::Vocabulary ;
253
298
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
254
299
};
255
300
0 commit comments