Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions llvm/include/llvm/Analysis/IR2Vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#define LLVM_ANALYSIS_IR2VEC_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
Expand Down Expand Up @@ -162,16 +163,30 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;

// Vocabulary Slot Layout:
// +----------------+------------------------------------------------------+
// | Entity Type | Index Range |
// +----------------+------------------------------------------------------+
// | Opcodes | [0 .. (MaxOpcodes-1)] |
// | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] |
// | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] |
// +----------------+------------------------------------------------------+
// Note: "Similar" LLVM Types are grouped/canonicalized together.
// Operands include Comparison predicates (ICmp/FCmp).
// This can be extended to include other specializations in future.
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;

bool Valid = false;
static constexpr unsigned NumICmpPredicates =
static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
static constexpr unsigned NumFCmpPredicates =
static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;

public:
// Slot layout:
// [0 .. MaxOpcodes-1] => Instruction opcodes
// [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
// [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds

/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
Expand Down Expand Up @@ -208,13 +223,18 @@ class Vocabulary {
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
// CmpInst::Predicate has gaps. We want the vocabulary to be dense without
// empty slots.
static constexpr unsigned MaxPredicateKinds =
NumICmpPredicates + NumFCmpPredicates;

Vocabulary() = default;
LLVM_ABI Vocabulary(VocabVector &&Vocab);

LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
/// Total number of entries (opcodes + canonicalized types + operand kinds)
/// Total number of entries (opcodes + canonicalized types + operand kinds +
/// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }

/// Function to get vocabulary key for a given Opcode
Expand All @@ -229,16 +249,21 @@ class Vocabulary {
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);

/// Function to get vocabulary key for a given predicate
LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);

/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);

/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;

/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
Expand Down Expand Up @@ -275,7 +300,13 @@ class Vocabulary {

private:
constexpr static unsigned NumCanonicalEntries =
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;

// Base offsets for slot layout to simplify index computation
constexpr static unsigned OperandBaseOffset =
MaxOpcodes + MaxCanonicalTypeIDs;
constexpr static unsigned PredicateBaseOffset =
OperandBaseOffset + MaxOperandKinds;

/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
Expand Down Expand Up @@ -327,6 +358,9 @@ class Vocabulary {

/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);

/// Function to get the predicate enum value for a given index
LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
};

/// Embedder provides the interface to generate embeddings (vector
Expand Down
76 changes: 67 additions & 9 deletions llvm/lib/Analysis/IR2Vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
Expand Down Expand Up @@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
// Add compare predicate embedding as an additional operand if applicable
if (const auto *IC = dyn_cast<CmpInst>(&I))
InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
Expand Down Expand Up @@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
return MaxOpcodes + MaxCanonicalTypeIDs + Index;
return OperandBaseOffset + Index;
}

unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
unsigned PU = static_cast<unsigned>(P);
unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);

unsigned PredIdx =
(PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
return PredicateBaseOffset + PredIdx;
}

const Embedding &Vocabulary::operator[](unsigned Opcode) const {
Expand All @@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(Arg)];
}

const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
return Vocab[getSlotIndex(P)];
}

StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
Expand Down Expand Up @@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}

CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
assert(Index < MaxPredicateKinds && "Invalid predicate index");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so Index is relative to the predicates subrange, i.e. Index == 0 is the first predicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct.

unsigned PredEnumVal =
(Index < NumFCmpPredicates)
? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
: (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
(Index - NumFCmpPredicates));
return static_cast<CmpInst::Predicate>(PredEnumVal);
}

StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
static SmallString<16> PredNameBuffer;
if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
PredNameBuffer = "FCMP_";
else
PredNameBuffer = "ICMP_";
PredNameBuffer += CmpInst::getPredicateName(Pred);
return PredNameBuffer;
}

StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
return getVocabKeyForOperandKind(
static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
if (Pos < PredicateBaseOffset)
return getVocabKeyForOperandKind(
static_cast<OperandKind>(Pos - OperandBaseOffset));
// Predicates
return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}

// For now, assume vocabulary is stable unless explicitly invalidated.
Expand All @@ -370,11 +412,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
VocabVector DummyVocab;
DummyVocab.reserve(NumCanonicalEntries);
float DummyVal = 0.1f;
// Create a dummy vocabulary with entries for all opcodes, types, and
// operands
for ([[maybe_unused]] unsigned _ :
seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
Vocabulary::MaxOperandKinds)) {
// Create a dummy vocabulary with entries for all opcodes, types, operands
// and predicates
for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
DummyVocab.push_back(Embedding(Dim, DummyVal));
DummyVal += 0.1f;
}
Expand Down Expand Up @@ -517,6 +557,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
NumericArgEmbeddings.end());

// Handle Predicates: part of Operands section. We look up predicate keys
// in ArgVocab.
std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
Embedding(Dim, 0));
NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
StringRef VocabKey =
Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
auto It = ArgVocab.find(VocabKey.str());
if (It != ArgVocab.end()) {
NumericPredEmbeddings[PK] = It->second;
continue;
}
handleMissingEntity(VocabKey.str());
}
Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
NumericPredEmbeddings.end());
}

IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
Expand Down
28 changes: 27 additions & 1 deletion llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@
"Function": [1, 2],
"Pointer": [3, 4],
"Constant": [5, 6],
"Variable": [7, 8]
"Variable": [7, 8],
"FCMP_false": [9, 10],
"FCMP_oeq": [11, 12],
"FCMP_ogt": [13, 14],
"FCMP_oge": [15, 16],
"FCMP_olt": [17, 18],
"FCMP_ole": [19, 20],
"FCMP_one": [21, 22],
"FCMP_ord": [23, 24],
"FCMP_uno": [25, 26],
"FCMP_ueq": [27, 28],
"FCMP_ugt": [29, 30],
"FCMP_uge": [31, 32],
"FCMP_ult": [33, 34],
"FCMP_ule": [35, 36],
"FCMP_une": [37, 38],
"FCMP_true": [39, 40],
"ICMP_eq": [41, 42],
"ICMP_ne": [43, 44],
"ICMP_ugt": [45, 46],
"ICMP_uge": [47, 48],
"ICMP_ult": [49, 50],
"ICMP_ule": [51, 52],
"ICMP_sgt": [53, 54],
"ICMP_sge": [55, 56],
"ICMP_slt": [57, 58],
"ICMP_sle": [59, 60]
}
}
28 changes: 27 additions & 1 deletion llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,32 @@
"Function": [1, 2, 3],
"Pointer": [4, 5, 6],
"Constant": [7, 8, 9],
"Variable": [10, 11, 12]
"Variable": [10, 11, 12],
"FCMP_false": [13, 14, 15],
"FCMP_oeq": [16, 17, 18],
"FCMP_ogt": [19, 20, 21],
"FCMP_oge": [22, 23, 24],
"FCMP_olt": [25, 26, 27],
"FCMP_ole": [28, 29, 30],
"FCMP_one": [31, 32, 33],
"FCMP_ord": [34, 35, 36],
"FCMP_uno": [37, 38, 39],
"FCMP_ueq": [40, 41, 42],
"FCMP_ugt": [43, 44, 45],
"FCMP_uge": [46, 47, 48],
"FCMP_ult": [49, 50, 51],
"FCMP_ule": [52, 53, 54],
"FCMP_une": [55, 56, 57],
"FCMP_true": [58, 59, 60],
"ICMP_eq": [61, 62, 63],
"ICMP_ne": [64, 65, 66],
"ICMP_ugt": [67, 68, 69],
"ICMP_uge": [70, 71, 72],
"ICMP_ult": [73, 74, 75],
"ICMP_ule": [76, 77, 78],
"ICMP_sgt": [79, 80, 81],
"ICMP_sge": [82, 83, 84],
"ICMP_slt": [85, 86, 87],
"ICMP_sle": [88, 89, 90]
}
}
29 changes: 28 additions & 1 deletion llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"FPTrunc": [133, 134, 135],
"FPExt": [136, 137, 138],
"PtrToInt": [139, 140, 141],
"PtrToAddr": [202, 203, 204],
"IntToPtr": [142, 143, 144],
"BitCast": [145, 146, 147],
"AddrSpaceCast": [148, 149, 150],
Expand Down Expand Up @@ -86,6 +87,32 @@
"Function": [0, 0, 0],
"Pointer": [0, 0, 0],
"Constant": [0, 0, 0],
"Variable": [0, 0, 0]
"Variable": [0, 0, 0],
"FCMP_false": [0, 0, 0],
"FCMP_oeq": [0, 0, 0],
"FCMP_ogt": [0, 0, 0],
"FCMP_oge": [0, 0, 0],
"FCMP_olt": [0, 0, 0],
"FCMP_ole": [0, 0, 0],
"FCMP_one": [0, 0, 0],
"FCMP_ord": [0, 0, 0],
"FCMP_uno": [0, 0, 0],
"FCMP_ueq": [0, 0, 0],
"FCMP_ugt": [0, 0, 0],
"FCMP_uge": [0, 0, 0],
"FCMP_ult": [0, 0, 0],
"FCMP_ule": [0, 0, 0],
"FCMP_une": [0, 0, 0],
"FCMP_true": [0, 0, 0],
"ICMP_eq": [0, 0, 0],
"ICMP_ne": [0, 0, 0],
"ICMP_ugt": [0, 0, 0],
"ICMP_uge": [0, 0, 0],
"ICMP_ult": [0, 0, 0],
"ICMP_ule": [0, 0, 0],
"ICMP_sgt": [1, 1, 1],
"ICMP_sge": [0, 0, 0],
"ICMP_slt": [0, 0, 0],
"ICMP_sle": [0, 0, 0]
}
}
26 changes: 26 additions & 0 deletions llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
Key: Constant: [ 1.00 1.20 ]
Key: Variable: [ 1.40 1.60 ]
Key: FCMP_false: [ 1.80 2.00 ]
Key: FCMP_oeq: [ 2.20 2.40 ]
Key: FCMP_ogt: [ 2.60 2.80 ]
Key: FCMP_oge: [ 3.00 3.20 ]
Key: FCMP_olt: [ 3.40 3.60 ]
Key: FCMP_ole: [ 3.80 4.00 ]
Key: FCMP_one: [ 4.20 4.40 ]
Key: FCMP_ord: [ 4.60 4.80 ]
Key: FCMP_uno: [ 5.00 5.20 ]
Key: FCMP_ueq: [ 5.40 5.60 ]
Key: FCMP_ugt: [ 5.80 6.00 ]
Key: FCMP_uge: [ 6.20 6.40 ]
Key: FCMP_ult: [ 6.60 6.80 ]
Key: FCMP_ule: [ 7.00 7.20 ]
Key: FCMP_une: [ 7.40 7.60 ]
Key: FCMP_true: [ 7.80 8.00 ]
Key: ICMP_eq: [ 8.20 8.40 ]
Key: ICMP_ne: [ 8.60 8.80 ]
Key: ICMP_ugt: [ 9.00 9.20 ]
Key: ICMP_uge: [ 9.40 9.60 ]
Key: ICMP_ult: [ 9.80 10.00 ]
Key: ICMP_ule: [ 10.20 10.40 ]
Key: ICMP_sgt: [ 10.60 10.80 ]
Key: ICMP_sge: [ 11.00 11.20 ]
Key: ICMP_slt: [ 11.40 11.60 ]
Key: ICMP_sle: [ 11.80 12.00 ]
Loading