@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834
834
835
835
static std::unordered_map<PyObject*, uint64_t > dict_version_map;
836
836
static int dict_version_watcher_id;
837
+ static int dict_recursive_tag_watcher_id;
837
838
static uint64_t global_dict_version_id = 1 ;
838
839
static int dict_version_watch_callback (
839
840
PyDict_WatchEvent event,
@@ -1557,6 +1558,37 @@ class GuardManager;
1557
1558
class RootGuardManager ;
1558
1559
class DictGuardManager ;
1559
1560
1561
+ // Global registry used by the *recursive-dict-tag* optimisation.
1562
+ //
1563
+ // Key : `PyObject*` pointing to a watched `dict`
1564
+ // Value : list of `GuardManager*` instances that have recorded that dict
1565
+ //
1566
+ // Why is this global?
1567
+ // -------------------
1568
+ // * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1569
+ // All `GuardManager`s therefore share a single watcher callback.
1570
+ // * Different guard managers (possibly across different frames) can end up
1571
+ // watching the same dictionary pointer. Therefore, we have a list of guard
1572
+ // managers for each dict pointer.
1573
+ //
1574
+ // When is watch registered?
1575
+ // * During the recording phase of recursive dict tag matching in GuardManager.
1576
+ //
1577
+ // When are they watched?
1578
+ // * In the dict_recursive_tag_watch_callback function.
1579
+ //
1580
+ // When are the dict pointers unwatched?
1581
+ // * If a dict is mutated or the guard manager deallocates.
1582
+ // * Read `unwatch_all_saved_dict_pointers` docstring for more details.
1583
+ //
1584
+ // Expected size
1585
+ // -------------
1586
+ // Every compilation frame contributes its tag-safe dicts to this registry, so
1587
+ // the container can grow large over the lifetime of the process. That’s
1588
+ // acceptable: lookup is by pointer (hash/equals = identity) and each entry
1589
+ // stores only lightweight pointers.
1590
+ std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_managers;
1591
+
1560
1592
/* *
1561
1593
* Base class for the leaf guard in the GuardManager hierarchy.
1562
1594
*/
@@ -2625,6 +2657,7 @@ class GuardManager {
2625
2657
2626
2658
virtual ~GuardManager () {
2627
2659
cleanup_tag_safe_entries ();
2660
+ disable_recursive_dict_tag_optimization ();
2628
2661
}
2629
2662
2630
2663
void cleanup_tag_safe_entries () {
@@ -2727,6 +2760,11 @@ class GuardManager {
2727
2760
_tensor_pointers[value] = tensor_pointers;
2728
2761
}
2729
2762
2763
+ void disable_recursive_dict_tag_optimization () {
2764
+ unwatch_all_saved_dict_pointers ();
2765
+ _disable_dict_tag_matching = true ;
2766
+ }
2767
+
2730
2768
public:
2731
2769
// For cloning
2732
2770
GuardManager (
@@ -2833,6 +2871,10 @@ class GuardManager {
2833
2871
}
2834
2872
2835
2873
bool check_dict_pointer_tags (PyObject* value) {
2874
+ if (_dict_callback_installed) {
2875
+ // This means that for 3.12+, there are callbacks watching dict pointers.
2876
+ return true ;
2877
+ }
2836
2878
for (auto & kv : _dict_pointers[value]) {
2837
2879
PyObject* dict_pointer = kv.first ;
2838
2880
uint64_t old_tag = kv.second ;
@@ -2963,6 +3005,11 @@ class GuardManager {
2963
3005
throw std::runtime_error (
2964
3006
" Could not register a callback for recursive dict tag optimization" );
2965
3007
}
3008
+ #if IS_PYTHON_3_12_PLUS
3009
+ // Ideally we don't need to even register a weakref callback for value.
3010
+ // But it does not hurt to be more cautious
3011
+ _dict_callback_installed = watch_dict_pointers (value);
3012
+ #endif
2966
3013
}
2967
3014
}
2968
3015
if (!result) {
@@ -2979,8 +3026,9 @@ class GuardManager {
2979
3026
}
2980
3027
GuardManager* guard_manager = static_cast <GuardManager*>(
2981
3028
PyCapsule_GetPointer (self_capsule, " GuardManager*" ));
2982
- if (guard_manager)
2983
- guard_manager->_disable_dict_tag_matching = true ;
3029
+ if (guard_manager) {
3030
+ guard_manager->disable_recursive_dict_tag_optimization ();
3031
+ }
2984
3032
Py_RETURN_NONE;
2985
3033
}
2986
3034
@@ -3031,6 +3079,81 @@ class GuardManager {
3031
3079
return true ;
3032
3080
}
3033
3081
3082
+ bool watch_dict_pointers (PyObject* value) {
3083
+ #if IS_PYTHON_3_12_PLUS
3084
+ // -----------------------------------------------------------------------------
3085
+ // CPython 3.12 dict-watcher integration
3086
+ // -----------------------------------------------------------------------------
3087
+ //
3088
+ // We register a single watcher on all every dictionary pointer recorded by
3089
+ // a tag-safe root. The watcher callback fires *once* for any structural
3090
+ // change to those dictionaries
3091
+ //
3092
+ // Fast-path benefit
3093
+ // -----------------
3094
+ // In steady state we no longer need to iterate over the recorded
3095
+ // dictionaries and compare their `ma_version_tag`s (the
3096
+ // “are-tags-unchanged” loop that used to dominate the fast-path guard
3097
+ // evaluation). The presence of an *active watcher* is itself a guarantee
3098
+ // that none of the dicts has mutated; if one **does** mutate, the callback
3099
+ // simply flips `_disable_dict_tag_matching = true`, causing the next guard
3100
+ // evaluation to skip the recursive-dict-tag optimisation entirely.
3101
+ for (auto & kv : _dict_pointers[value]) {
3102
+ PyObject* dict_pointer = kv.first ;
3103
+ int rc = PyDict_Watch (dict_recursive_tag_watcher_id, dict_pointer);
3104
+ if (rc != 0 ) {
3105
+ PyErr_Clear ();
3106
+ return false ;
3107
+ }
3108
+ dict_to_guard_managers[dict_pointer].push_back (this );
3109
+ }
3110
+ #endif
3111
+ return true ;
3112
+ }
3113
+
3114
+ void unwatch_all_saved_dict_pointers () {
3115
+ /*
3116
+ We may have recorded hundreds/thousands of dict pointers for the recursive
3117
+ dict-tag optimisation. If any of those dicts mutates, we want to disable the
3118
+ optimisation and then unwatch as many dict pointers as we can.
3119
+
3120
+ Be careful: the same dict pointer can be recorded by multiple GuardManagers.
3121
+ So the flow is:
3122
+
3123
+ 1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer].
3124
+ 2) If the list for that dict becomes empty, then:
3125
+ - PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer)
3126
+ - erase the dict_pointer entry from dict_to_guard_managers.
3127
+ */
3128
+ #if IS_PYTHON_3_12_PLUS
3129
+ if (!_disable_dict_tag_matching) {
3130
+ for (auto & value_stashed_pointers : _dict_pointers) {
3131
+ auto stashed_pointers = value_stashed_pointers.second ;
3132
+
3133
+ for (auto & stashed_pointer : stashed_pointers) {
3134
+ PyObject* dict_pointer = stashed_pointer.first ;
3135
+
3136
+ // Delete the guard manager from the dict_to_guard_managers
3137
+ auto it = std::find (
3138
+ dict_to_guard_managers[dict_pointer].begin (),
3139
+ dict_to_guard_managers[dict_pointer].end (),
3140
+ this );
3141
+ if (it != dict_to_guard_managers[dict_pointer].end ()) {
3142
+ dict_to_guard_managers[dict_pointer].erase (it);
3143
+ }
3144
+
3145
+ // Unwatch the dict pointer if this was the last guard manager
3146
+ // watching it.
3147
+ if (dict_to_guard_managers[dict_pointer].empty ()) {
3148
+ PyDict_Unwatch (dict_recursive_tag_watcher_id, dict_pointer);
3149
+ dict_to_guard_managers.erase (dict_pointer);
3150
+ }
3151
+ }
3152
+ }
3153
+ }
3154
+ #endif
3155
+ }
3156
+
3034
3157
virtual bool check_nopybind (FrameLocalsMapping* value) {
3035
3158
return check_nopybind_template (value);
3036
3159
}
@@ -3270,6 +3393,9 @@ class GuardManager {
3270
3393
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
3271
3394
std::vector<WeakEntry> _tag_safe_entries;
3272
3395
3396
+ // 3.12+ related helper
3397
+ bool _dict_callback_installed = false ;
3398
+
3273
3399
protected:
3274
3400
// weakref to the type of guarded value
3275
3401
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4083,27 @@ void add_relational_guard_resetter_to_cloned_root(
3957
4083
root->add_relational_guard_resetter (std::move (guard));
3958
4084
}
3959
4085
4086
+ #if IS_PYTHON_3_12_PLUS
4087
+ static int dict_recursive_tag_watch_callback (
4088
+ PyDict_WatchEvent event,
4089
+ PyObject* dict,
4090
+ PyObject* key,
4091
+ PyObject* new_value) noexcept {
4092
+ if (event != PyDict_EVENT_CLONED) {
4093
+ auto it = dict_to_guard_managers.find (dict);
4094
+ if (it != dict_to_guard_managers.end ()) {
4095
+ auto guard_managers = it->second ;
4096
+ for (auto & guard_manager : guard_managers) {
4097
+ if (guard_manager) {
4098
+ guard_manager->disable_recursive_dict_tag_optimization ();
4099
+ }
4100
+ }
4101
+ }
4102
+ }
4103
+ return 0 ; // keep watching
4104
+ }
4105
+ #endif
4106
+
3960
4107
std::unique_ptr<GuardManager> make_guard_manager (
3961
4108
RootGuardManager* root,
3962
4109
std::string source,
@@ -7558,6 +7705,13 @@ PyObject* torch_c_dynamo_guards_init() {
7558
7705
throw std::runtime_error (" Failed to install dict_version_watch_callback" );
7559
7706
}
7560
7707
7708
+ dict_recursive_tag_watcher_id =
7709
+ PyDict_AddWatcher (dict_recursive_tag_watch_callback);
7710
+ if (dict_recursive_tag_watcher_id == -1 ) {
7711
+ throw std::runtime_error (
7712
+ " Failed to install dict_recursive_tag_watch_callback" );
7713
+ }
7714
+
7561
7715
#endif
7562
7716
7563
7717
return m;
0 commit comments