Skip to content

Commit ba8467d

Browse files
committed
[dynamo][guards] Install dict watchers for recrusive dict tag optimization
ghstack-source-id: 5b2b049 Pull-Request: #159796
1 parent 589ab3f commit ba8467d

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

torch/csrc/dynamo/guards.cpp

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834834

835835
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
836836
static int dict_version_watcher_id;
837+
static int dict_recursive_tag_watcher_id;
837838
static uint64_t global_dict_version_id = 1;
838839
static int dict_version_watch_callback(
839840
PyDict_WatchEvent event,
@@ -1545,6 +1546,27 @@ class GuardManager;
15451546
class RootGuardManager;
15461547
class DictGuardManager;
15471548

1549+
// Global registry used by the *recursive-dict-tag* optimisation.
1550+
//
1551+
// Key : `PyObject*` pointing to a watched `dict`
1552+
// Value : list of `GuardManager*` instances that have recorded that dict
1553+
//
1554+
// Why is this global?
1555+
// -------------------
1556+
// * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1557+
// All `GuardManager`s therefore share a single watcher callback.
1558+
// * A “tag-safe root” in one compilation frame may observe the **same** object
1559+
// as another frame. Consequently multiple `GuardManager`s can end up
1560+
// watching the exact same dictionary pointer.
1561+
//
1562+
// Expected size
1563+
// -------------
1564+
// Every compilation frame contributes its tag-safe dicts to this registry, so
1565+
// the container can grow large over the lifetime of the process. That’s
1566+
// acceptable: lookup is by pointer (hash/equals = identity) and each entry
1567+
// stores only lightweight pointers.
1568+
std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_manager;
1569+
15481570
/**
15491571
* Base class for the leaf guard in the GuardManager hierarchy.
15501572
*/
@@ -2613,6 +2635,7 @@ class GuardManager {
26132635

26142636
virtual ~GuardManager() {
26152637
cleanup_tag_safe_entries();
2638+
unwatch_dict_pointers();
26162639
}
26172640

26182641
void cleanup_tag_safe_entries() {
@@ -2715,6 +2738,10 @@ class GuardManager {
27152738
_tensor_pointers[value] = tensor_pointers;
27162739
}
27172740

2741+
void disable_recursive_dict_tag_optimization() {
2742+
_disable_dict_tag_matching = true;
2743+
}
2744+
27182745
public:
27192746
// For cloning
27202747
GuardManager(
@@ -2821,6 +2848,10 @@ class GuardManager {
28212848
}
28222849

28232850
bool check_dict_pointer_tags(PyObject* value) {
2851+
if (_dict_callback_installed) {
2852+
// This means that for 3.12+, there are callbacks watching dict pointers.
2853+
return true;
2854+
}
28242855
for (auto& kv : _dict_pointers[value]) {
28252856
PyObject* dict_pointer = kv.first;
28262857
uint64_t old_tag = kv.second;
@@ -2951,6 +2982,11 @@ class GuardManager {
29512982
throw std::runtime_error(
29522983
"Could not register a callback for recursive dict tag optimization");
29532984
}
2985+
#if IS_PYTHON_3_12_PLUS
2986+
// Ideally we don't need to even register a weakref callback for value.
2987+
// But it does not hurt to be more cautious
2988+
_dict_callback_installed = watch_dict_pointers(value);
2989+
#endif
29542990
}
29552991
}
29562992
if (!result) {
@@ -2967,8 +3003,16 @@ class GuardManager {
29673003
}
29683004
GuardManager* guard_manager = static_cast<GuardManager*>(
29693005
PyCapsule_GetPointer(self_capsule, "GuardManager*"));
2970-
if (guard_manager)
3006+
if (guard_manager) {
29713007
guard_manager->_disable_dict_tag_matching = true;
3008+
// When the entry for a given value in _dict_pointers becomes invalid, it
3009+
// can only be because that value itself has been garbage-collected. At
3010+
// that moment every dictionary it referenced has also been reclaimed, and
3011+
// the dict-watch callback triggered during their finalisation has already
3012+
// invoked unwatch_dict_pointers. In short, by the time we reach this
3013+
// code the dict pointers have been unwatched automatically, so no further
3014+
// action is needed.
3015+
}
29723016
Py_RETURN_NONE;
29733017
}
29743018

@@ -3019,6 +3063,59 @@ class GuardManager {
30193063
return true;
30203064
}
30213065

3066+
bool watch_dict_pointers(PyObject* value) {
3067+
#if IS_PYTHON_3_12_PLUS
3068+
// -----------------------------------------------------------------------------
3069+
// CPython 3.12 dict-watcher integration
3070+
// -----------------------------------------------------------------------------
3071+
//
3072+
// We register a single watcher on all every dictionary pointer recorded by
3073+
// a tag-safe root. The watcher callback fires *once* for any structural
3074+
// change to those dictionaries
3075+
//
3076+
// Fast-path benefit
3077+
// -----------------
3078+
// In steady state we no longer need to iterate over the recorded
3079+
// dictionaries and compare their `ma_version_tag`s (the
3080+
// “are-tags-unchanged” loop that used to dominate the fast-path guard
3081+
// evaluation). The presence of an *active watcher* is itself a guarantee
3082+
// that none of the dicts has mutated; if one **does** mutate, the callback
3083+
// simply flips `_disable_dict_tag_matching = true`, causing the next guard
3084+
// evaluation to skip the recursive-dict-tag optimisation entirely.
3085+
for (auto& kv : _dict_pointers[value]) {
3086+
PyObject* dict_pointer = kv.first;
3087+
int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer);
3088+
if (rc != 0) {
3089+
PyErr_Clear();
3090+
return false;
3091+
}
3092+
dict_to_guard_manager[dict_pointer].push_back(this);
3093+
}
3094+
#endif
3095+
return true;
3096+
}
3097+
3098+
void unwatch_dict_pointers() {
3099+
#if IS_PYTHON_3_12_PLUS
3100+
if (!_dict_callbacks_unwatched && !_disable_dict_tag_matching) {
3101+
for (auto& value_stashed_pointers : _dict_pointers) {
3102+
auto stashed_pointers = value_stashed_pointers.second;
3103+
3104+
for (auto& stashed_pointer : stashed_pointers) {
3105+
PyObject* dict_pointer = stashed_pointer.first;
3106+
PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer);
3107+
auto it = std::find(
3108+
dict_to_guard_manager[dict_pointer].begin(),
3109+
dict_to_guard_manager[dict_pointer].end(),
3110+
this);
3111+
dict_to_guard_manager[dict_pointer].erase(it);
3112+
}
3113+
}
3114+
}
3115+
_dict_callbacks_unwatched = true;
3116+
#endif
3117+
}
3118+
30223119
virtual bool check_nopybind(FrameLocalsMapping* value) {
30233120
return check_nopybind_template(value);
30243121
}
@@ -3258,6 +3355,12 @@ class GuardManager {
32583355
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
32593356
std::vector<WeakEntry> _tag_safe_entries;
32603357

3358+
// 3.12+ related helper
3359+
bool _dict_callback_installed = false;
3360+
#if IS_PYTHON_3_12_PLUS
3361+
bool _dict_callbacks_unwatched = false;
3362+
#endif
3363+
32613364
protected:
32623365
// weakref to the type of guarded value
32633366
// protected because it is used for cloning by DictGuardManager
@@ -3945,6 +4048,24 @@ void add_relational_guard_resetter_to_cloned_root(
39454048
root->add_relational_guard_resetter(std::move(guard));
39464049
}
39474050

4051+
#if IS_PYTHON_3_12_PLUS
4052+
static int dict_recursive_tag_watch_callback(
4053+
PyDict_WatchEvent event,
4054+
PyObject* dict,
4055+
PyObject* key,
4056+
PyObject* new_value) noexcept {
4057+
auto it = dict_to_guard_manager.find(dict);
4058+
if (it != dict_to_guard_manager.end()) {
4059+
auto guard_managers = it->second;
4060+
for (auto& guard_manager : guard_managers) {
4061+
guard_manager->unwatch_dict_pointers();
4062+
guard_manager->disable_recursive_dict_tag_optimization();
4063+
}
4064+
}
4065+
return 0; // keep watching
4066+
}
4067+
#endif
4068+
39484069
std::unique_ptr<GuardManager> make_guard_manager(
39494070
RootGuardManager* root,
39504071
std::string source,
@@ -7546,6 +7667,13 @@ PyObject* torch_c_dynamo_guards_init() {
75467667
throw std::runtime_error("Failed to install dict_version_watch_callback");
75477668
}
75487669

7670+
dict_recursive_tag_watcher_id =
7671+
PyDict_AddWatcher(dict_recursive_tag_watch_callback);
7672+
if (dict_recursive_tag_watcher_id == -1) {
7673+
throw std::runtime_error(
7674+
"Failed to install dict_recursive_tag_watch_callback");
7675+
}
7676+
75497677
#endif
75507678

75517679
return m;

0 commit comments

Comments
 (0)