Skip to content

Commit cb881b6

Browse files
committed
Canonicalized type
1 parent e317c7e commit cb881b6

File tree

9 files changed

+350
-308
lines changed

9 files changed

+350
-308
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 111 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/Support/Compiler.h"
3737
#include "llvm/Support/ErrorOr.h"
3838
#include "llvm/Support/JSON.h"
39+
#include <array>
3940
#include <map>
4041

4142
namespace llvm {
@@ -137,13 +138,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
137138
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
138139

139140
/// Class for storing and accessing the IR2Vec vocabulary.
140-
/// Encapsulates all vocabulary-related constants, logic, and access methods.
141+
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
142+
/// seed embeddings are the initial learned representations of the entities
143+
/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
144+
/// seed embeddings.
145+
///
146+
/// The vocabulary contains the seed embeddings for three types of entities:
147+
/// instruction opcodes, types, and operands. Types are grouped/canonicalized
148+
/// for better learning (e.g., all float variants map to FloatTy). The
149+
/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
150+
/// handle all the known LLVM IR opcodes, types and operands.
151+
///
152+
/// This class helps populate the seed embeddings in an internal vector-based
153+
/// ADT. It provides logic to map every IR entity to a specific slot index or
154+
/// position in this vector, enabling O(1) embedding lookup while avoiding
155+
/// unnecessary computations involving string based lookups while generating the
156+
/// embeddings.
141157
class Vocabulary {
142158
friend class llvm::IR2VecVocabAnalysis;
143159
using VocabVector = std::vector<ir2vec::Embedding>;
144160
VocabVector Vocab;
145161
bool Valid = false;
146162

163+
public:
164+
// Slot layout:
165+
// [0 .. MaxOpcodes-1] => Instruction opcodes
166+
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
167+
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
168+
169+
/// Canonical type IDs supported by IR2Vec Vocabulary
170+
enum class CanonicalTypeID : unsigned {
171+
FloatTy,
172+
VoidTy,
173+
LabelTy,
174+
MetadataTy,
175+
VectorTy,
176+
TokenTy,
177+
IntegerTy,
178+
FunctionTy,
179+
PointerTy,
180+
StructTy,
181+
ArrayTy,
182+
UnknownTy,
183+
MaxCanonicalType
184+
};
185+
147186
/// Operand kinds supported by IR2Vec Vocabulary
148187
enum class OperandKind : unsigned {
149188
FunctionID,
@@ -152,20 +191,15 @@ class Vocabulary {
152191
VariableID,
153192
MaxOperandKind
154193
};
155-
/// String mappings for OperandKind values
156-
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
157-
"Constant", "Variable"};
158-
static_assert(std::size(OperandKindNames) ==
159-
static_cast<unsigned>(OperandKind::MaxOperandKind),
160-
"OperandKindNames array size must match MaxOperandKind");
161194

162-
public:
163195
/// Vocabulary layout constants
164196
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
165197
#include "llvm/IR/Instruction.def"
166198
#undef LAST_OTHER_INST
167199

168200
static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
201+
static constexpr unsigned MaxCanonicalTypeIDs =
202+
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
169203
static constexpr unsigned MaxOperandKinds =
170204
static_cast<unsigned>(OperandKind::MaxOperandKind);
171205

@@ -174,33 +208,31 @@ class Vocabulary {
174208

175209
LLVM_ABI bool isValid() const;
176210
LLVM_ABI unsigned getDimension() const;
177-
LLVM_ABI size_t size() const;
211+
/// Total number of entries (opcodes + canonicalized types + operand kinds)
212+
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
178213

179-
static size_t expectedSize() {
180-
return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
181-
}
182-
183-
/// Helper function to get vocabulary key for a given Opcode
214+
/// Function to get vocabulary key for a given Opcode
184215
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);
185216

186-
/// Helper function to get vocabulary key for a given TypeID
217+
/// Function to get vocabulary key for a given TypeID
187218
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
188219

189-
/// Helper function to get vocabulary key for a given OperandKind
220+
/// Function to get vocabulary key for a given OperandKind
190221
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);
191222

192-
/// Helper function to classify an operand into OperandKind
223+
/// Function to classify an operand into OperandKind
193224
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
194225

195-
/// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
196-
LLVM_ABI static unsigned getNumericID(unsigned Opcode);
197-
LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
198-
LLVM_ABI static unsigned getNumericID(const Value *Op);
226+
/// Functions to return the slot index or position of a given Opcode, TypeID,
227+
/// or OperandKind in the vocabulary.
228+
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
229+
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
230+
LLVM_ABI static unsigned getSlotIndex(const Value *Op);
199231

200232
/// Accessors to get the embedding for a given entity.
201233
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
202234
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
203-
LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const;
235+
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
204236

205237
/// Const Iterator type aliases
206238
using const_iterator = VocabVector::const_iterator;
@@ -234,6 +266,61 @@ class Vocabulary {
234266

235267
LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
236268
ModuleAnalysisManager::Invalidator &Inv) const;
269+
270+
private:
271+
constexpr static unsigned NumCanonicalEntries =
272+
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
273+
274+
/// String mappings for CanonicalTypeID values
275+
static constexpr StringLiteral CanonicalTypeNames[] = {
276+
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
277+
"VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
278+
"PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
279+
static_assert(std::size(CanonicalTypeNames) ==
280+
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
281+
"CanonicalTypeNames array size must match MaxCanonicalType");
282+
283+
/// String mappings for OperandKind values
284+
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
285+
"Constant", "Variable"};
286+
static_assert(std::size(OperandKindNames) ==
287+
static_cast<unsigned>(OperandKind::MaxOperandKind),
288+
"OperandKindNames array size must match MaxOperandKind");
289+
290+
/// Every known TypeID defined in llvm/IR/Type.h is expected to have a
291+
/// corresponding mapping here in the same order as enum Type::TypeID.
292+
static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
293+
CanonicalTypeID::FloatTy, // HalfTyID = 0
294+
CanonicalTypeID::FloatTy, // BFloatTyID
295+
CanonicalTypeID::FloatTy, // FloatTyID
296+
CanonicalTypeID::FloatTy, // DoubleTyID
297+
CanonicalTypeID::FloatTy, // X86_FP80TyID
298+
CanonicalTypeID::FloatTy, // FP128TyID
299+
CanonicalTypeID::FloatTy, // PPC_FP128TyID
300+
CanonicalTypeID::VoidTy, // VoidTyID
301+
CanonicalTypeID::LabelTy, // LabelTyID
302+
CanonicalTypeID::MetadataTy, // MetadataTyID
303+
CanonicalTypeID::VectorTy, // X86_AMXTyID
304+
CanonicalTypeID::TokenTy, // TokenTyID
305+
CanonicalTypeID::IntegerTy, // IntegerTyID
306+
CanonicalTypeID::FunctionTy, // FunctionTyID
307+
CanonicalTypeID::PointerTy, // PointerTyID
308+
CanonicalTypeID::StructTy, // StructTyID
309+
CanonicalTypeID::ArrayTy, // ArrayTyID
310+
CanonicalTypeID::VectorTy, // FixedVectorTyID
311+
CanonicalTypeID::VectorTy, // ScalableVectorTyID
312+
CanonicalTypeID::PointerTy, // TypedPointerTyID
313+
CanonicalTypeID::UnknownTy // TargetExtTyID
314+
}};
315+
static_assert(TypeIDMapping.size() == MaxTypeIDs,
316+
"TypeIDMapping must cover all Type::TypeID values");
317+
318+
/// Function to get vocabulary key for canonical type by enum
319+
LLVM_ABI static StringRef
320+
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);
321+
322+
/// Function to convert TypeID to CanonicalTypeID
323+
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
237324
};
238325

239326
/// Embedder provides the interface to generate embeddings (vector
@@ -262,11 +349,11 @@ class Embedder {
262349

263350
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
264351

265-
/// Helper function to compute embeddings. It generates embeddings for all
352+
/// Function to compute embeddings. It generates embeddings for all
266353
/// the instructions and basic blocks in the function F.
267354
void computeEmbeddings() const;
268355

269-
/// Helper function to compute the embedding for a given basic block.
356+
/// Function to compute the embedding for a given basic block.
270357
/// Specific to the kind of embeddings being computed.
271358
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
272359

0 commit comments

Comments
 (0)