@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
52
52
cl::opt<float > ArgWeight (" ir2vec-arg-weight" , cl::Optional, cl::init(0.2 ),
53
53
cl::desc(" Weight for argument embeddings" ),
54
54
cl::cat(IR2VecCategory));
55
+ cl::opt<IR2VecKind> IR2VecEmbeddingKind (
56
+ " ir2vec-kind" , cl::Optional,
57
+ cl::values (clEnumValN(IR2VecKind::Symbolic, " symbolic" ,
58
+ " Generate symbolic embeddings" ),
59
+ clEnumValN(IR2VecKind::FlowAware, " flow-aware" ,
60
+ " Generate flow-aware embeddings" )),
61
+ cl::init(IR2VecKind::Symbolic), cl::desc(" IR2Vec embedding kind" ),
62
+ cl::cat(IR2VecCategory));
63
+
55
64
} // namespace ir2vec
56
65
} // namespace llvm
57
66
@@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
123
132
double Tolerance) const {
124
133
assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
125
134
for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
126
- if (std::abs ((*this )[Itr] - RHS[Itr]) > Tolerance)
135
+ if (std::abs ((*this )[Itr] - RHS[Itr]) > Tolerance) {
136
+ LLVM_DEBUG (errs () << " Embedding mismatch at index " << Itr << " : "
137
+ << (*this )[Itr] << " vs " << RHS[Itr]
138
+ << " ; Tolerance: " << Tolerance << " \n " );
127
139
return false ;
140
+ }
128
141
return true ;
129
142
}
130
143
@@ -141,14 +154,16 @@ void Embedding::print(raw_ostream &OS) const {
141
154
142
155
Embedder::Embedder (const Function &F, const Vocabulary &Vocab)
143
156
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
144
- OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
145
- }
157
+ OpcWeight (::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
158
+ FuncVector(Embedding(Dimension, 0 )) { }
146
159
147
160
std::unique_ptr<Embedder> Embedder::create (IR2VecKind Mode, const Function &F,
148
161
const Vocabulary &Vocab) {
149
162
switch (Mode) {
150
163
case IR2VecKind::Symbolic:
151
164
return std::make_unique<SymbolicEmbedder>(F, Vocab);
165
+ case IR2VecKind::FlowAware:
166
+ return std::make_unique<FlowAwareEmbedder>(F, Vocab);
152
167
}
153
168
return nullptr ;
154
169
}
@@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const {
180
195
return FuncVector;
181
196
}
182
197
198
+ void Embedder::computeEmbeddings () const {
199
+ if (F.isDeclaration ())
200
+ return ;
201
+
202
+ // Consider only the basic blocks that are reachable from entry
203
+ for (const BasicBlock *BB : depth_first (&F)) {
204
+ computeEmbeddings (*BB);
205
+ FuncVector += BBVecMap[BB];
206
+ }
207
+ }
208
+
183
209
void SymbolicEmbedder::computeEmbeddings (const BasicBlock &BB) const {
184
210
Embedding BBVector (Dimension, 0 );
185
211
@@ -196,15 +222,38 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
196
222
BBVecMap[&BB] = BBVector;
197
223
}
198
224
199
- void SymbolicEmbedder::computeEmbeddings () const {
200
- if (F.isDeclaration ())
201
- return ;
225
+ void FlowAwareEmbedder::computeEmbeddings (const BasicBlock &BB) const {
226
+ Embedding BBVector (Dimension, 0 );
202
227
203
- // Consider only the basic blocks that are reachable from entry
204
- for (const BasicBlock *BB : depth_first (&F)) {
205
- computeEmbeddings (*BB);
206
- FuncVector += BBVecMap[BB];
228
+ // We consider only the non-debug and non-pseudo instructions
229
+ for (const auto &I : BB.instructionsWithoutDebug ()) {
230
+ // TODO: Handle call instructions differently.
231
+ // For now, we treat them like other instructions
232
+ Embedding ArgEmb (Dimension, 0 );
233
+ for (const auto &Op : I.operands ()) {
234
+ // If the operand is defined elsewhere, we use its embedding
235
+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
236
+ auto DefIt = InstVecMap.find (DefInst);
237
+ assert (DefIt != InstVecMap.end () &&
238
+ " Instruction should have been processed before its operands" );
239
+ ArgEmb += DefIt->second ;
240
+ continue ;
241
+ }
242
+ // If the operand is not defined by an instruction, we use the vocabulary
243
+ else {
244
+ LLVM_DEBUG (errs () << " Using embedding from vocabulary for operand: "
245
+ << *Op << " =" << Vocab[Op][0 ] << " \n " );
246
+ ArgEmb += Vocab[Op];
247
+ }
248
+ }
249
+ // Create the instruction vector by combining opcode, type, and arguments
250
+ // embeddings
251
+ auto InstVector =
252
+ Vocab[I.getOpcode ()] + Vocab[I.getType ()->getTypeID ()] + ArgEmb;
253
+ InstVecMap[&I] = InstVector;
254
+ BBVector += InstVector;
207
255
}
256
+ BBVecMap[&BB] = BBVector;
208
257
}
209
258
210
259
// ==----------------------------------------------------------------------===//
@@ -552,8 +601,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
552
601
assert (Vocabulary.isValid () && " IR2Vec Vocabulary is invalid" );
553
602
554
603
for (Function &F : M) {
555
- std::unique_ptr<Embedder> Emb =
556
- Embedder::create (IR2VecKind::Symbolic, F, Vocabulary);
604
+ std::unique_ptr<Embedder> Emb;
605
+ switch (IR2VecEmbeddingKind) {
606
+ case IR2VecKind::Symbolic:
607
+ Emb = std::make_unique<SymbolicEmbedder>(F, Vocabulary);
608
+ break ;
609
+ case IR2VecKind::FlowAware:
610
+ Emb = std::make_unique<FlowAwareEmbedder>(F, Vocabulary);
611
+ break ;
612
+ default :
613
+ llvm_unreachable (" Unknown IR2Vec embedding kind" );
614
+ }
557
615
if (!Emb) {
558
616
OS << " Error creating IR2Vec embeddings \n " ;
559
617
continue ;
0 commit comments