Skip to content

Conversation

mtrofin
Copy link
Member

@mtrofin mtrofin commented Aug 27, 2025

We can compute the entry count of branch funnel functions, and potentially avoid them being deemed cold (also, keeping profile information coherent is always good for performance)

Issue #147390

Copy link
Member Author

mtrofin commented Aug 27, 2025

Copy link

github-actions bot commented Aug 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 59546ec to 4fe3e48 Compare August 28, 2025 02:32
@mtrofin mtrofin marked this pull request as ready for review August 28, 2025 02:32
@mtrofin mtrofin requested review from pcc and teresajohnson August 28, 2025 02:33
@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Mircea Trofin (mtrofin)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155657.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+5-1)
  • (modified) llvm/lib/IR/ProfDataUtils.cpp (+10-2)
  • (modified) llvm/lib/IR/Verifier.cpp (+4-5)
  • (modified) llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp (+40-9)
  • (added) llvm/test/Transforms/WholeProgramDevirt/branch-funnel-profile.ll (+203)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..b8386ddc86ca8 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -180,7 +180,11 @@ inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
 /// info.
 LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I);
 
-LLVM_ABI bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD);
+/// Analogous to setExplicitlyUnknownBranchWeights, but for functions and their
+/// entry counts.
+LLVM_ABI void setExplicitlyUnknownFunctionEntryCount(Function &F);
+
+LLVM_ABI bool isExplicitlyUnknownProfileMetadata(const MDNode &MD);
 LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
 
 /// Scaling the profile data attached to 'I' using the ratio of S/T.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index d24263f8b3bda..b41256f599096 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -250,7 +250,15 @@ void setExplicitlyUnknownBranchWeights(Instruction &I) {
                   MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
 }
 
-bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) {
+void setExplicitlyUnknownFunctionEntryCount(Function &F) {
+  MDBuilder MDB(F.getContext());
+  F.setMetadata(
+      LLVMContext::MD_prof,
+      MDNode::get(F.getContext(),
+                  MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
+}
+
+bool isExplicitlyUnknownProfileMetadata(const MDNode &MD) {
   if (MD.getNumOperands() != 1)
     return false;
   return MD.getOperand(0).equalsStr(MDProfLabels::UnknownBranchWeightsMarker);
@@ -260,7 +268,7 @@ bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
   auto *MD = I.getMetadata(LLVMContext::MD_prof);
   if (!MD)
     return false;
-  return isExplicitlyUnknownBranchWeightsMetadata(*MD);
+  return isExplicitlyUnknownProfileMetadata(*MD);
 }
 
 void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 9fda08645e118..afeeecb39ddc8 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2526,12 +2526,11 @@ void Verifier::verifyFunctionMetadata(
   for (const auto &Pair : MDs) {
     if (Pair.first == LLVMContext::MD_prof) {
       MDNode *MD = Pair.second;
-      if (isExplicitlyUnknownBranchWeightsMetadata(*MD)) {
-        CheckFailed("'unknown' !prof metadata should appear only on "
-                    "instructions supporting the 'branch_weights' metadata",
-                    MD);
+      // We may have functions that are synthesized by the compiler, e.g. in
+      // WPD, that we can't currently determine the entry count.
+      if (isExplicitlyUnknownProfileMetadata(*MD))
         continue;
-      }
+
       Check(MD->getNumOperands() >= 2,
             "!prof annotations should have no less than 2 operands", MD);
 
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index cb98ed838f5d7..59f59265e83f9 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -60,6 +60,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BasicAliasAnalysis.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/Bitcode/BitcodeReader.h"
@@ -84,6 +85,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/ModuleSummaryIndexYAML.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Errc.h"
@@ -97,6 +99,7 @@
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
 #include "llvm/Transforms/Utils/Evaluator.h"
 #include <algorithm>
+#include <cmath>
 #include <cstddef>
 #include <map>
 #include <set>
@@ -169,6 +172,8 @@ static cl::list<std::string>
                       cl::desc("Prevent function(s) from being devirtualized"),
                       cl::Hidden, cl::CommaSeparated);
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 /// With Clang, a pure virtual class's deleting destructor is emitted as a
 /// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the
 /// context of whole program devirtualization, the deleting destructor of a pure
@@ -656,7 +661,7 @@ struct DevirtModule {
                            VTableSlotInfo &SlotInfo,
                            WholeProgramDevirtResolution *Res);
 
-  void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
+  void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Function &JT,
                               bool &IsExported);
   void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
                             VTableSlotInfo &SlotInfo,
@@ -1453,7 +1458,7 @@ void DevirtModule::tryICallBranchFunnel(
 
   FunctionType *FT =
       FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
-  Function *JT;
+  Function *JT = nullptr;
   if (isa<MDString>(Slot.TypeID)) {
     JT = Function::Create(FT, Function::ExternalLinkage,
                           M.getDataLayout().getProgramAddressSpace(),
@@ -1482,13 +1487,19 @@ void DevirtModule::tryICallBranchFunnel(
   ReturnInst::Create(M.getContext(), nullptr, BB);
 
   bool IsExported = false;
-  applyICallBranchFunnel(SlotInfo, JT, IsExported);
+  applyICallBranchFunnel(SlotInfo, *JT, IsExported);
   if (IsExported)
     Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
+
+  if (!JT->getEntryCount().has_value()) {
+    // FIXME: we could pass through thinlto the necessary information.
+    setExplicitlyUnknownFunctionEntryCount(*JT);
+  }
 }
 
 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
-                                          Constant *JT, bool &IsExported) {
+                                          Function &JT, bool &IsExported) {
+  DenseMap<Function *, double> FunctionEntryCounts;
   auto Apply = [&](CallSiteInfo &CSInfo) {
     if (CSInfo.isExported())
       IsExported = true;
@@ -1517,7 +1528,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
       NumBranchFunnel++;
       if (RemarksEnabled)
         VCallSite.emitRemark("branch-funnel",
-                             JT->stripPointerCasts()->getName(), OREGetter);
+                             JT.stripPointerCasts()->getName(), OREGetter);
 
       // Pass the address of the vtable in the nest register, which is r10 on
       // x86_64.
@@ -1533,11 +1544,26 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
       llvm::append_range(Args, CB.args());
 
       CallBase *NewCS = nullptr;
+      if (!JT.isDeclaration() && !ProfcheckDisableMetadataFixes) {
+        auto &F = *CB.getCaller();
+        auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
+        auto EC = BFI.getBlockFreq(&F.getEntryBlock());
+        auto CC = F.getEntryCount(/*AllowSynthetic=*/true);
+        double CallCount = 0.0;
+        if (EC.getFrequency() != 0 && CC && CC->getCount() != 0) {
+          double CallFreq =
+              static_cast<double>(
+                  BFI.getBlockFreq(CB.getParent()).getFrequency()) /
+              EC.getFrequency();
+          CallCount = CallFreq * CC->getCount();
+        }
+        FunctionEntryCounts[&JT] += CallCount;
+      }
       if (isa<CallInst>(CB))
-        NewCS = IRB.CreateCall(NewFT, JT, Args);
+        NewCS = IRB.CreateCall(NewFT, &JT, Args);
       else
         NewCS =
-            IRB.CreateInvoke(NewFT, JT, cast<InvokeInst>(CB).getNormalDest(),
+            IRB.CreateInvoke(NewFT, &JT, cast<InvokeInst>(CB).getNormalDest(),
                              cast<InvokeInst>(CB).getUnwindDest(), Args);
       NewCS->setCallingConv(CB.getCallingConv());
 
@@ -1571,6 +1597,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
   Apply(SlotInfo.CSInfo);
   for (auto &P : SlotInfo.ConstCSInfo)
     Apply(P.second);
+  for (auto &[F, C] : FunctionEntryCounts) {
+    assert(!F->getEntryCount(/*AllowSynthetic=*/true) &&
+           "Unexpected entry count for funnel that was freshly synthesized");
+    F->setEntryCount(static_cast<uint64_t>(std::round(C)));
+  }
 }
 
 bool DevirtModule::tryEvaluateFunctionsWithArgs(
@@ -2244,12 +2275,12 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
   if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
     // The type of the function is irrelevant, because it's bitcast at calls
     // anyhow.
-    Constant *JT = cast<Constant>(
+    auto *JT = cast<Function>(
         M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
                               Type::getVoidTy(M.getContext()))
             .getCallee());
     bool IsExported = false;
-    applyICallBranchFunnel(SlotInfo, JT, IsExported);
+    applyICallBranchFunnel(SlotInfo, *JT, IsExported);
     assert(!IsExported);
   }
 }
diff --git a/llvm/test/Transforms/WholeProgramDevirt/branch-funnel-profile.ll b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel-profile.ll
new file mode 100644
index 0000000000000..dd1aa926de8a5
--- /dev/null
+++ b/llvm/test/Transforms/WholeProgramDevirt/branch-funnel-profile.ll
@@ -0,0 +1,203 @@
+; A variant of branch-funnel.ll where we just check that the funnels' entry counts
+; are correctly set.
+;
+; RUN: opt -S -passes=wholeprogramdevirt -whole-program-visibility %s | FileCheck --check-prefixes=RETP %s
+; RUN: sed -e 's,+retpoline,-retpoline,g' %s | opt -S -passes=wholeprogramdevirt -whole-program-visibility | FileCheck --check-prefixes=NORETP %s
+; RUN: opt -passes=wholeprogramdevirt -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=RETP %s
+; RUN: opt -passes='wholeprogramdevirt,default<O3>' -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t  -S -o - %s | FileCheck --check-prefixes=O3 %s
+
+; RETP: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) !prof !11
+; RETP: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) !prof !11
+; RETP: define internal void @branch_funnel(ptr nest %0, ...) !prof !10
+; RETP: define internal void @branch_funnel.1(ptr nest %0, ...) !prof !10 
+; RETP: !10 = !{!"function_entry_count", i64 1000}
+; RETP: !11 = !{!"function_entry_count", i64 3000}
+
+; NORETP: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) !prof !11
+; NORETP: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) !prof !11
+; NORETP: define internal void @branch_funnel(ptr nest %0, ...) !prof !11
+; NORETP: define internal void @branch_funnel.1(ptr nest %0, ...) !prof !11
+; NORETP: !11 = !{!"unknown"}
+
+; O3: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !11
+; O3: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !11
+; O3: define internal void @branch_funnel(ptr nest %0, ...) unnamed_addr #5 !prof !10
+; O3: define internal void @branch_funnel.1(ptr nest %0, ...) unnamed_addr #5 !prof !10
+; O3: define hidden void @__typeid_typeid3_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !12
+; O3: define hidden void @__typeid_typeid3_rv_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !12
+; O3: !10 = !{!"function_entry_count", i64 1000}
+; O3: !11 = !{!"function_entry_count", i64 3000}
+; O3: !12 = !{!"unknown"}
+
+target datalayout = "e-p:64:64"
+target triple = "x86_64-unknown-linux-gnu"
+
+@vt1_1 = constant [1 x ptr] [ptr @vf1_1], !type !0
+@vt1_2 = constant [1 x ptr] [ptr @vf1_2], !type !0
+
+declare i32 @vf1_1(ptr %this, i32 %arg)
+declare i32 @vf1_2(ptr %this, i32 %arg)
+
+@vt2_1 = constant [1 x ptr] [ptr @vf2_1], !type !1
+@vt2_2 = constant [1 x ptr] [ptr @vf2_2], !type !1
+@vt2_3 = constant [1 x ptr] [ptr @vf2_3], !type !1
+@vt2_4 = constant [1 x ptr] [ptr @vf2_4], !type !1
+@vt2_5 = constant [1 x ptr] [ptr @vf2_5], !type !1
+@vt2_6 = constant [1 x ptr] [ptr @vf2_6], !type !1
+@vt2_7 = constant [1 x ptr] [ptr @vf2_7], !type !1
+@vt2_8 = constant [1 x ptr] [ptr @vf2_8], !type !1
+@vt2_9 = constant [1 x ptr] [ptr @vf2_9], !type !1
+@vt2_10 = constant [1 x ptr] [ptr @vf2_10], !type !1
+@vt2_11 = constant [1 x ptr] [ptr @vf2_11], !type !1
+
+declare i32 @vf2_1(ptr %this, i32 %arg)
+declare i32 @vf2_2(ptr %this, i32 %arg)
+declare i32 @vf2_3(ptr %this, i32 %arg)
+declare i32 @vf2_4(ptr %this, i32 %arg)
+declare i32 @vf2_5(ptr %this, i32 %arg)
+declare i32 @vf2_6(ptr %this, i32 %arg)
+declare i32 @vf2_7(ptr %this, i32 %arg)
+declare i32 @vf2_8(ptr %this, i32 %arg)
+declare i32 @vf2_9(ptr %this, i32 %arg)
+declare i32 @vf2_10(ptr %this, i32 %arg)
+declare i32 @vf2_11(ptr %this, i32 %arg)
+
+@vt3_1 = constant [1 x ptr] [ptr @vf3_1], !type !2
+@vt3_2 = constant [1 x ptr] [ptr @vf3_2], !type !2
+
+declare i32 @vf3_1(ptr %this, i32 %arg)
+declare i32 @vf3_2(ptr %this, i32 %arg)
+
+@vt4_1 = constant [1 x ptr] [ptr @vf4_1], !type !3
+@vt4_2 = constant [1 x ptr] [ptr @vf4_2], !type !3
+
+declare i32 @vf4_1(ptr %this, i32 %arg)
+declare i32 @vf4_2(ptr %this, i32 %arg)
+
+declare ptr @llvm.load.relative.i32(ptr, i32)
+
+;; These are relative vtables equivalent to the ones above.
+@vt1_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf1_1 to i64), i64 ptrtoint (ptr @vt1_1_rv to i64)) to i32)], !type !5
+@vt1_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf1_2 to i64), i64 ptrtoint (ptr @vt1_2_rv to i64)) to i32)], !type !5
+
+@vt2_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_1 to i64), i64 ptrtoint (ptr @vt2_1_rv to i64)) to i32)], !type !6
+@vt2_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_2 to i64), i64 ptrtoint (ptr @vt2_2_rv to i64)) to i32)], !type !6
+@vt2_3_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_3 to i64), i64 ptrtoint (ptr @vt2_3_rv to i64)) to i32)], !type !6
+@vt2_4_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_4 to i64), i64 ptrtoint (ptr @vt2_4_rv to i64)) to i32)], !type !6
+@vt2_5_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_5 to i64), i64 ptrtoint (ptr @vt2_5_rv to i64)) to i32)], !type !6
+@vt2_6_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_6 to i64), i64 ptrtoint (ptr @vt2_6_rv to i64)) to i32)], !type !6
+@vt2_7_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_7 to i64), i64 ptrtoint (ptr @vt2_7_rv to i64)) to i32)], !type !6
+@vt2_8_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_8 to i64), i64 ptrtoint (ptr @vt2_8_rv to i64)) to i32)], !type !6
+@vt2_9_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_9 to i64), i64 ptrtoint (ptr @vt2_9_rv to i64)) to i32)], !type !6
+@vt2_10_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_10 to i64), i64 ptrtoint (ptr @vt2_10_rv to i64)) to i32)], !type !6
+@vt2_11_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_11 to i64), i64 ptrtoint (ptr @vt2_11_rv to i64)) to i32)], !type !6
+
+@vt3_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf3_1 to i64), i64 ptrtoint (ptr @vt3_1_rv to i64)) to i32)], !type !7
+@vt3_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf3_2 to i64), i64 ptrtoint (ptr @vt3_2_rv to i64)) to i32)], !type !7
+
+@vt4_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf4_1 to i64), i64 ptrtoint (ptr @vt4_1_rv to i64)) to i32)], !type !8
+@vt4_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf4_2 to i64), i64 ptrtoint (ptr @vt4_2_rv to i64)) to i32)], !type !8
+
+
+define i32 @fn1(ptr %obj) #0 !prof !10 {
+  %vtable = load ptr, ptr %obj
+  %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr %vtable
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn1_rv(ptr %obj) #0 !prof !10 {
+  %vtable = load ptr, ptr %obj
+  %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1_rv")
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr %vtable, i32 0)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn2(ptr %obj) #0 !prof !10 {
+  %vtable = load ptr, ptr %obj
+  %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr %vtable
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn2_rv(ptr %obj) #0 !prof !10 {
+  %vtable = load ptr, ptr %obj
+  %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2_rv")
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr %vtable, i32 0)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn3(ptr %obj) #0 !prof !10 {
+  %vtable = load ptr, ptr %obj
+  %p = call i1 @llvm.type.test(ptr %vtable, metadata !4)
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr %vtable
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn3_rv(ptr %obj) #0 !prof !10 {
+  %vtable = load ptr, ptr %obj
+  %p = call i1 @llvm.type.test(ptr %vtable, metadata !9)
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr %vtable, i32 0)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn4(ptr %obj) #0 !prof !10 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr @vt1_1
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn4_cpy(ptr %obj) #0 !prof !10 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
+  call void @llvm.assume(i1 %p)
+  %fptr = load ptr, ptr @vt1_1
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn4_rv(ptr %obj) #0 !prof !10 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+define i32 @fn4_rv_cpy(ptr %obj) #0 !prof !10 {
+  %p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
+  call void @llvm.assume(i1 %p)
+  %fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
+  %result = call i32 %fptr(ptr %obj, i32 1)
+  ret i32 %result
+}
+
+declare i1 @llvm.type.test(ptr, metadata)
+declare void @llvm.assume(i1)
+
+!0 = !{i32 0, !"typeid1"}
+!1 = !{i32 0, !"typeid2"}
+!2 = !{i32 0, !"typeid3"}
+!3 = !{i32 0, !4}
+!4 = distinct !{}
+!5 = !{i32 0, !"typeid1_rv"}
+!6 = !{i32 0, !"typeid2_rv"}
+!7 = !{i32 0, !"typeid3_rv"}
+!8 = !{i32 0, !9}
+!9 = distinct !{}
+!10 = !{!"function_entry_count", i64 1000}
+
+attributes #0 = { "target-features"="+retpoline" }

@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 4fe3e48 to 12d4eeb Compare August 28, 2025 02:35
@mtrofin mtrofin changed the title [WPD] set the function entry count [WPD] set the branch funnel function entry count Aug 28, 2025
Copy link
Contributor

@teresajohnson teresajohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple comments/questions, but hopefully @pcc can review too since he's more familiar with this transformation.

if (IsExported)
Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;

if (!JT->getEntryCount().has_value()) {
// FIXME: we could pass through thinlto the necessary information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What info would we need?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling callsites' hotness

@@ -1533,11 +1544,26 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
llvm::append_range(Args, CB.args());

CallBase *NewCS = nullptr;
if (!JT.isDeclaration() && !ProfcheckDisableMetadataFixes) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments about the approach used below to compute the entry counts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal

@@ -1571,6 +1597,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
Apply(P.second);
for (auto &[F, C] : FunctionEntryCounts) {
assert(!F->getEntryCount(/*AllowSynthetic=*/true) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is false because this is a func that was initially set to have unknown earlier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's set to unknown after this (==applyICallBranchFunnel) is called, if we didn't have module-internal callsites to use.

but to the original question, the functions appearing here are funnels we just generated, so should have no entry count.

@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 12d4eeb to 038e91b Compare August 28, 2025 16:47
@mtrofin mtrofin changed the base branch from main to users/mtrofin/08-28-_profcheck_allow_unknown_function_entry_count August 28, 2025 20:37
@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 038e91b to 7d1a2bf Compare August 28, 2025 20:38
@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 7d1a2bf to ce1f1b2 Compare August 28, 2025 20:55
@mtrofin mtrofin force-pushed the users/mtrofin/08-28-_profcheck_allow_unknown_function_entry_count branch from d52142d to 0388e7f Compare August 28, 2025 20:55
@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from ce1f1b2 to 8b547b3 Compare September 4, 2025 17:30
@mtrofin mtrofin force-pushed the users/mtrofin/08-28-_profcheck_allow_unknown_function_entry_count branch from 0388e7f to e738ffe Compare September 4, 2025 17:30
Base automatically changed from users/mtrofin/08-28-_profcheck_allow_unknown_function_entry_count to main September 4, 2025 20:15
@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 8b547b3 to 28a2d35 Compare September 5, 2025 15:10
@pcc
Copy link
Contributor

pcc commented Sep 5, 2025

Branch funnel functions can't be inlined (see https://reviews.llvm.org/D45116) so I'm not sure that it's worth going to the effort of making the entry count accurate. Could we do something simpler here where we mark the calls with an unknown entry count?

Copy link
Member Author

mtrofin commented Sep 5, 2025

Inlining isn't the only scenario. For example, for ml-guided optimizations, we want to compute estimates for performance without running the program, and (reasonably) accurate entry counts play a role in that. Basically I'd prefer capturing, when possible, this information, rather than not.

There's an argument to be made for future optimizations - can we rely on !prof info being as accurate as possible. I.e. I wouldn't base calculating or not on today's scenarios, and rather keep it grounded on what's possible or not (i.e. if it's not possible, then, well... we did our best 😄 )

@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 28a2d35 to 5676021 Compare September 5, 2025 23:34
Copy link
Contributor

@pcc pcc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, I guess you could use this to decide whether it goes to .text or .text.hot or something like that but there's not much else you can do with these functions.

@@ -1453,7 +1458,7 @@ void DevirtModule::tryICallBranchFunnel(

FunctionType *FT =
FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
Function *JT;
Function *JT = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a drive-by. Right now, no, because JT gets initialized below, but should that change later, initializing at declaration avoids risk of UB.

@mtrofin mtrofin force-pushed the users/mtrofin/08-26-_wpd_set_the_function_entry_count branch from 5676021 to be3646a Compare September 6, 2025 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants