Skip to content

Commit 3064d53

Browse files
committed
[dynamo][guards] Install dict watchers for recrusive dict tag optimization
ghstack-source-id: 19aa6ea Pull-Request: #159796
1 parent 8d3d1c8 commit 3064d53

File tree

1 file changed

+156
-2
lines changed

1 file changed

+156
-2
lines changed

torch/csrc/dynamo/guards.cpp

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834834

835835
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
836836
static int dict_version_watcher_id;
837+
static int dict_recursive_tag_watcher_id;
837838
static uint64_t global_dict_version_id = 1;
838839
static int dict_version_watch_callback(
839840
PyDict_WatchEvent event,
@@ -1557,6 +1558,37 @@ class GuardManager;
15571558
class RootGuardManager;
15581559
class DictGuardManager;
15591560

1561+
// Global registry used by the *recursive-dict-tag* optimisation.
1562+
//
1563+
// Key : `PyObject*` pointing to a watched `dict`
1564+
// Value : list of `GuardManager*` instances that have recorded that dict
1565+
//
1566+
// Why is this global?
1567+
// -------------------
1568+
// * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1569+
// All `GuardManager`s therefore share a single watcher callback.
1570+
// * Different guard managers (possibly across different frames) can end up
1571+
// watching the same dictionary pointer. Therefore, we have a list of guard
1572+
// managers for each dict pointer.
1573+
//
1574+
// When is watch registered?
1575+
// * During the recording phase of recursive dict tag matching in GuardManager.
1576+
//
1577+
// When are they watched?
1578+
// * In the dict_recursive_tag_watch_callback function.
1579+
//
1580+
// When are the dict pointers unwatched?
1581+
// * If a dict is mutated or the guard manager deallocates.
1582+
// * Read `unwatch_all_saved_dict_pointers` docstring for more details.
1583+
//
1584+
// Expected size
1585+
// -------------
1586+
// Every compilation frame contributes its tag-safe dicts to this registry, so
1587+
// the container can grow large over the lifetime of the process. That’s
1588+
// acceptable: lookup is by pointer (hash/equals = identity) and each entry
1589+
// stores only lightweight pointers.
1590+
std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_managers;
1591+
15601592
/**
15611593
* Base class for the leaf guard in the GuardManager hierarchy.
15621594
*/
@@ -2625,6 +2657,7 @@ class GuardManager {
26252657

26262658
virtual ~GuardManager() {
26272659
cleanup_tag_safe_entries();
2660+
disable_recursive_dict_tag_optimization();
26282661
}
26292662

26302663
void cleanup_tag_safe_entries() {
@@ -2727,6 +2760,11 @@ class GuardManager {
27272760
_tensor_pointers[value] = tensor_pointers;
27282761
}
27292762

2763+
void disable_recursive_dict_tag_optimization() {
2764+
unwatch_all_saved_dict_pointers();
2765+
_disable_dict_tag_matching = true;
2766+
}
2767+
27302768
public:
27312769
// For cloning
27322770
GuardManager(
@@ -2833,6 +2871,10 @@ class GuardManager {
28332871
}
28342872

28352873
bool check_dict_pointer_tags(PyObject* value) {
2874+
if (_dict_callback_installed) {
2875+
// This means that for 3.12+, there are callbacks watching dict pointers.
2876+
return true;
2877+
}
28362878
for (auto& kv : _dict_pointers[value]) {
28372879
PyObject* dict_pointer = kv.first;
28382880
uint64_t old_tag = kv.second;
@@ -2963,6 +3005,11 @@ class GuardManager {
29633005
throw std::runtime_error(
29643006
"Could not register a callback for recursive dict tag optimization");
29653007
}
3008+
#if IS_PYTHON_3_12_PLUS
3009+
// Ideally we don't need to even register a weakref callback for value.
3010+
// But it does not hurt to be more cautious
3011+
_dict_callback_installed = watch_dict_pointers(value);
3012+
#endif
29663013
}
29673014
}
29683015
if (!result) {
@@ -2979,8 +3026,9 @@ class GuardManager {
29793026
}
29803027
GuardManager* guard_manager = static_cast<GuardManager*>(
29813028
PyCapsule_GetPointer(self_capsule, "GuardManager*"));
2982-
if (guard_manager)
2983-
guard_manager->_disable_dict_tag_matching = true;
3029+
if (guard_manager) {
3030+
guard_manager->disable_recursive_dict_tag_optimization();
3031+
}
29843032
Py_RETURN_NONE;
29853033
}
29863034

@@ -3031,6 +3079,81 @@ class GuardManager {
30313079
return true;
30323080
}
30333081

3082+
bool watch_dict_pointers(PyObject* value) {
3083+
#if IS_PYTHON_3_12_PLUS
3084+
// -----------------------------------------------------------------------------
3085+
// CPython 3.12 dict-watcher integration
3086+
// -----------------------------------------------------------------------------
3087+
//
3088+
// We register a single watcher on all every dictionary pointer recorded by
3089+
// a tag-safe root. The watcher callback fires *once* for any structural
3090+
// change to those dictionaries
3091+
//
3092+
// Fast-path benefit
3093+
// -----------------
3094+
// In steady state we no longer need to iterate over the recorded
3095+
// dictionaries and compare their `ma_version_tag`s (the
3096+
// “are-tags-unchanged” loop that used to dominate the fast-path guard
3097+
// evaluation). The presence of an *active watcher* is itself a guarantee
3098+
// that none of the dicts has mutated; if one **does** mutate, the callback
3099+
// simply flips `_disable_dict_tag_matching = true`, causing the next guard
3100+
// evaluation to skip the recursive-dict-tag optimisation entirely.
3101+
for (auto& kv : _dict_pointers[value]) {
3102+
PyObject* dict_pointer = kv.first;
3103+
int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer);
3104+
if (rc != 0) {
3105+
PyErr_Clear();
3106+
return false;
3107+
}
3108+
dict_to_guard_managers[dict_pointer].push_back(this);
3109+
}
3110+
#endif
3111+
return true;
3112+
}
3113+
3114+
void unwatch_all_saved_dict_pointers() {
3115+
/*
3116+
We may have recorded hundreds/thousands of dict pointers for the recursive
3117+
dict-tag optimisation. If any of those dicts mutates, we want to disable the
3118+
optimisation and then unwatch as many dict pointers as we can.
3119+
3120+
Be careful: the same dict pointer can be recorded by multiple GuardManagers.
3121+
So the flow is:
3122+
3123+
1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer].
3124+
2) If the list for that dict becomes empty, then:
3125+
- PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer)
3126+
- erase the dict_pointer entry from dict_to_guard_managers.
3127+
*/
3128+
#if IS_PYTHON_3_12_PLUS
3129+
if (!_disable_dict_tag_matching) {
3130+
for (auto& value_stashed_pointers : _dict_pointers) {
3131+
auto stashed_pointers = value_stashed_pointers.second;
3132+
3133+
for (auto& stashed_pointer : stashed_pointers) {
3134+
PyObject* dict_pointer = stashed_pointer.first;
3135+
3136+
// Delete the guard manager from the dict_to_guard_managers
3137+
auto it = std::find(
3138+
dict_to_guard_managers[dict_pointer].begin(),
3139+
dict_to_guard_managers[dict_pointer].end(),
3140+
this);
3141+
if (it != dict_to_guard_managers[dict_pointer].end()) {
3142+
dict_to_guard_managers[dict_pointer].erase(it);
3143+
}
3144+
3145+
// Unwatch the dict pointer if this was the last guard manager
3146+
// watching it.
3147+
if (dict_to_guard_managers[dict_pointer].empty()) {
3148+
PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer);
3149+
dict_to_guard_managers.erase(dict_pointer);
3150+
}
3151+
}
3152+
}
3153+
}
3154+
#endif
3155+
}
3156+
30343157
virtual bool check_nopybind(FrameLocalsMapping* value) {
30353158
return check_nopybind_template(value);
30363159
}
@@ -3270,6 +3393,9 @@ class GuardManager {
32703393
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
32713394
std::vector<WeakEntry> _tag_safe_entries;
32723395

3396+
// 3.12+ related helper
3397+
bool _dict_callback_installed = false;
3398+
32733399
protected:
32743400
// weakref to the type of guarded value
32753401
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4083,27 @@ void add_relational_guard_resetter_to_cloned_root(
39574083
root->add_relational_guard_resetter(std::move(guard));
39584084
}
39594085

4086+
#if IS_PYTHON_3_12_PLUS
4087+
static int dict_recursive_tag_watch_callback(
4088+
PyDict_WatchEvent event,
4089+
PyObject* dict,
4090+
PyObject* key,
4091+
PyObject* new_value) noexcept {
4092+
if (event != PyDict_EVENT_CLONED) {
4093+
auto it = dict_to_guard_managers.find(dict);
4094+
if (it != dict_to_guard_managers.end()) {
4095+
auto guard_managers = it->second;
4096+
for (auto& guard_manager : guard_managers) {
4097+
if (guard_manager) {
4098+
guard_manager->disable_recursive_dict_tag_optimization();
4099+
}
4100+
}
4101+
}
4102+
}
4103+
return 0; // keep watching
4104+
}
4105+
#endif
4106+
39604107
std::unique_ptr<GuardManager> make_guard_manager(
39614108
RootGuardManager* root,
39624109
std::string source,
@@ -7558,6 +7705,13 @@ PyObject* torch_c_dynamo_guards_init() {
75587705
throw std::runtime_error("Failed to install dict_version_watch_callback");
75597706
}
75607707

7708+
dict_recursive_tag_watcher_id =
7709+
PyDict_AddWatcher(dict_recursive_tag_watch_callback);
7710+
if (dict_recursive_tag_watcher_id == -1) {
7711+
throw std::runtime_error(
7712+
"Failed to install dict_recursive_tag_watch_callback");
7713+
}
7714+
75617715
#endif
75627716

75637717
return m;

0 commit comments

Comments
 (0)