-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[IR2Vec] Add support for Cmp predicates in vocabulary and embeddings #156952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
8c8500c
to
fd4e1df
Compare
6185e40
to
cfb9293
Compare
cfb9293
to
f81b699
Compare
f81b699
to
5adbeae
Compare
@llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesComparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space. Following are the changes:
(Tracking issues: #141817, #141833) Patch is 25.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156952.diff 13 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index b7b881999241e..d49854e2d06a8 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#define LLVM_ANALYSIS_IR2VEC_H
#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
+ // Slot layout:
+ // [0 .. MaxOpcodes-1] => Instruction
+ // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] =>
+ // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
+ // Operands
+ // Within Operands: first OperandKind entries, followed by compare
+ // predicates
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
+
bool Valid = false;
+ static constexpr unsigned NumICmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+ static constexpr unsigned NumFCmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
public:
- // Slot layout:
- // [0 .. MaxOpcodes-1] => Instruction opcodes
- // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
- // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
-
/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
+ // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+ // empty slots.
+ static constexpr unsigned MaxPredicateKinds =
+ NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
LLVM_ABI Vocabulary(VocabVector &&Vocab);
LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
- /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ /// Total number of entries (opcodes + canonicalized types + operand kinds +
+ /// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
/// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
+ /// Function to get vocabulary key for a given predicate
+ LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+ LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
private:
constexpr static unsigned NumCanonicalEntries =
- MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+ // Base offsets for slot layout to simplify index computation
+ constexpr static unsigned OperandBaseOffset =
+ MaxOpcodes + MaxCanonicalTypeIDs;
+ constexpr static unsigned PredicateBaseOffset =
+ OperandBaseOffset + MaxOperandKinds;
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+ /// Function to get the predicate enum value for a given index
+ LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
};
/// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 98849fd922843..2981a6ab9382b 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ // Add compare predicate embedding as an additional operand if applicable
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+ return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+ unsigned PU = static_cast<unsigned>(P);
+ unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+ unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+ unsigned PredIdx =
+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+ return PredicateBaseOffset + PredIdx;
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(Arg)];
}
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+ return Vocab[getSlotIndex(P)];
+}
+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+ assert(Index < MaxPredicateKinds && "Invalid predicate index");
+ unsigned PredEnumVal =
+ (Index < NumFCmpPredicates)
+ ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+ : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+ (Index - NumFCmpPredicates));
+ return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+ static SmallString<16> PredNameBuffer;
+ if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+ PredNameBuffer = "FCMP_";
+ else
+ PredNameBuffer = "ICMP_";
+ PredNameBuffer += CmpInst::getPredicateName(Pred);
+ return PredNameBuffer;
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
- return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+ if (Pos < PredicateBaseOffset)
+ return getVocabKeyForOperandKind(
+ static_cast<OperandKind>(Pos - OperandBaseOffset));
+ // Predicates
+ return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +412,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
VocabVector DummyVocab;
DummyVocab.reserve(NumCanonicalEntries);
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operands
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
- Vocabulary::MaxOperandKinds)) {
+ // Create a dummy vocabulary with entries for all opcodes, types, operands
+ // and predicates
+ for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
DummyVocab.push_back(Embedding(Dim, DummyVal));
DummyVal += 0.1f;
}
@@ -517,6 +557,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
NumericArgEmbeddings.end());
+
+ // Handle Predicates: part of Operands section. We look up predicate keys
+ // in ArgVocab.
+ std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+ Embedding(Dim, 0));
+ NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+ for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+ StringRef VocabKey =
+ Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+ auto It = ArgVocab.find(VocabKey.str());
+ if (It != ArgVocab.end()) {
+ NumericPredEmbeddings[PK] = It->second;
+ continue;
+ }
+ handleMissingEntity(VocabKey.str());
+ }
+ Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+ NumericPredEmbeddings.end());
}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84c1541b..ae36ff54686c5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
"Function": [1, 2],
"Pointer": [3, 4],
"Constant": [5, 6],
- "Variable": [7, 8]
+ "Variable": [7, 8],
+ "FCMP_false": [9, 10],
+ "FCMP_oeq": [11, 12],
+ "FCMP_ogt": [13, 14],
+ "FCMP_oge": [15, 16],
+ "FCMP_olt": [17, 18],
+ "FCMP_ole": [19, 20],
+ "FCMP_one": [21, 22],
+ "FCMP_ord": [23, 24],
+ "FCMP_uno": [25, 26],
+ "FCMP_ueq": [27, 28],
+ "FCMP_ugt": [29, 30],
+ "FCMP_uge": [31, 32],
+ "FCMP_ult": [33, 34],
+ "FCMP_ule": [35, 36],
+ "FCMP_une": [37, 38],
+ "FCMP_true": [39, 40],
+ "ICMP_eq": [41, 42],
+ "ICMP_ne": [43, 44],
+ "ICMP_ugt": [45, 46],
+ "ICMP_uge": [47, 48],
+ "ICMP_ult": [49, 50],
+ "ICMP_ule": [51, 52],
+ "ICMP_sgt": [53, 54],
+ "ICMP_sge": [55, 56],
+ "ICMP_slt": [57, 58],
+ "ICMP_sle": [59, 60]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a217b70c..9003dc73954aa 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
"Function": [1, 2, 3],
"Pointer": [4, 5, 6],
"Constant": [7, 8, 9],
- "Variable": [10, 11, 12]
+ "Variable": [10, 11, 12],
+ "FCMP_false": [13, 14, 15],
+ "FCMP_oeq": [16, 17, 18],
+ "FCMP_ogt": [19, 20, 21],
+ "FCMP_oge": [22, 23, 24],
+ "FCMP_olt": [25, 26, 27],
+ "FCMP_ole": [28, 29, 30],
+ "FCMP_one": [31, 32, 33],
+ "FCMP_ord": [34, 35, 36],
+ "FCMP_uno": [37, 38, 39],
+ "FCMP_ueq": [40, 41, 42],
+ "FCMP_ugt": [43, 44, 45],
+ "FCMP_uge": [46, 47, 48],
+ "FCMP_ult": [49, 50, 51],
+ "FCMP_ule": [52, 53, 54],
+ "FCMP_une": [55, 56, 57],
+ "FCMP_true": [58, 59, 60],
+ "ICMP_eq": [61, 62, 63],
+ "ICMP_ne": [64, 65, 66],
+ "ICMP_ugt": [67, 68, 69],
+ "ICMP_uge": [70, 71, 72],
+ "ICMP_ult": [73, 74, 75],
+ "ICMP_ule": [76, 77, 78],
+ "ICMP_sgt": [79, 80, 81],
+ "ICMP_sge": [82, 83, 84],
+ "ICMP_slt": [85, 86, 87],
+ "ICMP_sle": [88, 89, 90]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efee9f6a1..7ef85490b27df 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
"FPTrunc": [133, 134, 135],
"FPExt": [136, 137, 138],
"PtrToInt": [139, 140, 141],
+ "PtrToAddr": [202, 203, 204],
"IntToPtr": [142, 143, 144],
"BitCast": [145, 146, 147],
"AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
"Function": [0, 0, 0],
"Pointer": [0, 0, 0],
"Constant": [0, 0, 0],
- "Variable": [0, 0, 0]
+ "Variable": [0, 0, 0],
+ "FCMP_false": [0, 0, 0],
+ "FCMP_oeq": [0, 0, 0],
+ "FCMP_ogt": [0, 0, 0],
+ "FCMP_oge": [0, 0, 0],
+ "FCMP_olt": [0, 0, 0],
+ "FCMP_ole": [0, 0, 0],
+ "FCMP_one": [0, 0, 0],
+ "FCMP_ord": [0, 0, 0],
+ "FCMP_uno": [0, 0, 0],
+ "FCMP_ueq": [0, 0, 0],
+ "FCMP_ugt": [0, 0, 0],
+ "FCMP_uge": [0, 0, 0],
+ "FCMP_ult": [0, 0, 0],
+ "FCMP_ule": [0, 0, 0],
+ "FCMP_une": [0, 0, 0],
+ "FCMP_true": [0, 0, 0],
+ "ICMP_eq": [0, 0, 0],
+ "ICMP_ne": [0, 0, 0],
+ "ICMP_ugt": [0, 0, 0],
+ "ICMP_uge": [0, 0, 0],
+ "ICMP_ult": [0, 0, 0],
+ "ICMP_ule": [0, 0, 0],
+ "ICMP_sgt": [1, 1, 1],
+ "ICMP_sge": [0, 0, 0],
+ "ICMP_slt": [0, 0, 0],
+ "ICMP_sle": [0, 0, 0]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c9c6a65..d62b0dd157b0b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
Key: Constant: [ 1.00 1.20 ]
Key: Variable: [ 1.40 1.60 ]
+Key: FCMP_false: [ 1.80 2.00 ]
+Key: FCMP_oeq: [ 2.20 2.40 ]
+Key: FCMP_ogt: [ 2.60 2.80 ]
+Key: FCMP_oge: [ 3.00 3.20 ]
+Key: FCMP_olt: [ 3.40 3.60 ]
+Key: FCMP_ole: [ 3.80 4.00 ]
+Key: FCMP_one: [ 4.20 4.40 ]
+Key: FCMP_ord: [ 4.60 4.80 ]
+Key: FCMP_uno: [ 5.00 5.20 ]
+Key: FCMP_ueq: [ 5.40 5.60 ]
+Key: FCMP_ugt: [ 5.80 6.00 ]
+Key: FCMP_uge: [ 6.20 6.40 ]
+Key: FCMP_ult: [ 6.60 6.80 ]
+Key: FCMP_ule: [ 7.00 7.20 ]
+Key: FCMP_une: [ 7.40 7.60 ]
+Key: FCMP_true: [ 7.80 8.00 ]
+Key: ICMP_eq: [ 8.20 8.40 ]
+Key: ICMP_ne: [ 8.60 8.80 ]
+Key: ICMP_ugt: [ 9.00 9.20 ]
+Key: ICMP_uge: [ 9.40 9.60 ]
+Key: ICMP_ult: [ 9.80 10.00 ]
+Key: ICMP_ule: [ 10.20 10.40 ]
+Key: ICMP_sgt: [ 10.60 10.80 ]
+Key: ICMP_sge: [ 11.00 11.20 ]
+Key: ICMP_slt: [ 11.40 11.60 ]
+Key: ICMP_sle: [ 11.80 12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809fd2fd2..e443adb17ac78 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
Key: Constant: [ 2.50 3.00 ]
Key: Variable: [ 3.50 4.00 ]
+Key: FCMP_false: [ 4.50 5.00 ]
+Key: FCMP_oeq: [ 5.50 6.00 ]
+Key: FCMP_ogt: [ 6.50 7.00 ]
+Key: FCMP_oge: [ 7.50 8.00 ]
+Key: FCMP_olt: [ 8.50 9.00 ]
+Key: FCMP_ole: [ 9.50 10.00 ]
+Key: FCMP_one: [ 10.50 11.00 ]
+Key: FCMP_ord: [ 11.50 12.00 ]
+Key: FCMP_uno: [ 12.50 13.00 ]
+Key: FCMP_ueq: [ 13.50 14.00 ]
+Key: FCMP_ugt: [ 14.50 15.00 ]
+Key: FCMP_uge: [ 15.50 16.00 ]
+Key: FCMP_ult: [ 16.50 17.00 ]
+Key: FCMP_ule: [ 17.50 18.00 ]
+Key: FCMP_une: [ 18.50 19.00 ]
+Key: FCMP_true: [ 19.50 20.00 ]
+Key: ICMP_eq: [ 20.50 21.00 ]
+Key: ICMP_ne: [ 21.50 22.00 ]
+Key: ICMP_ugt: [ 22.50 23.00 ]
+Key: ICMP_uge: [ 23.50 24.00 ]
+Key: ICMP_ult: [ 24.50 25.00 ]
+Key: ICMP_ule: [ 25.50 26.00 ]
+Key: ICMP_sgt: [ 26.50 27.00 ]
+Key: ICMP_sge: [ 27.50 28.00 ]
+Key: ICMP_slt: [ 28.50 29.00 ]
+Key: ICMP_sle: [ 29.50 30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9bd3d9c..7fb6043552f7b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.00 0.00 ]
Key: Pointer: [ 0.00 0.00 ]
Key: Constant: [ 0.00 0.00 ]
Key: Variable: [ 0.00 0.00 ]
+Key: FCMP_false: [ 0.00 0.00 ]
+Key: FCMP_oeq: [ 0.00 0.00 ]
+Key: FCMP_ogt: [ 0.00 0.00 ]
+Key: FCMP_oge: [ 0.00 0.00 ]
+Key: FCMP_olt: [ 0.00 0.00 ]
+Key: FCMP_ole: [ 0.00 0.00 ]
+Key: FCMP_one: [ 0.00 0.00 ]
+Key: FCMP_ord: [ 0.00 0.00 ]
+Key: FCMP_uno: [ 0.00 0.00 ]
+Key: FCMP_ueq: [ 0.00 0.00 ]
+Key: FCMP_ugt: [ 0.00 0.00 ]
+Key: FCMP_uge: [ 0.00 0.00 ]
+Key: FCMP_ult: [ 0.00 0.00 ]
+Key: FCMP_ule: [ 0.00 0.00 ]
+Key: FCMP_une: [ 0.00 0.00 ]
+Key: FCMP_true: [ 0.00 0.00 ]
+Key: ICMP_eq: [ 0.00 0.00 ]
+Key: ICMP_ne: [ 0.00 0.00 ]
+Key: ICMP_ugt: [ 0.00 0.00 ]
+Key: ICMP_uge: [ 0.00 0.00 ]
+Key: ICMP_ult: [ 0.00 0.00 ]
+Key: ICMP_ule: [ 0.00 0.00 ]
+Key: ICMP_sgt: [ 0.00 0.00 ]
+Key: ICMP_sge: [ 0.00 0.00 ]
+Key: ICMP_slt: [ 0.00 0.00 ]
+Key: ICMP_sle: [ 0.00 0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe532479086d3..804c1ca5cb6f6 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49978018..9be0ee1c2de7a 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf30bf74..8dbce57302f6f 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
; RUN: llvm-ir2vec entities | FileCheck %s
-CHECK: 84
+CHECK: 110
CHECK-NEXT: Ret 0
CHECK-NEXT: Br 1
CHECK-NEXT: Switch 2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function 80
CHECK-NEXT: Pointer ...
[truncated]
|
@llvm/pr-subscribers-llvm-analysis Author: S. VenkataKeerthy (svkeerthy) ChangesComparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space. Following are the changes:
(Tracking issues: #141817, #141833) Patch is 25.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156952.diff 13 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index b7b881999241e..d49854e2d06a8 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#define LLVM_ANALYSIS_IR2VEC_H
#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
+ // Slot layout:
+ // [0 .. MaxOpcodes-1] => Instruction
+ // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] =>
+ // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
+ // Operands
+ // Within Operands: first OperandKind entries, followed by compare
+ // predicates
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
+
bool Valid = false;
+ static constexpr unsigned NumICmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+ static constexpr unsigned NumFCmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
public:
- // Slot layout:
- // [0 .. MaxOpcodes-1] => Instruction opcodes
- // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
- // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
-
/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
+ // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+ // empty slots.
+ static constexpr unsigned MaxPredicateKinds =
+ NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
LLVM_ABI Vocabulary(VocabVector &&Vocab);
LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
- /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ /// Total number of entries (opcodes + canonicalized types + operand kinds +
+ /// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
/// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
+ /// Function to get vocabulary key for a given predicate
+ LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+ LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
private:
constexpr static unsigned NumCanonicalEntries =
- MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+ // Base offsets for slot layout to simplify index computation
+ constexpr static unsigned OperandBaseOffset =
+ MaxOpcodes + MaxCanonicalTypeIDs;
+ constexpr static unsigned PredicateBaseOffset =
+ OperandBaseOffset + MaxOperandKinds;
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+ /// Function to get the predicate enum value for a given index
+ LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
};
/// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 98849fd922843..2981a6ab9382b 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ // Add compare predicate embedding as an additional operand if applicable
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+ return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+ unsigned PU = static_cast<unsigned>(P);
+ unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+ unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+ unsigned PredIdx =
+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+ return PredicateBaseOffset + PredIdx;
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(Arg)];
}
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+ return Vocab[getSlotIndex(P)];
+}
+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+ assert(Index < MaxPredicateKinds && "Invalid predicate index");
+ unsigned PredEnumVal =
+ (Index < NumFCmpPredicates)
+ ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+ : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+ (Index - NumFCmpPredicates));
+ return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+ static SmallString<16> PredNameBuffer;
+ if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+ PredNameBuffer = "FCMP_";
+ else
+ PredNameBuffer = "ICMP_";
+ PredNameBuffer += CmpInst::getPredicateName(Pred);
+ return PredNameBuffer;
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
- return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+ if (Pos < PredicateBaseOffset)
+ return getVocabKeyForOperandKind(
+ static_cast<OperandKind>(Pos - OperandBaseOffset));
+ // Predicates
+ return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +412,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
VocabVector DummyVocab;
DummyVocab.reserve(NumCanonicalEntries);
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operands
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
- Vocabulary::MaxOperandKinds)) {
+ // Create a dummy vocabulary with entries for all opcodes, types, operands
+ // and predicates
+ for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
DummyVocab.push_back(Embedding(Dim, DummyVal));
DummyVal += 0.1f;
}
@@ -517,6 +557,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
NumericArgEmbeddings.end());
+
+ // Handle Predicates: part of Operands section. We look up predicate keys
+ // in ArgVocab.
+ std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+ Embedding(Dim, 0));
+ NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+ for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+ StringRef VocabKey =
+ Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+ auto It = ArgVocab.find(VocabKey.str());
+ if (It != ArgVocab.end()) {
+ NumericPredEmbeddings[PK] = It->second;
+ continue;
+ }
+ handleMissingEntity(VocabKey.str());
+ }
+ Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+ NumericPredEmbeddings.end());
}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84c1541b..ae36ff54686c5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
"Function": [1, 2],
"Pointer": [3, 4],
"Constant": [5, 6],
- "Variable": [7, 8]
+ "Variable": [7, 8],
+ "FCMP_false": [9, 10],
+ "FCMP_oeq": [11, 12],
+ "FCMP_ogt": [13, 14],
+ "FCMP_oge": [15, 16],
+ "FCMP_olt": [17, 18],
+ "FCMP_ole": [19, 20],
+ "FCMP_one": [21, 22],
+ "FCMP_ord": [23, 24],
+ "FCMP_uno": [25, 26],
+ "FCMP_ueq": [27, 28],
+ "FCMP_ugt": [29, 30],
+ "FCMP_uge": [31, 32],
+ "FCMP_ult": [33, 34],
+ "FCMP_ule": [35, 36],
+ "FCMP_une": [37, 38],
+ "FCMP_true": [39, 40],
+ "ICMP_eq": [41, 42],
+ "ICMP_ne": [43, 44],
+ "ICMP_ugt": [45, 46],
+ "ICMP_uge": [47, 48],
+ "ICMP_ult": [49, 50],
+ "ICMP_ule": [51, 52],
+ "ICMP_sgt": [53, 54],
+ "ICMP_sge": [55, 56],
+ "ICMP_slt": [57, 58],
+ "ICMP_sle": [59, 60]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a217b70c..9003dc73954aa 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
"Function": [1, 2, 3],
"Pointer": [4, 5, 6],
"Constant": [7, 8, 9],
- "Variable": [10, 11, 12]
+ "Variable": [10, 11, 12],
+ "FCMP_false": [13, 14, 15],
+ "FCMP_oeq": [16, 17, 18],
+ "FCMP_ogt": [19, 20, 21],
+ "FCMP_oge": [22, 23, 24],
+ "FCMP_olt": [25, 26, 27],
+ "FCMP_ole": [28, 29, 30],
+ "FCMP_one": [31, 32, 33],
+ "FCMP_ord": [34, 35, 36],
+ "FCMP_uno": [37, 38, 39],
+ "FCMP_ueq": [40, 41, 42],
+ "FCMP_ugt": [43, 44, 45],
+ "FCMP_uge": [46, 47, 48],
+ "FCMP_ult": [49, 50, 51],
+ "FCMP_ule": [52, 53, 54],
+ "FCMP_une": [55, 56, 57],
+ "FCMP_true": [58, 59, 60],
+ "ICMP_eq": [61, 62, 63],
+ "ICMP_ne": [64, 65, 66],
+ "ICMP_ugt": [67, 68, 69],
+ "ICMP_uge": [70, 71, 72],
+ "ICMP_ult": [73, 74, 75],
+ "ICMP_ule": [76, 77, 78],
+ "ICMP_sgt": [79, 80, 81],
+ "ICMP_sge": [82, 83, 84],
+ "ICMP_slt": [85, 86, 87],
+ "ICMP_sle": [88, 89, 90]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efee9f6a1..7ef85490b27df 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
"FPTrunc": [133, 134, 135],
"FPExt": [136, 137, 138],
"PtrToInt": [139, 140, 141],
+ "PtrToAddr": [202, 203, 204],
"IntToPtr": [142, 143, 144],
"BitCast": [145, 146, 147],
"AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
"Function": [0, 0, 0],
"Pointer": [0, 0, 0],
"Constant": [0, 0, 0],
- "Variable": [0, 0, 0]
+ "Variable": [0, 0, 0],
+ "FCMP_false": [0, 0, 0],
+ "FCMP_oeq": [0, 0, 0],
+ "FCMP_ogt": [0, 0, 0],
+ "FCMP_oge": [0, 0, 0],
+ "FCMP_olt": [0, 0, 0],
+ "FCMP_ole": [0, 0, 0],
+ "FCMP_one": [0, 0, 0],
+ "FCMP_ord": [0, 0, 0],
+ "FCMP_uno": [0, 0, 0],
+ "FCMP_ueq": [0, 0, 0],
+ "FCMP_ugt": [0, 0, 0],
+ "FCMP_uge": [0, 0, 0],
+ "FCMP_ult": [0, 0, 0],
+ "FCMP_ule": [0, 0, 0],
+ "FCMP_une": [0, 0, 0],
+ "FCMP_true": [0, 0, 0],
+ "ICMP_eq": [0, 0, 0],
+ "ICMP_ne": [0, 0, 0],
+ "ICMP_ugt": [0, 0, 0],
+ "ICMP_uge": [0, 0, 0],
+ "ICMP_ult": [0, 0, 0],
+ "ICMP_ule": [0, 0, 0],
+ "ICMP_sgt": [1, 1, 1],
+ "ICMP_sge": [0, 0, 0],
+ "ICMP_slt": [0, 0, 0],
+ "ICMP_sle": [0, 0, 0]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c9c6a65..d62b0dd157b0b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
Key: Constant: [ 1.00 1.20 ]
Key: Variable: [ 1.40 1.60 ]
+Key: FCMP_false: [ 1.80 2.00 ]
+Key: FCMP_oeq: [ 2.20 2.40 ]
+Key: FCMP_ogt: [ 2.60 2.80 ]
+Key: FCMP_oge: [ 3.00 3.20 ]
+Key: FCMP_olt: [ 3.40 3.60 ]
+Key: FCMP_ole: [ 3.80 4.00 ]
+Key: FCMP_one: [ 4.20 4.40 ]
+Key: FCMP_ord: [ 4.60 4.80 ]
+Key: FCMP_uno: [ 5.00 5.20 ]
+Key: FCMP_ueq: [ 5.40 5.60 ]
+Key: FCMP_ugt: [ 5.80 6.00 ]
+Key: FCMP_uge: [ 6.20 6.40 ]
+Key: FCMP_ult: [ 6.60 6.80 ]
+Key: FCMP_ule: [ 7.00 7.20 ]
+Key: FCMP_une: [ 7.40 7.60 ]
+Key: FCMP_true: [ 7.80 8.00 ]
+Key: ICMP_eq: [ 8.20 8.40 ]
+Key: ICMP_ne: [ 8.60 8.80 ]
+Key: ICMP_ugt: [ 9.00 9.20 ]
+Key: ICMP_uge: [ 9.40 9.60 ]
+Key: ICMP_ult: [ 9.80 10.00 ]
+Key: ICMP_ule: [ 10.20 10.40 ]
+Key: ICMP_sgt: [ 10.60 10.80 ]
+Key: ICMP_sge: [ 11.00 11.20 ]
+Key: ICMP_slt: [ 11.40 11.60 ]
+Key: ICMP_sle: [ 11.80 12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809fd2fd2..e443adb17ac78 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
Key: Constant: [ 2.50 3.00 ]
Key: Variable: [ 3.50 4.00 ]
+Key: FCMP_false: [ 4.50 5.00 ]
+Key: FCMP_oeq: [ 5.50 6.00 ]
+Key: FCMP_ogt: [ 6.50 7.00 ]
+Key: FCMP_oge: [ 7.50 8.00 ]
+Key: FCMP_olt: [ 8.50 9.00 ]
+Key: FCMP_ole: [ 9.50 10.00 ]
+Key: FCMP_one: [ 10.50 11.00 ]
+Key: FCMP_ord: [ 11.50 12.00 ]
+Key: FCMP_uno: [ 12.50 13.00 ]
+Key: FCMP_ueq: [ 13.50 14.00 ]
+Key: FCMP_ugt: [ 14.50 15.00 ]
+Key: FCMP_uge: [ 15.50 16.00 ]
+Key: FCMP_ult: [ 16.50 17.00 ]
+Key: FCMP_ule: [ 17.50 18.00 ]
+Key: FCMP_une: [ 18.50 19.00 ]
+Key: FCMP_true: [ 19.50 20.00 ]
+Key: ICMP_eq: [ 20.50 21.00 ]
+Key: ICMP_ne: [ 21.50 22.00 ]
+Key: ICMP_ugt: [ 22.50 23.00 ]
+Key: ICMP_uge: [ 23.50 24.00 ]
+Key: ICMP_ult: [ 24.50 25.00 ]
+Key: ICMP_ule: [ 25.50 26.00 ]
+Key: ICMP_sgt: [ 26.50 27.00 ]
+Key: ICMP_sge: [ 27.50 28.00 ]
+Key: ICMP_slt: [ 28.50 29.00 ]
+Key: ICMP_sle: [ 29.50 30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9bd3d9c..7fb6043552f7b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.00 0.00 ]
Key: Pointer: [ 0.00 0.00 ]
Key: Constant: [ 0.00 0.00 ]
Key: Variable: [ 0.00 0.00 ]
+Key: FCMP_false: [ 0.00 0.00 ]
+Key: FCMP_oeq: [ 0.00 0.00 ]
+Key: FCMP_ogt: [ 0.00 0.00 ]
+Key: FCMP_oge: [ 0.00 0.00 ]
+Key: FCMP_olt: [ 0.00 0.00 ]
+Key: FCMP_ole: [ 0.00 0.00 ]
+Key: FCMP_one: [ 0.00 0.00 ]
+Key: FCMP_ord: [ 0.00 0.00 ]
+Key: FCMP_uno: [ 0.00 0.00 ]
+Key: FCMP_ueq: [ 0.00 0.00 ]
+Key: FCMP_ugt: [ 0.00 0.00 ]
+Key: FCMP_uge: [ 0.00 0.00 ]
+Key: FCMP_ult: [ 0.00 0.00 ]
+Key: FCMP_ule: [ 0.00 0.00 ]
+Key: FCMP_une: [ 0.00 0.00 ]
+Key: FCMP_true: [ 0.00 0.00 ]
+Key: ICMP_eq: [ 0.00 0.00 ]
+Key: ICMP_ne: [ 0.00 0.00 ]
+Key: ICMP_ugt: [ 0.00 0.00 ]
+Key: ICMP_uge: [ 0.00 0.00 ]
+Key: ICMP_ult: [ 0.00 0.00 ]
+Key: ICMP_ule: [ 0.00 0.00 ]
+Key: ICMP_sgt: [ 0.00 0.00 ]
+Key: ICMP_sge: [ 0.00 0.00 ]
+Key: ICMP_slt: [ 0.00 0.00 ]
+Key: ICMP_sle: [ 0.00 0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe532479086d3..804c1ca5cb6f6 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49978018..9be0ee1c2de7a 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf30bf74..8dbce57302f6f 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
; RUN: llvm-ir2vec entities | FileCheck %s
-CHECK: 84
+CHECK: 110
CHECK-NEXT: Ret 0
CHECK-NEXT: Br 1
CHECK-NEXT: Switch 2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function 80
CHECK-NEXT: Pointer ...
[truncated]
|
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>; | |||
/// embeddings. | |||
class Vocabulary { | |||
friend class llvm::IR2VecVocabAnalysis; | |||
// Slot layout: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC each subrange is fixed length. Within a subrange, you have the embedding of whatever comes after =>
- correct?
Can you clarify this upfront in the comment? One could imagine MaxOpcodes refers to a LLVM IR constant, for example.
Also split the Operands the same way, for consistency.
// Slot layout: | ||
// [0 .. MaxOpcodes-1] => Instruction | ||
// opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => | ||
// Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain what canonicalized means here? (like, in the comment)
@@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { | |||
return OperandKind::VariableID; | |||
} | |||
|
|||
CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) { | |||
assert(Index < MaxPredicateKinds && "Invalid predicate index"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so Index is relative to the predicates subrange, i.e. Index == 0 is the first predicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct.
Comparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space.
Following are the changes:
(Tracking issues: #141817, #141833)