@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834
834
835
835
static std::unordered_map<PyObject*, uint64_t > dict_version_map;
836
836
static int dict_version_watcher_id;
837
+ static int dict_recursive_tag_watcher_id;
837
838
static uint64_t global_dict_version_id = 1 ;
838
839
static int dict_version_watch_callback (
839
840
PyDict_WatchEvent event,
@@ -1557,6 +1558,27 @@ class GuardManager;
1557
1558
class RootGuardManager ;
1558
1559
class DictGuardManager ;
1559
1560
1561
+ // Global registry used by the *recursive-dict-tag* optimisation.
1562
+ //
1563
+ // Key : `PyObject*` pointing to a watched `dict`
1564
+ // Value : list of `GuardManager*` instances that have recorded that dict
1565
+ //
1566
+ // Why is this global?
1567
+ // -------------------
1568
+ // * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1569
+ // All `GuardManager`s therefore share a single watcher callback.
1570
+ // * A “tag-safe root” in one compilation frame may observe the **same** object
1571
+ // as another frame. Consequently multiple `GuardManager`s can end up
1572
+ // watching the exact same dictionary pointer.
1573
+ //
1574
+ // Expected size
1575
+ // -------------
1576
+ // Every compilation frame contributes its tag-safe dicts to this registry, so
1577
+ // the container can grow large over the lifetime of the process. That’s
1578
+ // acceptable: lookup is by pointer (hash/equals = identity) and each entry
1579
+ // stores only lightweight pointers.
1580
+ std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_manager;
1581
+
1560
1582
/* *
1561
1583
* Base class for the leaf guard in the GuardManager hierarchy.
1562
1584
*/
@@ -2625,6 +2647,7 @@ class GuardManager {
2625
2647
2626
2648
virtual ~GuardManager () {
2627
2649
cleanup_tag_safe_entries ();
2650
+ unwatch_dict_pointers ();
2628
2651
}
2629
2652
2630
2653
void cleanup_tag_safe_entries () {
@@ -2727,6 +2750,10 @@ class GuardManager {
2727
2750
_tensor_pointers[value] = tensor_pointers;
2728
2751
}
2729
2752
2753
+ void disable_recursive_dict_tag_optimization () {
2754
+ _disable_dict_tag_matching = true ;
2755
+ }
2756
+
2730
2757
public:
2731
2758
// For cloning
2732
2759
GuardManager (
@@ -2833,6 +2860,10 @@ class GuardManager {
2833
2860
}
2834
2861
2835
2862
bool check_dict_pointer_tags (PyObject* value) {
2863
+ if (_dict_callback_installed) {
2864
+ // This means that for 3.12+, there are callbacks watching dict pointers.
2865
+ return true ;
2866
+ }
2836
2867
for (auto & kv : _dict_pointers[value]) {
2837
2868
PyObject* dict_pointer = kv.first ;
2838
2869
uint64_t old_tag = kv.second ;
@@ -2963,6 +2994,11 @@ class GuardManager {
2963
2994
throw std::runtime_error (
2964
2995
" Could not register a callback for recursive dict tag optimization" );
2965
2996
}
2997
+ #if IS_PYTHON_3_12_PLUS
2998
+ // Ideally we don't need to even register a weakref callback for value.
2999
+ // But it does not hurt to be more cautious
3000
+ _dict_callback_installed = watch_dict_pointers (value);
3001
+ #endif
2966
3002
}
2967
3003
}
2968
3004
if (!result) {
@@ -2979,8 +3015,16 @@ class GuardManager {
2979
3015
}
2980
3016
GuardManager* guard_manager = static_cast <GuardManager*>(
2981
3017
PyCapsule_GetPointer (self_capsule, " GuardManager*" ));
2982
- if (guard_manager)
3018
+ if (guard_manager) {
2983
3019
guard_manager->_disable_dict_tag_matching = true ;
3020
+ // When the entry for a given value in _dict_pointers becomes invalid, it
3021
+ // can only be because that value itself has been garbage-collected. At
3022
+ // that moment every dictionary it referenced has also been reclaimed, and
3023
+ // the dict-watch callback triggered during their finalisation has already
3024
+ // invoked unwatch_dict_pointers. In short, by the time we reach this
3025
+ // code the dict pointers have been unwatched automatically, so no further
3026
+ // action is needed.
3027
+ }
2984
3028
Py_RETURN_NONE;
2985
3029
}
2986
3030
@@ -3031,6 +3075,59 @@ class GuardManager {
3031
3075
return true ;
3032
3076
}
3033
3077
3078
+ bool watch_dict_pointers (PyObject* value) {
3079
+ #if IS_PYTHON_3_12_PLUS
3080
+ // -----------------------------------------------------------------------------
3081
+ // CPython 3.12 dict-watcher integration
3082
+ // -----------------------------------------------------------------------------
3083
+ //
3084
+ // We register a single watcher on all every dictionary pointer recorded by
3085
+ // a tag-safe root. The watcher callback fires *once* for any structural
3086
+ // change to those dictionaries
3087
+ //
3088
+ // Fast-path benefit
3089
+ // -----------------
3090
+ // In steady state we no longer need to iterate over the recorded
3091
+ // dictionaries and compare their `ma_version_tag`s (the
3092
+ // “are-tags-unchanged” loop that used to dominate the fast-path guard
3093
+ // evaluation). The presence of an *active watcher* is itself a guarantee
3094
+ // that none of the dicts has mutated; if one **does** mutate, the callback
3095
+ // simply flips `_disable_dict_tag_matching = true`, causing the next guard
3096
+ // evaluation to skip the recursive-dict-tag optimisation entirely.
3097
+ for (auto & kv : _dict_pointers[value]) {
3098
+ PyObject* dict_pointer = kv.first ;
3099
+ int rc = PyDict_Watch (dict_recursive_tag_watcher_id, dict_pointer);
3100
+ if (rc != 0 ) {
3101
+ PyErr_Clear ();
3102
+ return false ;
3103
+ }
3104
+ dict_to_guard_manager[dict_pointer].push_back (this );
3105
+ }
3106
+ #endif
3107
+ return true ;
3108
+ }
3109
+
3110
+ void unwatch_dict_pointers () {
3111
+ #if IS_PYTHON_3_12_PLUS
3112
+ if (!_dict_callbacks_unwatched && !_disable_dict_tag_matching) {
3113
+ for (auto & value_stashed_pointers : _dict_pointers) {
3114
+ auto stashed_pointers = value_stashed_pointers.second ;
3115
+
3116
+ for (auto & stashed_pointer : stashed_pointers) {
3117
+ PyObject* dict_pointer = stashed_pointer.first ;
3118
+ PyDict_Unwatch (dict_recursive_tag_watcher_id, dict_pointer);
3119
+ auto it = std::find (
3120
+ dict_to_guard_manager[dict_pointer].begin (),
3121
+ dict_to_guard_manager[dict_pointer].end (),
3122
+ this );
3123
+ dict_to_guard_manager[dict_pointer].erase (it);
3124
+ }
3125
+ }
3126
+ }
3127
+ _dict_callbacks_unwatched = true ;
3128
+ #endif
3129
+ }
3130
+
3034
3131
virtual bool check_nopybind (FrameLocalsMapping* value) {
3035
3132
return check_nopybind_template (value);
3036
3133
}
@@ -3270,6 +3367,12 @@ class GuardManager {
3270
3367
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
3271
3368
std::vector<WeakEntry> _tag_safe_entries;
3272
3369
3370
+ // 3.12+ related helper
3371
+ bool _dict_callback_installed = false ;
3372
+ #if IS_PYTHON_3_12_PLUS
3373
+ bool _dict_callbacks_unwatched = false ;
3374
+ #endif
3375
+
3273
3376
protected:
3274
3377
// weakref to the type of guarded value
3275
3378
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4060,24 @@ void add_relational_guard_resetter_to_cloned_root(
3957
4060
root->add_relational_guard_resetter (std::move (guard));
3958
4061
}
3959
4062
4063
+ #if IS_PYTHON_3_12_PLUS
4064
+ static int dict_recursive_tag_watch_callback (
4065
+ PyDict_WatchEvent event,
4066
+ PyObject* dict,
4067
+ PyObject* key,
4068
+ PyObject* new_value) noexcept {
4069
+ auto it = dict_to_guard_manager.find (dict);
4070
+ if (it != dict_to_guard_manager.end ()) {
4071
+ auto guard_managers = it->second ;
4072
+ for (auto & guard_manager : guard_managers) {
4073
+ guard_manager->unwatch_dict_pointers ();
4074
+ guard_manager->disable_recursive_dict_tag_optimization ();
4075
+ }
4076
+ }
4077
+ return 0 ; // keep watching
4078
+ }
4079
+ #endif
4080
+
3960
4081
std::unique_ptr<GuardManager> make_guard_manager (
3961
4082
RootGuardManager* root,
3962
4083
std::string source,
@@ -7558,6 +7679,13 @@ PyObject* torch_c_dynamo_guards_init() {
7558
7679
throw std::runtime_error (" Failed to install dict_version_watch_callback" );
7559
7680
}
7560
7681
7682
+ dict_recursive_tag_watcher_id =
7683
+ PyDict_AddWatcher (dict_recursive_tag_watch_callback);
7684
+ if (dict_recursive_tag_watcher_id == -1 ) {
7685
+ throw std::runtime_error (
7686
+ " Failed to install dict_recursive_tag_watch_callback" );
7687
+ }
7688
+
7561
7689
#endif
7562
7690
7563
7691
return m;
0 commit comments