@@ -80,12 +80,17 @@ class Embedder {
80
80
81
81
// Utility maps - these are used to store the vector representations of
82
82
// instructions, basic blocks and functions.
83
- Embedding FuncVector;
84
- BBEmbeddingsMap BBVecMap;
85
- InstEmbeddingsMap InstVecMap;
83
+ mutable Embedding FuncVector;
84
+ mutable BBEmbeddingsMap BBVecMap;
85
+ mutable InstEmbeddingsMap InstVecMap;
86
86
87
87
Embedder (const Function &F, const Vocab &Vocabulary, unsigned Dimension);
88
88
89
+ // / Helper function to compute embeddings. It generates embeddings for all
90
+ // / the instructions and basic blocks in the function F. Logic of computing
91
+ // / the embeddings is specific to the kind of embeddings being computed.
92
+ virtual void computeEmbeddings () const = 0;
93
+
89
94
// / Lookup vocabulary for a given Key. If the key is not found, it returns a
90
95
// / zero vector.
91
96
Embedding lookupVocab (const std::string &Key) const ;
@@ -100,25 +105,24 @@ class Embedder {
100
105
public:
101
106
virtual ~Embedder () = default ;
102
107
103
- // / Top level function to compute embeddings. It generates embeddings for all
104
- // / the instructions and basic blocks in the function F. Logic of computing
105
- // / the embeddings is specific to the kind of embeddings being computed.
106
- virtual void computeEmbeddings () = 0;
107
-
108
108
// / Factory method to create an Embedder object.
109
109
static Expected<std::unique_ptr<Embedder>> create (IR2VecKind Mode,
110
110
const Function &F,
111
111
const Vocab &Vocabulary,
112
112
unsigned Dimension);
113
113
114
- // / Returns a map containing instructions and the corresponding embeddings.
115
- const InstEmbeddingsMap &getInstVecMap () const { return InstVecMap; }
114
+ // / Returns a map containing instructions and the corresponding embeddings for
115
+ // / the function F if it has been computed. If not, it computes the embeddings
116
+ // / for the function and returns the map.
117
+ const InstEmbeddingsMap &getInstVecMap () const ;
116
118
117
- // / Returns a map containing basic block and the corresponding embeddings.
118
- const BBEmbeddingsMap &getBBVecMap () const { return BBVecMap; }
119
+ // / Returns a map containing basic block and the corresponding embeddings for
120
+ // / the function F if it has been computed. If not, it computes the embeddings
121
+ // / for the function and returns the map.
122
+ const BBEmbeddingsMap &getBBVecMap () const ;
119
123
120
- // / Returns the embedding for the current function.
121
- const Embedding &getFunctionVector () const { return FuncVector; }
124
+ // / Computes and returns the embedding for the current function.
125
+ const Embedding &getFunctionVector () const ;
122
126
};
123
127
124
128
// / Class for computing the Symbolic embeddings of IR2Vec.
@@ -127,21 +131,22 @@ class Embedder {
127
131
class SymbolicEmbedder : public Embedder {
128
132
private:
129
133
// / Utility function to compute the embedding for a given basic block.
130
- Embedding computeBB2Vec (const BasicBlock &BB);
134
+ Embedding computeBB2Vec (const BasicBlock &BB) const ;
131
135
132
136
// / Utility function to compute the embedding for a given type.
133
137
Embedding getTypeEmbedding (const Type *Ty) const ;
134
138
135
139
// / Utility function to compute the embedding for a given operand.
136
140
Embedding getOperandEmbedding (const Value *Op) const ;
137
141
142
+ void computeEmbeddings () const override ;
143
+
138
144
public:
139
145
SymbolicEmbedder (const Function &F, const Vocab &Vocabulary,
140
146
unsigned Dimension)
141
147
: Embedder(F, Vocabulary, Dimension) {
142
148
FuncVector = Embedding (Dimension, 0 );
143
149
}
144
- void computeEmbeddings () override ;
145
150
};
146
151
147
152
} // namespace ir2vec
0 commit comments