Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 117 additions & 24 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
/// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
/// https://arxiv.org/abs/1909.06228
///
/// To obtain embeddings:
/// First run IR2VecVocabAnalysis to populate the vocabulary.
/// Then, use the Embedder interface to generate embeddings for the desired IR
/// entities. See the documentation for more details -
/// https://llvm.org/docs/MLGO.html#ir2vec-embeddings
///
//===----------------------------------------------------------------------===//

#ifndef LLVM_ANALYSIS_IR2VEC_H
Expand All @@ -36,6 +42,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/JSON.h"
#include <array>
#include <map>

namespace llvm {
Expand Down Expand Up @@ -137,13 +144,51 @@ using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;

/// Class for storing and accessing the IR2Vec vocabulary.
/// Encapsulates all vocabulary-related constants, logic, and access methods.
/// The Vocabulary class manages seed embeddings for LLVM IR entities. The
/// seed embeddings are the initial learned representations of the entities
/// of LLVM IR. The IR2Vec representation for a given IR is derived from these
/// seed embeddings.
///
/// The vocabulary contains the seed embeddings for three types of entities:
/// instruction opcodes, types, and operands. Types are grouped/canonicalized
/// for better learning (e.g., all float variants map to FloatTy). The
/// vocabulary abstracts away the canonicalization effectively, the exposed APIs
/// handle all the known LLVM IR opcodes, types and operands.
///
/// This class helps populate the seed embeddings in an internal vector-based
/// ADT. It provides logic to map every IR entity to a specific slot index or
/// position in this vector, enabling O(1) embedding lookup while avoiding
/// unnecessary computations involving string based lookups while generating the
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
bool Valid = false;

public:
// Slot layout:
// [0 .. MaxOpcodes-1] => Instruction opcodes
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds

/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
VoidTy,
LabelTy,
MetadataTy,
VectorTy,
TokenTy,
IntegerTy,
FunctionTy,
PointerTy,
StructTy,
ArrayTy,
UnknownTy,
MaxCanonicalType
};

/// Operand kinds supported by IR2Vec Vocabulary
enum class OperandKind : unsigned {
FunctionID,
Expand All @@ -152,20 +197,15 @@ class Vocabulary {
VariableID,
MaxOperandKind
};
/// String mappings for OperandKind values
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
"Constant", "Variable"};
static_assert(std::size(OperandKindNames) ==
static_cast<unsigned>(OperandKind::MaxOperandKind),
"OperandKindNames array size must match MaxOperandKind");

public:
/// Vocabulary layout constants
#define LAST_OTHER_INST(NUM) static constexpr unsigned MaxOpcodes = NUM;
#include "llvm/IR/Instruction.def"
#undef LAST_OTHER_INST

static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1;
static constexpr unsigned MaxCanonicalTypeIDs =
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);

Expand All @@ -174,33 +214,31 @@ class Vocabulary {

LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
LLVM_ABI size_t size() const;
/// Total number of entries (opcodes + canonicalized types + operand kinds)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }

static size_t expectedSize() {
return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
}

/// Helper function to get vocabulary key for a given Opcode
/// Function to get vocabulary key for a given Opcode
LLVM_ABI static StringRef getVocabKeyForOpcode(unsigned Opcode);

/// Helper function to get vocabulary key for a given TypeID
/// Function to get vocabulary key for a given TypeID
LLVM_ABI static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);

/// Helper function to get vocabulary key for a given OperandKind
/// Function to get vocabulary key for a given OperandKind
LLVM_ABI static StringRef getVocabKeyForOperandKind(OperandKind Kind);

/// Helper function to classify an operand into OperandKind
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);

/// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
LLVM_ABI static unsigned getNumericID(unsigned Opcode);
LLVM_ABI static unsigned getNumericID(Type::TypeID TypeID);
LLVM_ABI static unsigned getNumericID(const Value *Op);
/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value *Op);

/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value *Arg) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;

/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
Expand Down Expand Up @@ -234,6 +272,61 @@ class Vocabulary {

LLVM_ABI bool invalidate(Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const;

private:
constexpr static unsigned NumCanonicalEntries =
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;

/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
"FloatTy", "VoidTy", "LabelTy", "MetadataTy",
"VectorTy", "TokenTy", "IntegerTy", "FunctionTy",
"PointerTy", "StructTy", "ArrayTy", "UnknownTy"};
static_assert(std::size(CanonicalTypeNames) ==
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType),
"CanonicalTypeNames array size must match MaxCanonicalType");

/// String mappings for OperandKind values
static constexpr StringLiteral OperandKindNames[] = {"Function", "Pointer",
"Constant", "Variable"};
static_assert(std::size(OperandKindNames) ==
static_cast<unsigned>(OperandKind::MaxOperandKind),
"OperandKindNames array size must match MaxOperandKind");

/// Every known TypeID defined in llvm/IR/Type.h is expected to have a
/// corresponding mapping here in the same order as enum Type::TypeID.
static constexpr std::array<CanonicalTypeID, MaxTypeIDs> TypeIDMapping = {{
CanonicalTypeID::FloatTy, // HalfTyID = 0
CanonicalTypeID::FloatTy, // BFloatTyID
CanonicalTypeID::FloatTy, // FloatTyID
CanonicalTypeID::FloatTy, // DoubleTyID
CanonicalTypeID::FloatTy, // X86_FP80TyID
CanonicalTypeID::FloatTy, // FP128TyID
CanonicalTypeID::FloatTy, // PPC_FP128TyID
CanonicalTypeID::VoidTy, // VoidTyID
CanonicalTypeID::LabelTy, // LabelTyID
CanonicalTypeID::MetadataTy, // MetadataTyID
CanonicalTypeID::VectorTy, // X86_AMXTyID
CanonicalTypeID::TokenTy, // TokenTyID
CanonicalTypeID::IntegerTy, // IntegerTyID
CanonicalTypeID::FunctionTy, // FunctionTyID
CanonicalTypeID::PointerTy, // PointerTyID
CanonicalTypeID::StructTy, // StructTyID
CanonicalTypeID::ArrayTy, // ArrayTyID
CanonicalTypeID::VectorTy, // FixedVectorTyID
CanonicalTypeID::VectorTy, // ScalableVectorTyID
CanonicalTypeID::PointerTy, // TypedPointerTyID
CanonicalTypeID::UnknownTy // TargetExtTyID
}};
static_assert(TypeIDMapping.size() == MaxTypeIDs,
"TypeIDMapping must cover all Type::TypeID values");

/// Function to get vocabulary key for canonical type by enum
LLVM_ABI static StringRef
getVocabKeyForCanonicalTypeID(CanonicalTypeID CType);

/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
};

/// Embedder provides the interface to generate embeddings (vector
Expand Down Expand Up @@ -262,11 +355,11 @@ class Embedder {

LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);

/// Helper function to compute embeddings. It generates embeddings for all
/// Function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F.
void computeEmbeddings() const;

/// Helper function to compute the embedding for a given basic block.
/// Function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;

Expand Down
Loading