@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834
834
835
835
static std::unordered_map<PyObject*, uint64_t > dict_version_map;
836
836
static int dict_version_watcher_id;
837
+ static int dict_recursive_tag_watcher_id;
837
838
static uint64_t global_dict_version_id = 1 ;
838
839
static int dict_version_watch_callback (
839
840
PyDict_WatchEvent event,
@@ -1545,6 +1546,27 @@ class GuardManager;
1545
1546
class RootGuardManager ;
1546
1547
class DictGuardManager ;
1547
1548
1549
+ // Global registry used by the *recursive-dict-tag* optimisation.
1550
+ //
1551
+ // Key : `PyObject*` pointing to a watched `dict`
1552
+ // Value : list of `GuardManager*` instances that have recorded that dict
1553
+ //
1554
+ // Why is this global?
1555
+ // -------------------
1556
+ // * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1557
+ // All `GuardManager`s therefore share a single watcher callback.
1558
+ // * A “tag-safe root” in one compilation frame may observe the **same** object
1559
+ // as another frame. Consequently multiple `GuardManager`s can end up
1560
+ // watching the exact same dictionary pointer.
1561
+ //
1562
+ // Expected size
1563
+ // -------------
1564
+ // Every compilation frame contributes its tag-safe dicts to this registry, so
1565
+ // the container can grow large over the lifetime of the process. That’s
1566
+ // acceptable: lookup is by pointer (hash/equals = identity) and each entry
1567
+ // stores only lightweight pointers.
1568
+ std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_manager;
1569
+
1548
1570
/* *
1549
1571
* Base class for the leaf guard in the GuardManager hierarchy.
1550
1572
*/
@@ -2613,6 +2635,7 @@ class GuardManager {
2613
2635
2614
2636
virtual ~GuardManager () {
2615
2637
cleanup_tag_safe_entries ();
2638
+ unwatch_dict_pointers ();
2616
2639
}
2617
2640
2618
2641
void cleanup_tag_safe_entries () {
@@ -2715,6 +2738,10 @@ class GuardManager {
2715
2738
_tensor_pointers[value] = tensor_pointers;
2716
2739
}
2717
2740
2741
+ void disable_recursive_dict_tag_optimization () {
2742
+ _disable_dict_tag_matching = true ;
2743
+ }
2744
+
2718
2745
public:
2719
2746
// For cloning
2720
2747
GuardManager (
@@ -2821,6 +2848,10 @@ class GuardManager {
2821
2848
}
2822
2849
2823
2850
bool check_dict_pointer_tags (PyObject* value) {
2851
+ if (_dict_callback_installed) {
2852
+ // This means that for 3.12+, there are callbacks watching dict pointers.
2853
+ return true ;
2854
+ }
2824
2855
for (auto & kv : _dict_pointers[value]) {
2825
2856
PyObject* dict_pointer = kv.first ;
2826
2857
uint64_t old_tag = kv.second ;
@@ -2951,6 +2982,11 @@ class GuardManager {
2951
2982
throw std::runtime_error (
2952
2983
" Could not register a callback for recursive dict tag optimization" );
2953
2984
}
2985
+ #if IS_PYTHON_3_12_PLUS
2986
+ // Ideally we don't need to even register a weakref callback for value.
2987
+ // But it does not hurt to be more cautious
2988
+ _dict_callback_installed = watch_dict_pointers (value);
2989
+ #endif
2954
2990
}
2955
2991
}
2956
2992
if (!result) {
@@ -2967,8 +3003,16 @@ class GuardManager {
2967
3003
}
2968
3004
GuardManager* guard_manager = static_cast <GuardManager*>(
2969
3005
PyCapsule_GetPointer (self_capsule, " GuardManager*" ));
2970
- if (guard_manager)
3006
+ if (guard_manager) {
2971
3007
guard_manager->_disable_dict_tag_matching = true ;
3008
+ // When the entry for a given value in _dict_pointers becomes invalid, it
3009
+ // can only be because that value itself has been garbage-collected. At
3010
+ // that moment every dictionary it referenced has also been reclaimed, and
3011
+ // the dict-watch callback triggered during their finalisation has already
3012
+ // invoked unwatch_dict_pointers. In short, by the time we reach this
3013
+ // code the dict pointers have been unwatched automatically, so no further
3014
+ // action is needed.
3015
+ }
2972
3016
Py_RETURN_NONE;
2973
3017
}
2974
3018
@@ -3019,6 +3063,59 @@ class GuardManager {
3019
3063
return true ;
3020
3064
}
3021
3065
3066
+ bool watch_dict_pointers (PyObject* value) {
3067
+ #if IS_PYTHON_3_12_PLUS
3068
+ // -----------------------------------------------------------------------------
3069
+ // CPython 3.12 dict-watcher integration
3070
+ // -----------------------------------------------------------------------------
3071
+ //
3072
+ // We register a single watcher on all every dictionary pointer recorded by
3073
+ // a tag-safe root. The watcher callback fires *once* for any structural
3074
+ // change to those dictionaries
3075
+ //
3076
+ // Fast-path benefit
3077
+ // -----------------
3078
+ // In steady state we no longer need to iterate over the recorded
3079
+ // dictionaries and compare their `ma_version_tag`s (the
3080
+ // “are-tags-unchanged” loop that used to dominate the fast-path guard
3081
+ // evaluation). The presence of an *active watcher* is itself a guarantee
3082
+ // that none of the dicts has mutated; if one **does** mutate, the callback
3083
+ // simply flips `_disable_dict_tag_matching = true`, causing the next guard
3084
+ // evaluation to skip the recursive-dict-tag optimisation entirely.
3085
+ for (auto & kv : _dict_pointers[value]) {
3086
+ PyObject* dict_pointer = kv.first ;
3087
+ int rc = PyDict_Watch (dict_recursive_tag_watcher_id, dict_pointer);
3088
+ if (rc != 0 ) {
3089
+ PyErr_Clear ();
3090
+ return false ;
3091
+ }
3092
+ dict_to_guard_manager[dict_pointer].push_back (this );
3093
+ }
3094
+ #endif
3095
+ return true ;
3096
+ }
3097
+
3098
+ void unwatch_dict_pointers () {
3099
+ #if IS_PYTHON_3_12_PLUS
3100
+ if (!_dict_callbacks_unwatched && !_disable_dict_tag_matching) {
3101
+ for (auto & value_stashed_pointers : _dict_pointers) {
3102
+ auto stashed_pointers = value_stashed_pointers.second ;
3103
+
3104
+ for (auto & stashed_pointer : stashed_pointers) {
3105
+ PyObject* dict_pointer = stashed_pointer.first ;
3106
+ PyDict_Unwatch (dict_recursive_tag_watcher_id, dict_pointer);
3107
+ auto it = std::find (
3108
+ dict_to_guard_manager[dict_pointer].begin (),
3109
+ dict_to_guard_manager[dict_pointer].end (),
3110
+ this );
3111
+ dict_to_guard_manager[dict_pointer].erase (it);
3112
+ }
3113
+ }
3114
+ }
3115
+ _dict_callbacks_unwatched = true ;
3116
+ #endif
3117
+ }
3118
+
3022
3119
virtual bool check_nopybind (FrameLocalsMapping* value) {
3023
3120
return check_nopybind_template (value);
3024
3121
}
@@ -3258,6 +3355,12 @@ class GuardManager {
3258
3355
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
3259
3356
std::vector<WeakEntry> _tag_safe_entries;
3260
3357
3358
+ // 3.12+ related helper
3359
+ bool _dict_callback_installed = false ;
3360
+ #if IS_PYTHON_3_12_PLUS
3361
+ bool _dict_callbacks_unwatched = false ;
3362
+ #endif
3363
+
3261
3364
protected:
3262
3365
// weakref to the type of guarded value
3263
3366
// protected because it is used for cloning by DictGuardManager
@@ -3945,6 +4048,24 @@ void add_relational_guard_resetter_to_cloned_root(
3945
4048
root->add_relational_guard_resetter (std::move (guard));
3946
4049
}
3947
4050
4051
+ #if IS_PYTHON_3_12_PLUS
4052
+ static int dict_recursive_tag_watch_callback (
4053
+ PyDict_WatchEvent event,
4054
+ PyObject* dict,
4055
+ PyObject* key,
4056
+ PyObject* new_value) noexcept {
4057
+ auto it = dict_to_guard_manager.find (dict);
4058
+ if (it != dict_to_guard_manager.end ()) {
4059
+ auto guard_managers = it->second ;
4060
+ for (auto & guard_manager : guard_managers) {
4061
+ guard_manager->unwatch_dict_pointers ();
4062
+ guard_manager->disable_recursive_dict_tag_optimization ();
4063
+ }
4064
+ }
4065
+ return 0 ; // keep watching
4066
+ }
4067
+ #endif
4068
+
3948
4069
std::unique_ptr<GuardManager> make_guard_manager (
3949
4070
RootGuardManager* root,
3950
4071
std::string source,
@@ -7546,6 +7667,13 @@ PyObject* torch_c_dynamo_guards_init() {
7546
7667
throw std::runtime_error (" Failed to install dict_version_watch_callback" );
7547
7668
}
7548
7669
7670
+ dict_recursive_tag_watcher_id =
7671
+ PyDict_AddWatcher (dict_recursive_tag_watch_callback);
7672
+ if (dict_recursive_tag_watcher_id == -1 ) {
7673
+ throw std::runtime_error (
7674
+ " Failed to install dict_recursive_tag_watch_callback" );
7675
+ }
7676
+
7549
7677
#endif
7550
7678
7551
7679
return m;
0 commit comments