36
36
#include " llvm/Support/Compiler.h"
37
37
#include " llvm/Support/ErrorOr.h"
38
38
#include " llvm/Support/JSON.h"
39
+ #include < array>
39
40
#include < map>
40
41
41
42
namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
137
138
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
138
139
139
140
// / Class for storing and accessing the IR2Vec vocabulary.
140
- // / Encapsulates all vocabulary-related constants, logic, and access methods.
141
+ // / The Vocabulary class manages seed embeddings for LLVM IR entities. The
142
+ // / seed embeddings are the initial learned representations of the entities
143
+ // / of LLVM IR. The IR2Vec representation for a given IR is derived from these
144
+ // / seed embeddings.
145
+ // /
146
+ // / The vocabulary contains the seed embeddings for three types of entities:
147
+ // / instruction opcodes, types, and operands. Types are grouped/canonicalized
148
+ // / for better learning (e.g., all float variants map to FloatTy). The
149
+ // / vocabulary abstracts away the canonicalization effectively, the exposed APIs
150
+ // / handle all the known LLVM IR opcodes, types and operands.
151
+ // /
152
+ // / This class helps populate the seed embeddings in an internal vector-based
153
+ // / ADT. It provides logic to map every IR entity to a specific slot index or
154
+ // / position in this vector, enabling O(1) embedding lookup while avoiding
155
+ // / unnecessary computations involving string based lookups while generating the
156
+ // / embeddings.
141
157
class Vocabulary {
142
158
friend class llvm ::IR2VecVocabAnalysis;
143
159
using VocabVector = std::vector<ir2vec::Embedding>;
144
160
VocabVector Vocab;
145
161
bool Valid = false ;
146
162
163
+ public:
164
+ // Slot layout:
165
+ // [0 .. MaxOpcodes-1] => Instruction opcodes
166
+ // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
167
+ // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
168
+
169
+ // / Canonical type IDs supported by IR2Vec Vocabulary
170
+ enum class CanonicalTypeID : unsigned {
171
+ FloatTy,
172
+ VoidTy,
173
+ LabelTy,
174
+ MetadataTy,
175
+ VectorTy,
176
+ TokenTy,
177
+ IntegerTy,
178
+ FunctionTy,
179
+ PointerTy,
180
+ StructTy,
181
+ ArrayTy,
182
+ UnknownTy,
183
+ MaxCanonicalType
184
+ };
185
+
147
186
// / Operand kinds supported by IR2Vec Vocabulary
148
187
enum class OperandKind : unsigned {
149
188
FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
152
191
VariableID,
153
192
MaxOperandKind
154
193
};
155
- // / String mappings for OperandKind values
156
- static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
157
- " Constant" , " Variable" };
158
- static_assert (std::size(OperandKindNames) ==
159
- static_cast <unsigned >(OperandKind::MaxOperandKind),
160
- " OperandKindNames array size must match MaxOperandKind" );
161
194
162
- public:
163
195
// / Vocabulary layout constants
164
196
#define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
165
197
#include " llvm/IR/Instruction.def"
166
198
#undef LAST_OTHER_INST
167
199
168
200
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
201
+ static constexpr unsigned MaxCanonicalTypeIDs =
202
+ static_cast <unsigned >(CanonicalTypeID::MaxCanonicalType);
169
203
static constexpr unsigned MaxOperandKinds =
170
204
static_cast <unsigned >(OperandKind::MaxOperandKind);
171
205
@@ -174,33 +208,31 @@ class Vocabulary {
174
208
175
209
LLVM_ABI bool isValid () const ;
176
210
LLVM_ABI unsigned getDimension () const ;
177
- LLVM_ABI size_t size () const ;
211
+ // / Total number of entries (opcodes + canonicalized types + operand kinds)
212
+ static constexpr size_t getCanonicalSize () { return NumCanonicalEntries; }
178
213
179
- static size_t expectedSize () {
180
- return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
181
- }
182
-
183
- // / Helper function to get vocabulary key for a given Opcode
214
+ // / Function to get vocabulary key for a given Opcode
184
215
LLVM_ABI static StringRef getVocabKeyForOpcode (unsigned Opcode);
185
216
186
- // / Helper function to get vocabulary key for a given TypeID
217
+ // / Function to get vocabulary key for a given TypeID
187
218
LLVM_ABI static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
188
219
189
- // / Helper function to get vocabulary key for a given OperandKind
220
+ // / Function to get vocabulary key for a given OperandKind
190
221
LLVM_ABI static StringRef getVocabKeyForOperandKind (OperandKind Kind);
191
222
192
- // / Helper function to classify an operand into OperandKind
223
+ // / Function to classify an operand into OperandKind
193
224
LLVM_ABI static OperandKind getOperandKind (const Value *Op);
194
225
195
- // / Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
196
- LLVM_ABI static unsigned getNumericID (unsigned Opcode);
197
- LLVM_ABI static unsigned getNumericID (Type::TypeID TypeID);
198
- LLVM_ABI static unsigned getNumericID (const Value *Op);
226
+ // / Functions to return the slot index or position of a given Opcode, TypeID,
227
+ // / or OperandKind in the vocabulary.
228
+ LLVM_ABI static unsigned getSlotIndex (unsigned Opcode);
229
+ LLVM_ABI static unsigned getSlotIndex (Type::TypeID TypeID);
230
+ LLVM_ABI static unsigned getSlotIndex (const Value *Op);
199
231
200
232
// / Accessors to get the embedding for a given entity.
201
233
LLVM_ABI const ir2vec::Embedding &operator [](unsigned Opcode) const ;
202
234
LLVM_ABI const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
203
- LLVM_ABI const ir2vec::Embedding &operator [](const Value * Arg) const ;
235
+ LLVM_ABI const ir2vec::Embedding &operator [](const Value & Arg) const ;
204
236
205
237
// / Const Iterator type aliases
206
238
using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
234
266
235
267
LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236
268
ModuleAnalysisManager::Invalidator &Inv) const ;
269
+
270
+ private:
271
+ constexpr static unsigned NumCanonicalEntries =
272
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
273
+
274
+ // / String mappings for CanonicalTypeID values
275
+ static constexpr StringLiteral CanonicalTypeNames[] = {
276
+ " FloatTy" , " VoidTy" , " LabelTy" , " MetadataTy" ,
277
+ " VectorTy" , " TokenTy" , " IntegerTy" , " FunctionTy" ,
278
+ " PointerTy" , " StructTy" , " ArrayTy" , " UnknownTy" };
279
+ static_assert (std::size(CanonicalTypeNames) ==
280
+ static_cast <unsigned >(CanonicalTypeID::MaxCanonicalType),
281
+ " CanonicalTypeNames array size must match MaxCanonicalType" );
282
+
283
+ // / String mappings for OperandKind values
284
+ static constexpr StringLiteral OperandKindNames[] = {" Function" , " Pointer" ,
285
+ " Constant" , " Variable" };
286
+ static_assert (std::size(OperandKindNames) ==
287
+ static_cast <unsigned >(OperandKind::MaxOperandKind),
288
+ " OperandKindNames array size must match MaxOperandKind" );
289
+
290
+ // / Every known TypeID defined in llvm/IR/Type.h is expected to have a
291
+ // / corresponding mapping here in the same order as enum Type::TypeID.
292
+ static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
293
+ CanonicalTypeID::FloatTy, // HalfTyID = 0
294
+ CanonicalTypeID::FloatTy, // BFloatTyID
295
+ CanonicalTypeID::FloatTy, // FloatTyID
296
+ CanonicalTypeID::FloatTy, // DoubleTyID
297
+ CanonicalTypeID::FloatTy, // X86_FP80TyID
298
+ CanonicalTypeID::FloatTy, // FP128TyID
299
+ CanonicalTypeID::FloatTy, // PPC_FP128TyID
300
+ CanonicalTypeID::VoidTy, // VoidTyID
301
+ CanonicalTypeID::LabelTy, // LabelTyID
302
+ CanonicalTypeID::MetadataTy, // MetadataTyID
303
+ CanonicalTypeID::VectorTy, // X86_AMXTyID
304
+ CanonicalTypeID::TokenTy, // TokenTyID
305
+ CanonicalTypeID::IntegerTy, // IntegerTyID
306
+ CanonicalTypeID::FunctionTy, // FunctionTyID
307
+ CanonicalTypeID::PointerTy, // PointerTyID
308
+ CanonicalTypeID::StructTy, // StructTyID
309
+ CanonicalTypeID::ArrayTy, // ArrayTyID
310
+ CanonicalTypeID::VectorTy, // FixedVectorTyID
311
+ CanonicalTypeID::VectorTy, // ScalableVectorTyID
312
+ CanonicalTypeID::PointerTy, // TypedPointerTyID
313
+ CanonicalTypeID::UnknownTy // TargetExtTyID
314
+ }};
315
+ static_assert (TypeIDMapping.size() == MaxTypeIDs,
316
+ " TypeIDMapping must cover all Type::TypeID values" );
317
+
318
+ // / Function to get vocabulary key for canonical type by enum
319
+ LLVM_ABI static StringRef
320
+ getVocabKeyForCanonicalTypeID (CanonicalTypeID CType);
321
+
322
+ // / Function to convert TypeID to CanonicalTypeID
323
+ LLVM_ABI static CanonicalTypeID getCanonicalTypeID (Type::TypeID TypeID);
237
324
};
238
325
239
326
// / Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
262
349
263
350
LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab);
264
351
265
- // / Helper function to compute embeddings. It generates embeddings for all
352
+ // / Function to compute embeddings. It generates embeddings for all
266
353
// / the instructions and basic blocks in the function F.
267
354
void computeEmbeddings () const ;
268
355
269
- // / Helper function to compute the embedding for a given basic block.
356
+ // / Function to compute the embedding for a given basic block.
270
357
// / Specific to the kind of embeddings being computed.
271
358
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
272
359
0 commit comments