@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834
834
835
835
static std::unordered_map<PyObject*, uint64_t > dict_version_map;
836
836
static int dict_version_watcher_id;
837
+ static int dict_recursive_tag_watcher_id;
837
838
static uint64_t global_dict_version_id = 1 ;
838
839
static int dict_version_watch_callback (
839
840
PyDict_WatchEvent event,
@@ -1557,6 +1558,27 @@ class GuardManager;
1557
1558
class RootGuardManager ;
1558
1559
class DictGuardManager ;
1559
1560
1561
+ // Global registry used by the *recursive-dict-tag* optimisation.
1562
+ //
1563
+ // Key : `PyObject*` pointing to a watched `dict`
1564
+ // Value : list of `GuardManager*` instances that have recorded that dict
1565
+ //
1566
+ // Why is this global?
1567
+ // -------------------
1568
+ // * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1569
+ // All `GuardManager`s therefore share a single watcher callback.
1570
+ // * A “tag-safe root” in one compilation frame may observe the **same** object
1571
+ // as another frame. Consequently multiple `GuardManager`s can end up
1572
+ // watching the exact same dictionary pointer.
1573
+ //
1574
+ // Expected size
1575
+ // -------------
1576
+ // Every compilation frame contributes its tag-safe dicts to this registry, so
1577
+ // the container can grow large over the lifetime of the process. That’s
1578
+ // acceptable: lookup is by pointer (hash/equals = identity) and each entry
1579
+ // stores only lightweight pointers.
1580
+ std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_manager;
1581
+
1560
1582
/* *
1561
1583
* Base class for the leaf guard in the GuardManager hierarchy.
1562
1584
*/
@@ -2625,6 +2647,7 @@ class GuardManager {
2625
2647
2626
2648
virtual ~GuardManager () {
2627
2649
cleanup_tag_safe_entries ();
2650
+ unwatch_dict_pointers ();
2628
2651
}
2629
2652
2630
2653
void cleanup_tag_safe_entries () {
@@ -2727,6 +2750,10 @@ class GuardManager {
2727
2750
_tensor_pointers[value] = tensor_pointers;
2728
2751
}
2729
2752
2753
+ void disable_recursive_dict_tag_optimization () {
2754
+ _disable_dict_tag_matching = true ;
2755
+ }
2756
+
2730
2757
public:
2731
2758
// For cloning
2732
2759
GuardManager (
@@ -2833,6 +2860,10 @@ class GuardManager {
2833
2860
}
2834
2861
2835
2862
bool check_dict_pointer_tags (PyObject* value) {
2863
+ if (_dict_callback_installed) {
2864
+ // This means that for 3.12+, there are callbacks watching dict pointers.
2865
+ return true ;
2866
+ }
2836
2867
for (auto & kv : _dict_pointers[value]) {
2837
2868
PyObject* dict_pointer = kv.first ;
2838
2869
uint64_t old_tag = kv.second ;
@@ -2963,6 +2994,11 @@ class GuardManager {
2963
2994
throw std::runtime_error (
2964
2995
" Could not register a callback for recursive dict tag optimization" );
2965
2996
}
2997
+ #if IS_PYTHON_3_12_PLUS
2998
+ // Ideally we don't need to even register a weakref callback for value.
2999
+ // But it does not hurt to be more cautious
3000
+ _dict_callback_installed = watch_dict_pointers (value);
3001
+ #endif
2966
3002
}
2967
3003
}
2968
3004
if (!result) {
@@ -2979,8 +3015,10 @@ class GuardManager {
2979
3015
}
2980
3016
GuardManager* guard_manager = static_cast <GuardManager*>(
2981
3017
PyCapsule_GetPointer (self_capsule, " GuardManager*" ));
2982
- if (guard_manager)
3018
+ if (guard_manager) {
3019
+ guard_manager->unwatch_dict_pointers ();
2983
3020
guard_manager->_disable_dict_tag_matching = true ;
3021
+ }
2984
3022
Py_RETURN_NONE;
2985
3023
}
2986
3024
@@ -3031,6 +3069,61 @@ class GuardManager {
3031
3069
return true ;
3032
3070
}
3033
3071
3072
+ bool watch_dict_pointers (PyObject* value) {
3073
+ #if IS_PYTHON_3_12_PLUS
3074
+ // -----------------------------------------------------------------------------
3075
+ // CPython 3.12 dict-watcher integration
3076
+ // -----------------------------------------------------------------------------
3077
+ //
3078
+ // We register a single watcher on all every dictionary pointer recorded by
3079
+ // a tag-safe root. The watcher callback fires *once* for any structural
3080
+ // change to those dictionaries
3081
+ //
3082
+ // Fast-path benefit
3083
+ // -----------------
3084
+ // In steady state we no longer need to iterate over the recorded
3085
+ // dictionaries and compare their `ma_version_tag`s (the
3086
+ // “are-tags-unchanged” loop that used to dominate the fast-path guard
3087
+ // evaluation). The presence of an *active watcher* is itself a guarantee
3088
+ // that none of the dicts has mutated; if one **does** mutate, the callback
3089
+ // simply flips `_disable_dict_tag_matching = true`, causing the next guard
3090
+ // evaluation to skip the recursive-dict-tag optimisation entirely.
3091
+ for (auto & kv : _dict_pointers[value]) {
3092
+ PyObject* dict_pointer = kv.first ;
3093
+ int rc = PyDict_Watch (dict_recursive_tag_watcher_id, dict_pointer);
3094
+ if (rc != 0 ) {
3095
+ PyErr_Clear ();
3096
+ return false ;
3097
+ }
3098
+ dict_to_guard_manager[dict_pointer].push_back (this );
3099
+ }
3100
+ #endif
3101
+ return true ;
3102
+ }
3103
+
3104
+ void unwatch_dict_pointers () {
3105
+ #if IS_PYTHON_3_12_PLUS
3106
+ if (!_dict_callbacks_unwatched && !_disable_dict_tag_matching) {
3107
+ for (auto & value_stashed_pointers : _dict_pointers) {
3108
+ auto stashed_pointers = value_stashed_pointers.second ;
3109
+
3110
+ for (auto & stashed_pointer : stashed_pointers) {
3111
+ PyObject* dict_pointer = stashed_pointer.first ;
3112
+ PyDict_Unwatch (dict_recursive_tag_watcher_id, dict_pointer);
3113
+ auto it = std::find (
3114
+ dict_to_guard_manager[dict_pointer].begin (),
3115
+ dict_to_guard_manager[dict_pointer].end (),
3116
+ this );
3117
+ if (it != dict_to_guard_manager[dict_pointer].end ()) {
3118
+ dict_to_guard_manager[dict_pointer].erase (it);
3119
+ }
3120
+ }
3121
+ }
3122
+ }
3123
+ _dict_callbacks_unwatched = true ;
3124
+ #endif
3125
+ }
3126
+
3034
3127
virtual bool check_nopybind (FrameLocalsMapping* value) {
3035
3128
return check_nopybind_template (value);
3036
3129
}
@@ -3270,6 +3363,12 @@ class GuardManager {
3270
3363
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
3271
3364
std::vector<WeakEntry> _tag_safe_entries;
3272
3365
3366
+ // 3.12+ related helper
3367
+ bool _dict_callback_installed = false ;
3368
+ #if IS_PYTHON_3_12_PLUS
3369
+ bool _dict_callbacks_unwatched = false ;
3370
+ #endif
3371
+
3273
3372
protected:
3274
3373
// weakref to the type of guarded value
3275
3374
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4056,26 @@ void add_relational_guard_resetter_to_cloned_root(
3957
4056
root->add_relational_guard_resetter (std::move (guard));
3958
4057
}
3959
4058
4059
+ #if IS_PYTHON_3_12_PLUS
4060
+ static int dict_recursive_tag_watch_callback (
4061
+ PyDict_WatchEvent event,
4062
+ PyObject* dict,
4063
+ PyObject* key,
4064
+ PyObject* new_value) noexcept {
4065
+ auto it = dict_to_guard_manager.find (dict);
4066
+ if (it != dict_to_guard_manager.end ()) {
4067
+ auto guard_managers = it->second ;
4068
+ for (auto & guard_manager : guard_managers) {
4069
+ if (guard_manager) {
4070
+ guard_manager->unwatch_dict_pointers ();
4071
+ guard_manager->disable_recursive_dict_tag_optimization ();
4072
+ }
4073
+ }
4074
+ }
4075
+ return 0 ; // keep watching
4076
+ }
4077
+ #endif
4078
+
3960
4079
std::unique_ptr<GuardManager> make_guard_manager (
3961
4080
RootGuardManager* root,
3962
4081
std::string source,
@@ -7558,6 +7677,13 @@ PyObject* torch_c_dynamo_guards_init() {
7558
7677
throw std::runtime_error (" Failed to install dict_version_watch_callback" );
7559
7678
}
7560
7679
7680
+ dict_recursive_tag_watcher_id =
7681
+ PyDict_AddWatcher (dict_recursive_tag_watch_callback);
7682
+ if (dict_recursive_tag_watcher_id == -1 ) {
7683
+ throw std::runtime_error (
7684
+ " Failed to install dict_recursive_tag_watch_callback" );
7685
+ }
7686
+
7561
7687
#endif
7562
7688
7563
7689
return m;
0 commit comments