Skip to content

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Sep 4, 2025

Comparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space.

Following are the changes:

  1. Expand the vocabulary slot layout to include predicate entries after opcodes, types, and operands
  2. Add methods to handle predicate embedding lookups and conversions
  3. Update the embedder implementations to include predicate information when processing CmpInst instructions
  4. Update test files to include the new predicate entries in the vocabulary

(Tracking issues: #141817, #141833)

@svkeerthy svkeerthy force-pushed the users/svkeerthy/08-27-default_constructor branch from 8c8500c to fd4e1df Compare September 4, 2025 19:43
@svkeerthy svkeerthy force-pushed the users/svkeerthy/09-03-support_predicates branch from 6185e40 to cfb9293 Compare September 4, 2025 19:44
Base automatically changed from users/svkeerthy/08-27-default_constructor to main September 4, 2025 20:31
@svkeerthy svkeerthy force-pushed the users/svkeerthy/09-03-support_predicates branch from cfb9293 to f81b699 Compare September 4, 2025 20:32
@svkeerthy svkeerthy force-pushed the users/svkeerthy/09-03-support_predicates branch from f81b699 to 5adbeae Compare September 4, 2025 22:46
@svkeerthy svkeerthy changed the title Support predicates [IR2Vec] Add support for compare predicates in embeddings Sep 4, 2025
@svkeerthy svkeerthy changed the title [IR2Vec] Add support for compare predicates in embeddings [IR2Vec] Add support for Cmp predicates in vocabulary and embeddings Sep 4, 2025
@svkeerthy svkeerthy marked this pull request as ready for review September 4, 2025 22:58
@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Sep 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

Comparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space.

Following are the changes:

  1. Expand the vocabulary slot layout to include predicate entries after opcodes, types, and operands
  2. Add methods to handle predicate embedding lookups and conversions
  3. Update the embedder implementations to include predicate information when processing CmpInst instructions
  4. Update test files to include the new predicate entries in the vocabulary

(Tracking issues: #141817, #141833)


Patch is 25.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156952.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+36-7)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+67-9)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json (+27-1)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json (+27-1)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json (+28-1)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt (+26)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt (+26)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt (+26)
  • (modified) llvm/test/Analysis/IR2Vec/if-else.ll (+1-1)
  • (modified) llvm/test/Analysis/IR2Vec/unreachable.ll (+1-1)
  • (modified) llvm/test/tools/llvm-ir2vec/entities.ll (+27-1)
  • (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+1-1)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+46-1)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index b7b881999241e..d49854e2d06a8 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
 #define LLVM_ANALYSIS_IR2VEC_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 /// embeddings.
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
+  // Slot layout:
+  // [0 .. MaxOpcodes-1]                                   => Instruction
+  // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1]      =>
+  // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
+  // Operands
+  //   Within Operands: first OperandKind entries, followed by compare
+  //   predicates
   using VocabVector = std::vector<ir2vec::Embedding>;
   VocabVector Vocab;
+
   bool Valid = false;
+  static constexpr unsigned NumICmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+  static constexpr unsigned NumFCmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
 
 public:
-  // Slot layout:
-  // [0 .. MaxOpcodes-1]               => Instruction opcodes
-  // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
-  // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
-
   /// Canonical type IDs supported by IR2Vec Vocabulary
   enum class CanonicalTypeID : unsigned {
     FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
       static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
+  // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+  // empty slots.
+  static constexpr unsigned MaxPredicateKinds =
+      NumICmpPredicates + NumFCmpPredicates;
 
   Vocabulary() = default;
   LLVM_ABI Vocabulary(VocabVector &&Vocab);
 
   LLVM_ABI bool isValid() const;
   LLVM_ABI unsigned getDimension() const;
-  /// Total number of entries (opcodes + canonicalized types + operand kinds)
+  /// Total number of entries (opcodes + canonicalized types + operand kinds +
+  /// predicates)
   static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
 
   /// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
   /// Function to classify an operand into OperandKind
   LLVM_ABI static OperandKind getOperandKind(const Value *Op);
 
+  /// Function to get vocabulary key for a given predicate
+  LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
   /// Functions to return the slot index or position of a given Opcode, TypeID,
   /// or OperandKind in the vocabulary.
   LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
   LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
   LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+  LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
   LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
   LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+  LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
 
   /// Const Iterator type aliases
   using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
 
 private:
   constexpr static unsigned NumCanonicalEntries =
-      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+  // Base offsets for slot layout to simplify index computation
+  constexpr static unsigned OperandBaseOffset =
+      MaxOpcodes + MaxCanonicalTypeIDs;
+  constexpr static unsigned PredicateBaseOffset =
+      OperandBaseOffset + MaxOperandKinds;
 
   /// String mappings for CanonicalTypeID values
   static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
 
   /// Function to convert TypeID to CanonicalTypeID
   LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+  /// Function to get the predicate enum value for a given index
+  LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
 };
 
 /// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 98849fd922843..2981a6ab9382b 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
       ArgEmb += Vocab[*Op];
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     // embeddings
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    // Add compare predicate embedding as an additional operand if applicable
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
 unsigned Vocabulary::getSlotIndex(const Value &Op) {
   unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
   assert(Index < MaxOperandKinds && "Invalid OperandKind");
-  return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+  return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+  unsigned PU = static_cast<unsigned>(P);
+  unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+  unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+  unsigned PredIdx =
+      (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+  return PredicateBaseOffset + PredIdx;
 }
 
 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
   return Vocab[getSlotIndex(Arg)];
 }
 
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+  return Vocab[getSlotIndex(P)];
+}
+
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
 #define HANDLE_INST(NUM, OPCODE, CLASS)                                        \
@@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+  assert(Index < MaxPredicateKinds && "Invalid predicate index");
+  unsigned PredEnumVal =
+      (Index < NumFCmpPredicates)
+          ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+          : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+             (Index - NumFCmpPredicates));
+  return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+  static SmallString<16> PredNameBuffer;
+  if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+    PredNameBuffer = "FCMP_";
+  else
+    PredNameBuffer = "ICMP_";
+  PredNameBuffer += CmpInst::getPredicateName(Pred);
+  return PredNameBuffer;
+}
+
 StringRef Vocabulary::getStringKey(unsigned Pos) {
   assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
     return getVocabKeyForOpcode(Pos + 1);
   // Type
-  if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+  if (Pos < OperandBaseOffset)
     return getVocabKeyForCanonicalTypeID(
         static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
   // Operand
-  return getVocabKeyForOperandKind(
-      static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+  if (Pos < PredicateBaseOffset)
+    return getVocabKeyForOperandKind(
+        static_cast<OperandKind>(Pos - OperandBaseOffset));
+  // Predicates
+  return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
 }
 
 // For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +412,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
   VocabVector DummyVocab;
   DummyVocab.reserve(NumCanonicalEntries);
   float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, and
-  // operands
-  for ([[maybe_unused]] unsigned _ :
-       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
-                   Vocabulary::MaxOperandKinds)) {
+  // Create a dummy vocabulary with entries for all opcodes, types, operands
+  // and predicates
+  for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
     DummyVocab.push_back(Embedding(Dim, DummyVal));
     DummyVal += 0.1f;
   }
@@ -517,6 +557,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   }
   Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
                NumericArgEmbeddings.end());
+
+  // Handle Predicates: part of Operands section. We look up predicate keys
+  // in ArgVocab.
+  std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+                                               Embedding(Dim, 0));
+  NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+  for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+    StringRef VocabKey =
+        Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+    auto It = ArgVocab.find(VocabKey.str());
+    if (It != ArgVocab.end()) {
+      NumericPredEmbeddings[PK] = It->second;
+      continue;
+    }
+    handleMissingEntity(VocabKey.str());
+  }
+  Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+               NumericPredEmbeddings.end());
 }
 
 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84c1541b..ae36ff54686c5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
         "Function": [1, 2],
         "Pointer": [3, 4],
         "Constant": [5, 6],
-        "Variable": [7, 8]
+        "Variable": [7, 8],
+        "FCMP_false": [9, 10],
+        "FCMP_oeq": [11, 12], 
+        "FCMP_ogt": [13, 14], 
+        "FCMP_oge": [15, 16], 
+        "FCMP_olt": [17, 18], 
+        "FCMP_ole": [19, 20], 
+        "FCMP_one": [21, 22], 
+        "FCMP_ord": [23, 24], 
+        "FCMP_uno": [25, 26], 
+        "FCMP_ueq": [27, 28], 
+        "FCMP_ugt": [29, 30], 
+        "FCMP_uge": [31, 32], 
+        "FCMP_ult": [33, 34], 
+        "FCMP_ule": [35, 36], 
+        "FCMP_une": [37, 38], 
+        "FCMP_true": [39, 40], 
+        "ICMP_eq": [41, 42], 
+        "ICMP_ne": [43, 44], 
+        "ICMP_ugt": [45, 46], 
+        "ICMP_uge": [47, 48], 
+        "ICMP_ult": [49, 50], 
+        "ICMP_ule": [51, 52], 
+        "ICMP_sgt": [53, 54], 
+        "ICMP_sge": [55, 56], 
+        "ICMP_slt": [57, 58], 
+        "ICMP_sle": [59, 60]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a217b70c..9003dc73954aa 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
         "Function": [1, 2, 3],
         "Pointer": [4, 5, 6],
         "Constant": [7, 8, 9],
-        "Variable": [10, 11, 12]
+        "Variable": [10, 11, 12],
+        "FCMP_false": [13, 14, 15],
+        "FCMP_oeq": [16, 17, 18],
+        "FCMP_ogt": [19, 20, 21],
+        "FCMP_oge": [22, 23, 24],
+        "FCMP_olt": [25, 26, 27],
+        "FCMP_ole": [28, 29, 30],
+        "FCMP_one": [31, 32, 33],
+        "FCMP_ord": [34, 35, 36],
+        "FCMP_uno": [37, 38, 39],
+        "FCMP_ueq": [40, 41, 42],
+        "FCMP_ugt": [43, 44, 45],
+        "FCMP_uge": [46, 47, 48],
+        "FCMP_ult": [49, 50, 51],
+        "FCMP_ule": [52, 53, 54],
+        "FCMP_une": [55, 56, 57],
+        "FCMP_true": [58, 59, 60],        
+        "ICMP_eq": [61, 62, 63],
+        "ICMP_ne": [64, 65, 66],
+        "ICMP_ugt": [67, 68, 69],
+        "ICMP_uge": [70, 71, 72],
+        "ICMP_ult": [73, 74, 75],
+        "ICMP_ule": [76, 77, 78],
+        "ICMP_sgt": [79, 80, 81],
+        "ICMP_sge": [82, 83, 84],
+        "ICMP_slt": [85, 86, 87],
+        "ICMP_sle": [88, 89, 90]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efee9f6a1..7ef85490b27df 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
         "FPTrunc": [133, 134, 135],
         "FPExt": [136, 137, 138],
         "PtrToInt": [139, 140, 141],
+        "PtrToAddr": [202, 203, 204],
         "IntToPtr": [142, 143, 144],
         "BitCast": [145, 146, 147],
         "AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
         "Function": [0, 0, 0],
         "Pointer": [0, 0, 0],
         "Constant": [0, 0, 0],
-        "Variable": [0, 0, 0]
+        "Variable": [0, 0, 0],
+        "FCMP_false": [0, 0, 0],
+        "FCMP_oeq": [0, 0, 0],
+        "FCMP_ogt": [0, 0, 0],
+        "FCMP_oge": [0, 0, 0],
+        "FCMP_olt": [0, 0, 0],
+        "FCMP_ole": [0, 0, 0],
+        "FCMP_one": [0, 0, 0],
+        "FCMP_ord": [0, 0, 0],
+        "FCMP_uno": [0, 0, 0],
+        "FCMP_ueq": [0, 0, 0],
+        "FCMP_ugt": [0, 0, 0],
+        "FCMP_uge": [0, 0, 0],
+        "FCMP_ult": [0, 0, 0],
+        "FCMP_ule": [0, 0, 0],
+        "FCMP_une": [0, 0, 0],
+        "FCMP_true": [0, 0, 0],
+        "ICMP_eq": [0, 0, 0],
+        "ICMP_ne": [0, 0, 0],
+        "ICMP_ugt": [0, 0, 0],
+        "ICMP_uge": [0, 0, 0],
+        "ICMP_ult": [0, 0, 0],
+        "ICMP_ule": [0, 0, 0],
+        "ICMP_sgt": [1, 1, 1],
+        "ICMP_sge": [0, 0, 0],
+        "ICMP_slt": [0, 0, 0],
+        "ICMP_sle": [0, 0, 0]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c9c6a65..d62b0dd157b0b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.20  0.40 ]
 Key: Pointer:  [ 0.60  0.80 ]
 Key: Constant:  [ 1.00  1.20 ]
 Key: Variable:  [ 1.40  1.60 ]
+Key: FCMP_false:  [ 1.80  2.00 ]
+Key: FCMP_oeq:  [ 2.20  2.40 ]
+Key: FCMP_ogt:  [ 2.60  2.80 ]
+Key: FCMP_oge:  [ 3.00  3.20 ]
+Key: FCMP_olt:  [ 3.40  3.60 ]
+Key: FCMP_ole:  [ 3.80  4.00 ]
+Key: FCMP_one:  [ 4.20  4.40 ]
+Key: FCMP_ord:  [ 4.60  4.80 ]
+Key: FCMP_uno:  [ 5.00  5.20 ]
+Key: FCMP_ueq:  [ 5.40  5.60 ]
+Key: FCMP_ugt:  [ 5.80  6.00 ]
+Key: FCMP_uge:  [ 6.20  6.40 ]
+Key: FCMP_ult:  [ 6.60  6.80 ]
+Key: FCMP_ule:  [ 7.00  7.20 ]
+Key: FCMP_une:  [ 7.40  7.60 ]
+Key: FCMP_true:  [ 7.80  8.00 ]
+Key: ICMP_eq:  [ 8.20  8.40 ]
+Key: ICMP_ne:  [ 8.60  8.80 ]
+Key: ICMP_ugt:  [ 9.00  9.20 ]
+Key: ICMP_uge:  [ 9.40  9.60 ]
+Key: ICMP_ult:  [ 9.80  10.00 ]
+Key: ICMP_ule:  [ 10.20  10.40 ]
+Key: ICMP_sgt:  [ 10.60  10.80 ]
+Key: ICMP_sge:  [ 11.00  11.20 ]
+Key: ICMP_slt:  [ 11.40  11.60 ]
+Key: ICMP_sle:  [ 11.80  12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809fd2fd2..e443adb17ac78 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.50  1.00 ]
 Key: Pointer:  [ 1.50  2.00 ]
 Key: Constant:  [ 2.50  3.00 ]
 Key: Variable:  [ 3.50  4.00 ]
+Key: FCMP_false:  [ 4.50  5.00 ]
+Key: FCMP_oeq:  [ 5.50  6.00 ]
+Key: FCMP_ogt:  [ 6.50  7.00 ]
+Key: FCMP_oge:  [ 7.50  8.00 ]
+Key: FCMP_olt:  [ 8.50  9.00 ]
+Key: FCMP_ole:  [ 9.50  10.00 ]
+Key: FCMP_one:  [ 10.50  11.00 ]
+Key: FCMP_ord:  [ 11.50  12.00 ]
+Key: FCMP_uno:  [ 12.50  13.00 ]
+Key: FCMP_ueq:  [ 13.50  14.00 ]
+Key: FCMP_ugt:  [ 14.50  15.00 ]
+Key: FCMP_uge:  [ 15.50  16.00 ]
+Key: FCMP_ult:  [ 16.50  17.00 ]
+Key: FCMP_ule:  [ 17.50  18.00 ]
+Key: FCMP_une:  [ 18.50  19.00 ]
+Key: FCMP_true:  [ 19.50  20.00 ]
+Key: ICMP_eq:  [ 20.50  21.00 ]
+Key: ICMP_ne:  [ 21.50  22.00 ]
+Key: ICMP_ugt:  [ 22.50  23.00 ]
+Key: ICMP_uge:  [ 23.50  24.00 ]
+Key: ICMP_ult:  [ 24.50  25.00 ]
+Key: ICMP_ule:  [ 25.50  26.00 ]
+Key: ICMP_sgt:  [ 26.50  27.00 ]
+Key: ICMP_sge:  [ 27.50  28.00 ]
+Key: ICMP_slt:  [ 28.50  29.00 ]
+Key: ICMP_sle:  [ 29.50  30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9bd3d9c..7fb6043552f7b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.00  0.00 ]
 Key: Pointer:  [ 0.00  0.00 ]
 Key: Constant:  [ 0.00  0.00 ]
 Key: Variable:  [ 0.00  0.00 ]
+Key: FCMP_false:  [ 0.00  0.00 ]
+Key: FCMP_oeq:  [ 0.00  0.00 ]
+Key: FCMP_ogt:  [ 0.00  0.00 ]
+Key: FCMP_oge:  [ 0.00  0.00 ]
+Key: FCMP_olt:  [ 0.00  0.00 ]
+Key: FCMP_ole:  [ 0.00  0.00 ]
+Key: FCMP_one:  [ 0.00  0.00 ]
+Key: FCMP_ord:  [ 0.00  0.00 ]
+Key: FCMP_uno:  [ 0.00  0.00 ]
+Key: FCMP_ueq:  [ 0.00  0.00 ]
+Key: FCMP_ugt:  [ 0.00  0.00 ]
+Key: FCMP_uge:  [ 0.00  0.00 ]
+Key: FCMP_ult:  [ 0.00  0.00 ]
+Key: FCMP_ule:  [ 0.00  0.00 ]
+Key: FCMP_une:  [ 0.00  0.00 ]
+Key: FCMP_true:  [ 0.00  0.00 ]
+Key: ICMP_eq:  [ 0.00  0.00 ]
+Key: ICMP_ne:  [ 0.00  0.00 ]
+Key: ICMP_ugt:  [ 0.00  0.00 ]
+Key: ICMP_uge:  [ 0.00  0.00 ]
+Key: ICMP_ult:  [ 0.00  0.00 ]
+Key: ICMP_ule:  [ 0.00  0.00 ]
+Key: ICMP_sgt:  [ 0.00  0.00 ]
+Key: ICMP_sge:  [ 0.00  0.00 ]
+Key: ICMP_slt:  [ 0.00  0.00 ]
+Key: ICMP_sle:  [ 0.00  0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe532479086d3..804c1ca5cb6f6 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return:                                           ; preds = %if.else, %if.then
 
 ; CHECK: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00  825.00  834.00 ]
+; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49978018..9be0ee1c2de7a 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return:                                           ; preds = %if.else, %if.then
 
 ; CHECK: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00  825.00  834.00 ]
+; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf30bf74..8dbce57302f6f 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
 ; RUN: llvm-ir2vec entities | FileCheck %s
 
-CHECK: 84
+CHECK: 110
 CHECK-NEXT: Ret     0
 CHECK-NEXT: Br      1
 CHECK-NEXT: Switch  2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function        80
 CHECK-NEXT: Pointer ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

Comparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space.

Following are the changes:

  1. Expand the vocabulary slot layout to include predicate entries after opcodes, types, and operands
  2. Add methods to handle predicate embedding lookups and conversions
  3. Update the embedder implementations to include predicate information when processing CmpInst instructions
  4. Update test files to include the new predicate entries in the vocabulary

(Tracking issues: #141817, #141833)


Patch is 25.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156952.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+36-7)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+67-9)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json (+27-1)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json (+27-1)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json (+28-1)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt (+26)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt (+26)
  • (modified) llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt (+26)
  • (modified) llvm/test/Analysis/IR2Vec/if-else.ll (+1-1)
  • (modified) llvm/test/Analysis/IR2Vec/unreachable.ll (+1-1)
  • (modified) llvm/test/tools/llvm-ir2vec/entities.ll (+27-1)
  • (modified) llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp (+1-1)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+46-1)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index b7b881999241e..d49854e2d06a8 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
 #define LLVM_ANALYSIS_IR2VEC_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 /// embeddings.
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
+  // Slot layout:
+  // [0 .. MaxOpcodes-1]                                   => Instruction
+  // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1]      =>
+  // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
+  // Operands
+  //   Within Operands: first OperandKind entries, followed by compare
+  //   predicates
   using VocabVector = std::vector<ir2vec::Embedding>;
   VocabVector Vocab;
+
   bool Valid = false;
+  static constexpr unsigned NumICmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+  static constexpr unsigned NumFCmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
 
 public:
-  // Slot layout:
-  // [0 .. MaxOpcodes-1]               => Instruction opcodes
-  // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
-  // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
-
   /// Canonical type IDs supported by IR2Vec Vocabulary
   enum class CanonicalTypeID : unsigned {
     FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
       static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
+  // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+  // empty slots.
+  static constexpr unsigned MaxPredicateKinds =
+      NumICmpPredicates + NumFCmpPredicates;
 
   Vocabulary() = default;
   LLVM_ABI Vocabulary(VocabVector &&Vocab);
 
   LLVM_ABI bool isValid() const;
   LLVM_ABI unsigned getDimension() const;
-  /// Total number of entries (opcodes + canonicalized types + operand kinds)
+  /// Total number of entries (opcodes + canonicalized types + operand kinds +
+  /// predicates)
   static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
 
   /// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
   /// Function to classify an operand into OperandKind
   LLVM_ABI static OperandKind getOperandKind(const Value *Op);
 
+  /// Function to get vocabulary key for a given predicate
+  LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
   /// Functions to return the slot index or position of a given Opcode, TypeID,
   /// or OperandKind in the vocabulary.
   LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
   LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
   LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+  LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
   LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
   LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+  LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
 
   /// Const Iterator type aliases
   using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
 
 private:
   constexpr static unsigned NumCanonicalEntries =
-      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+  // Base offsets for slot layout to simplify index computation
+  constexpr static unsigned OperandBaseOffset =
+      MaxOpcodes + MaxCanonicalTypeIDs;
+  constexpr static unsigned PredicateBaseOffset =
+      OperandBaseOffset + MaxOperandKinds;
 
   /// String mappings for CanonicalTypeID values
   static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
 
   /// Function to convert TypeID to CanonicalTypeID
   LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+  /// Function to get the predicate enum value for a given index
+  LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
 };
 
 /// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 98849fd922843..2981a6ab9382b 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
       ArgEmb += Vocab[*Op];
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     // embeddings
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    // Add compare predicate embedding as an additional operand if applicable
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
 unsigned Vocabulary::getSlotIndex(const Value &Op) {
   unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
   assert(Index < MaxOperandKinds && "Invalid OperandKind");
-  return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+  return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+  unsigned PU = static_cast<unsigned>(P);
+  unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+  unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+  unsigned PredIdx =
+      (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+  return PredicateBaseOffset + PredIdx;
 }
 
 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
   return Vocab[getSlotIndex(Arg)];
 }
 
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+  return Vocab[getSlotIndex(P)];
+}
+
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
 #define HANDLE_INST(NUM, OPCODE, CLASS)                                        \
@@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+  assert(Index < MaxPredicateKinds && "Invalid predicate index");
+  unsigned PredEnumVal =
+      (Index < NumFCmpPredicates)
+          ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+          : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+             (Index - NumFCmpPredicates));
+  return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+  static SmallString<16> PredNameBuffer;
+  if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+    PredNameBuffer = "FCMP_";
+  else
+    PredNameBuffer = "ICMP_";
+  PredNameBuffer += CmpInst::getPredicateName(Pred);
+  return PredNameBuffer;
+}
+
 StringRef Vocabulary::getStringKey(unsigned Pos) {
   assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
     return getVocabKeyForOpcode(Pos + 1);
   // Type
-  if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+  if (Pos < OperandBaseOffset)
     return getVocabKeyForCanonicalTypeID(
         static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
   // Operand
-  return getVocabKeyForOperandKind(
-      static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+  if (Pos < PredicateBaseOffset)
+    return getVocabKeyForOperandKind(
+        static_cast<OperandKind>(Pos - OperandBaseOffset));
+  // Predicates
+  return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
 }
 
 // For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +412,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
   VocabVector DummyVocab;
   DummyVocab.reserve(NumCanonicalEntries);
   float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, and
-  // operands
-  for ([[maybe_unused]] unsigned _ :
-       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
-                   Vocabulary::MaxOperandKinds)) {
+  // Create a dummy vocabulary with entries for all opcodes, types, operands
+  // and predicates
+  for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
     DummyVocab.push_back(Embedding(Dim, DummyVal));
     DummyVal += 0.1f;
   }
@@ -517,6 +557,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   }
   Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
                NumericArgEmbeddings.end());
+
+  // Handle Predicates: part of Operands section. We look up predicate keys
+  // in ArgVocab.
+  std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+                                               Embedding(Dim, 0));
+  NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+  for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+    StringRef VocabKey =
+        Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+    auto It = ArgVocab.find(VocabKey.str());
+    if (It != ArgVocab.end()) {
+      NumericPredEmbeddings[PK] = It->second;
+      continue;
+    }
+    handleMissingEntity(VocabKey.str());
+  }
+  Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+               NumericPredEmbeddings.end());
 }
 
 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84c1541b..ae36ff54686c5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
         "Function": [1, 2],
         "Pointer": [3, 4],
         "Constant": [5, 6],
-        "Variable": [7, 8]
+        "Variable": [7, 8],
+        "FCMP_false": [9, 10],
+        "FCMP_oeq": [11, 12], 
+        "FCMP_ogt": [13, 14], 
+        "FCMP_oge": [15, 16], 
+        "FCMP_olt": [17, 18], 
+        "FCMP_ole": [19, 20], 
+        "FCMP_one": [21, 22], 
+        "FCMP_ord": [23, 24], 
+        "FCMP_uno": [25, 26], 
+        "FCMP_ueq": [27, 28], 
+        "FCMP_ugt": [29, 30], 
+        "FCMP_uge": [31, 32], 
+        "FCMP_ult": [33, 34], 
+        "FCMP_ule": [35, 36], 
+        "FCMP_une": [37, 38], 
+        "FCMP_true": [39, 40], 
+        "ICMP_eq": [41, 42], 
+        "ICMP_ne": [43, 44], 
+        "ICMP_ugt": [45, 46], 
+        "ICMP_uge": [47, 48], 
+        "ICMP_ult": [49, 50], 
+        "ICMP_ule": [51, 52], 
+        "ICMP_sgt": [53, 54], 
+        "ICMP_sge": [55, 56], 
+        "ICMP_slt": [57, 58], 
+        "ICMP_sle": [59, 60]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a217b70c..9003dc73954aa 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
         "Function": [1, 2, 3],
         "Pointer": [4, 5, 6],
         "Constant": [7, 8, 9],
-        "Variable": [10, 11, 12]
+        "Variable": [10, 11, 12],
+        "FCMP_false": [13, 14, 15],
+        "FCMP_oeq": [16, 17, 18],
+        "FCMP_ogt": [19, 20, 21],
+        "FCMP_oge": [22, 23, 24],
+        "FCMP_olt": [25, 26, 27],
+        "FCMP_ole": [28, 29, 30],
+        "FCMP_one": [31, 32, 33],
+        "FCMP_ord": [34, 35, 36],
+        "FCMP_uno": [37, 38, 39],
+        "FCMP_ueq": [40, 41, 42],
+        "FCMP_ugt": [43, 44, 45],
+        "FCMP_uge": [46, 47, 48],
+        "FCMP_ult": [49, 50, 51],
+        "FCMP_ule": [52, 53, 54],
+        "FCMP_une": [55, 56, 57],
+        "FCMP_true": [58, 59, 60],        
+        "ICMP_eq": [61, 62, 63],
+        "ICMP_ne": [64, 65, 66],
+        "ICMP_ugt": [67, 68, 69],
+        "ICMP_uge": [70, 71, 72],
+        "ICMP_ult": [73, 74, 75],
+        "ICMP_ule": [76, 77, 78],
+        "ICMP_sgt": [79, 80, 81],
+        "ICMP_sge": [82, 83, 84],
+        "ICMP_slt": [85, 86, 87],
+        "ICMP_sle": [88, 89, 90]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efee9f6a1..7ef85490b27df 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
         "FPTrunc": [133, 134, 135],
         "FPExt": [136, 137, 138],
         "PtrToInt": [139, 140, 141],
+        "PtrToAddr": [202, 203, 204],
         "IntToPtr": [142, 143, 144],
         "BitCast": [145, 146, 147],
         "AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
         "Function": [0, 0, 0],
         "Pointer": [0, 0, 0],
         "Constant": [0, 0, 0],
-        "Variable": [0, 0, 0]
+        "Variable": [0, 0, 0],
+        "FCMP_false": [0, 0, 0],
+        "FCMP_oeq": [0, 0, 0],
+        "FCMP_ogt": [0, 0, 0],
+        "FCMP_oge": [0, 0, 0],
+        "FCMP_olt": [0, 0, 0],
+        "FCMP_ole": [0, 0, 0],
+        "FCMP_one": [0, 0, 0],
+        "FCMP_ord": [0, 0, 0],
+        "FCMP_uno": [0, 0, 0],
+        "FCMP_ueq": [0, 0, 0],
+        "FCMP_ugt": [0, 0, 0],
+        "FCMP_uge": [0, 0, 0],
+        "FCMP_ult": [0, 0, 0],
+        "FCMP_ule": [0, 0, 0],
+        "FCMP_une": [0, 0, 0],
+        "FCMP_true": [0, 0, 0],
+        "ICMP_eq": [0, 0, 0],
+        "ICMP_ne": [0, 0, 0],
+        "ICMP_ugt": [0, 0, 0],
+        "ICMP_uge": [0, 0, 0],
+        "ICMP_ult": [0, 0, 0],
+        "ICMP_ule": [0, 0, 0],
+        "ICMP_sgt": [1, 1, 1],
+        "ICMP_sge": [0, 0, 0],
+        "ICMP_slt": [0, 0, 0],
+        "ICMP_sle": [0, 0, 0]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c9c6a65..d62b0dd157b0b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.20  0.40 ]
 Key: Pointer:  [ 0.60  0.80 ]
 Key: Constant:  [ 1.00  1.20 ]
 Key: Variable:  [ 1.40  1.60 ]
+Key: FCMP_false:  [ 1.80  2.00 ]
+Key: FCMP_oeq:  [ 2.20  2.40 ]
+Key: FCMP_ogt:  [ 2.60  2.80 ]
+Key: FCMP_oge:  [ 3.00  3.20 ]
+Key: FCMP_olt:  [ 3.40  3.60 ]
+Key: FCMP_ole:  [ 3.80  4.00 ]
+Key: FCMP_one:  [ 4.20  4.40 ]
+Key: FCMP_ord:  [ 4.60  4.80 ]
+Key: FCMP_uno:  [ 5.00  5.20 ]
+Key: FCMP_ueq:  [ 5.40  5.60 ]
+Key: FCMP_ugt:  [ 5.80  6.00 ]
+Key: FCMP_uge:  [ 6.20  6.40 ]
+Key: FCMP_ult:  [ 6.60  6.80 ]
+Key: FCMP_ule:  [ 7.00  7.20 ]
+Key: FCMP_une:  [ 7.40  7.60 ]
+Key: FCMP_true:  [ 7.80  8.00 ]
+Key: ICMP_eq:  [ 8.20  8.40 ]
+Key: ICMP_ne:  [ 8.60  8.80 ]
+Key: ICMP_ugt:  [ 9.00  9.20 ]
+Key: ICMP_uge:  [ 9.40  9.60 ]
+Key: ICMP_ult:  [ 9.80  10.00 ]
+Key: ICMP_ule:  [ 10.20  10.40 ]
+Key: ICMP_sgt:  [ 10.60  10.80 ]
+Key: ICMP_sge:  [ 11.00  11.20 ]
+Key: ICMP_slt:  [ 11.40  11.60 ]
+Key: ICMP_sle:  [ 11.80  12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809fd2fd2..e443adb17ac78 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.50  1.00 ]
 Key: Pointer:  [ 1.50  2.00 ]
 Key: Constant:  [ 2.50  3.00 ]
 Key: Variable:  [ 3.50  4.00 ]
+Key: FCMP_false:  [ 4.50  5.00 ]
+Key: FCMP_oeq:  [ 5.50  6.00 ]
+Key: FCMP_ogt:  [ 6.50  7.00 ]
+Key: FCMP_oge:  [ 7.50  8.00 ]
+Key: FCMP_olt:  [ 8.50  9.00 ]
+Key: FCMP_ole:  [ 9.50  10.00 ]
+Key: FCMP_one:  [ 10.50  11.00 ]
+Key: FCMP_ord:  [ 11.50  12.00 ]
+Key: FCMP_uno:  [ 12.50  13.00 ]
+Key: FCMP_ueq:  [ 13.50  14.00 ]
+Key: FCMP_ugt:  [ 14.50  15.00 ]
+Key: FCMP_uge:  [ 15.50  16.00 ]
+Key: FCMP_ult:  [ 16.50  17.00 ]
+Key: FCMP_ule:  [ 17.50  18.00 ]
+Key: FCMP_une:  [ 18.50  19.00 ]
+Key: FCMP_true:  [ 19.50  20.00 ]
+Key: ICMP_eq:  [ 20.50  21.00 ]
+Key: ICMP_ne:  [ 21.50  22.00 ]
+Key: ICMP_ugt:  [ 22.50  23.00 ]
+Key: ICMP_uge:  [ 23.50  24.00 ]
+Key: ICMP_ult:  [ 24.50  25.00 ]
+Key: ICMP_ule:  [ 25.50  26.00 ]
+Key: ICMP_sgt:  [ 26.50  27.00 ]
+Key: ICMP_sge:  [ 27.50  28.00 ]
+Key: ICMP_slt:  [ 28.50  29.00 ]
+Key: ICMP_sle:  [ 29.50  30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9bd3d9c..7fb6043552f7b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.00  0.00 ]
 Key: Pointer:  [ 0.00  0.00 ]
 Key: Constant:  [ 0.00  0.00 ]
 Key: Variable:  [ 0.00  0.00 ]
+Key: FCMP_false:  [ 0.00  0.00 ]
+Key: FCMP_oeq:  [ 0.00  0.00 ]
+Key: FCMP_ogt:  [ 0.00  0.00 ]
+Key: FCMP_oge:  [ 0.00  0.00 ]
+Key: FCMP_olt:  [ 0.00  0.00 ]
+Key: FCMP_ole:  [ 0.00  0.00 ]
+Key: FCMP_one:  [ 0.00  0.00 ]
+Key: FCMP_ord:  [ 0.00  0.00 ]
+Key: FCMP_uno:  [ 0.00  0.00 ]
+Key: FCMP_ueq:  [ 0.00  0.00 ]
+Key: FCMP_ugt:  [ 0.00  0.00 ]
+Key: FCMP_uge:  [ 0.00  0.00 ]
+Key: FCMP_ult:  [ 0.00  0.00 ]
+Key: FCMP_ule:  [ 0.00  0.00 ]
+Key: FCMP_une:  [ 0.00  0.00 ]
+Key: FCMP_true:  [ 0.00  0.00 ]
+Key: ICMP_eq:  [ 0.00  0.00 ]
+Key: ICMP_ne:  [ 0.00  0.00 ]
+Key: ICMP_ugt:  [ 0.00  0.00 ]
+Key: ICMP_uge:  [ 0.00  0.00 ]
+Key: ICMP_ult:  [ 0.00  0.00 ]
+Key: ICMP_ule:  [ 0.00  0.00 ]
+Key: ICMP_sgt:  [ 0.00  0.00 ]
+Key: ICMP_sge:  [ 0.00  0.00 ]
+Key: ICMP_slt:  [ 0.00  0.00 ]
+Key: ICMP_sle:  [ 0.00  0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe532479086d3..804c1ca5cb6f6 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return:                                           ; preds = %if.else, %if.then
 
 ; CHECK: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00  825.00  834.00 ]
+; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49978018..9be0ee1c2de7a 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return:                                           ; preds = %if.else, %if.then
 
 ; CHECK: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00  825.00  834.00 ]
+; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf30bf74..8dbce57302f6f 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
 ; RUN: llvm-ir2vec entities | FileCheck %s
 
-CHECK: 84
+CHECK: 110
 CHECK-NEXT: Ret     0
 CHECK-NEXT: Br      1
 CHECK-NEXT: Switch  2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function        80
 CHECK-NEXT: Pointer ...
[truncated]

@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
// Slot layout:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC each subrange is fixed length. Within a subrange, you have the embedding of whatever comes after => - correct?

Can you clarify this upfront in the comment? One could imagine MaxOpcodes refers to a LLVM IR constant, for example.

Also split the Operands the same way, for consistency.

// Slot layout:
// [0 .. MaxOpcodes-1] => Instruction
// opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] =>
// Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
Copy link
Member

@mtrofin mtrofin Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what canonicalized means here? (like, in the comment)

@@ -345,18 +364,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}

CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
assert(Index < MaxPredicateKinds && "Invalid predicate index");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so Index is relative to the predicates subrange, i.e. Index == 0 is the first predicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants