Skip to content

Commit 54b4d18

Browse files
committed
[dynamo][guards] Install dict watchers for recrusive dict tag optimization
ghstack-source-id: 1e83e40 Pull-Request: #159796
1 parent 87bffca commit 54b4d18

File tree

1 file changed

+127
-1
lines changed

1 file changed

+127
-1
lines changed

torch/csrc/dynamo/guards.cpp

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834834

835835
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
836836
static int dict_version_watcher_id;
837+
static int dict_recursive_tag_watcher_id;
837838
static uint64_t global_dict_version_id = 1;
838839
static int dict_version_watch_callback(
839840
PyDict_WatchEvent event,
@@ -1557,6 +1558,27 @@ class GuardManager;
15571558
class RootGuardManager;
15581559
class DictGuardManager;
15591560

1561+
// Global registry used by the *recursive-dict-tag* optimisation.
1562+
//
1563+
// Key : `PyObject*` pointing to a watched `dict`
1564+
// Value : list of `GuardManager*` instances that have recorded that dict
1565+
//
1566+
// Why is this global?
1567+
// -------------------
1568+
// * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1569+
// All `GuardManager`s therefore share a single watcher callback.
1570+
// * A “tag-safe root” in one compilation frame may observe the **same** object
1571+
// as another frame. Consequently multiple `GuardManager`s can end up
1572+
// watching the exact same dictionary pointer.
1573+
//
1574+
// Expected size
1575+
// -------------
1576+
// Every compilation frame contributes its tag-safe dicts to this registry, so
1577+
// the container can grow large over the lifetime of the process. That’s
1578+
// acceptable: lookup is by pointer (hash/equals = identity) and each entry
1579+
// stores only lightweight pointers.
1580+
std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_manager;
1581+
15601582
/**
15611583
* Base class for the leaf guard in the GuardManager hierarchy.
15621584
*/
@@ -2625,6 +2647,7 @@ class GuardManager {
26252647

26262648
virtual ~GuardManager() {
26272649
cleanup_tag_safe_entries();
2650+
unwatch_dict_pointers();
26282651
}
26292652

26302653
void cleanup_tag_safe_entries() {
@@ -2727,6 +2750,10 @@ class GuardManager {
27272750
_tensor_pointers[value] = tensor_pointers;
27282751
}
27292752

2753+
void disable_recursive_dict_tag_optimization() {
2754+
_disable_dict_tag_matching = true;
2755+
}
2756+
27302757
public:
27312758
// For cloning
27322759
GuardManager(
@@ -2833,6 +2860,10 @@ class GuardManager {
28332860
}
28342861

28352862
bool check_dict_pointer_tags(PyObject* value) {
2863+
if (_dict_callback_installed) {
2864+
// This means that for 3.12+, there are callbacks watching dict pointers.
2865+
return true;
2866+
}
28362867
for (auto& kv : _dict_pointers[value]) {
28372868
PyObject* dict_pointer = kv.first;
28382869
uint64_t old_tag = kv.second;
@@ -2963,6 +2994,11 @@ class GuardManager {
29632994
throw std::runtime_error(
29642995
"Could not register a callback for recursive dict tag optimization");
29652996
}
2997+
#if IS_PYTHON_3_12_PLUS
2998+
// Ideally we don't need to even register a weakref callback for value.
2999+
// But it does not hurt to be more cautious
3000+
_dict_callback_installed = watch_dict_pointers(value);
3001+
#endif
29663002
}
29673003
}
29683004
if (!result) {
@@ -2979,8 +3015,10 @@ class GuardManager {
29793015
}
29803016
GuardManager* guard_manager = static_cast<GuardManager*>(
29813017
PyCapsule_GetPointer(self_capsule, "GuardManager*"));
2982-
if (guard_manager)
3018+
if (guard_manager) {
3019+
guard_manager->unwatch_dict_pointers();
29833020
guard_manager->_disable_dict_tag_matching = true;
3021+
}
29843022
Py_RETURN_NONE;
29853023
}
29863024

@@ -3031,6 +3069,61 @@ class GuardManager {
30313069
return true;
30323070
}
30333071

3072+
bool watch_dict_pointers(PyObject* value) {
3073+
#if IS_PYTHON_3_12_PLUS
3074+
// -----------------------------------------------------------------------------
3075+
// CPython 3.12 dict-watcher integration
3076+
// -----------------------------------------------------------------------------
3077+
//
3078+
// We register a single watcher on all every dictionary pointer recorded by
3079+
// a tag-safe root. The watcher callback fires *once* for any structural
3080+
// change to those dictionaries
3081+
//
3082+
// Fast-path benefit
3083+
// -----------------
3084+
// In steady state we no longer need to iterate over the recorded
3085+
// dictionaries and compare their `ma_version_tag`s (the
3086+
// “are-tags-unchanged” loop that used to dominate the fast-path guard
3087+
// evaluation). The presence of an *active watcher* is itself a guarantee
3088+
// that none of the dicts has mutated; if one **does** mutate, the callback
3089+
// simply flips `_disable_dict_tag_matching = true`, causing the next guard
3090+
// evaluation to skip the recursive-dict-tag optimisation entirely.
3091+
for (auto& kv : _dict_pointers[value]) {
3092+
PyObject* dict_pointer = kv.first;
3093+
int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer);
3094+
if (rc != 0) {
3095+
PyErr_Clear();
3096+
return false;
3097+
}
3098+
dict_to_guard_manager[dict_pointer].push_back(this);
3099+
}
3100+
#endif
3101+
return true;
3102+
}
3103+
3104+
void unwatch_dict_pointers() {
3105+
#if IS_PYTHON_3_12_PLUS
3106+
if (!_dict_callbacks_unwatched && !_disable_dict_tag_matching) {
3107+
for (auto& value_stashed_pointers : _dict_pointers) {
3108+
auto stashed_pointers = value_stashed_pointers.second;
3109+
3110+
for (auto& stashed_pointer : stashed_pointers) {
3111+
PyObject* dict_pointer = stashed_pointer.first;
3112+
PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer);
3113+
auto it = std::find(
3114+
dict_to_guard_manager[dict_pointer].begin(),
3115+
dict_to_guard_manager[dict_pointer].end(),
3116+
this);
3117+
if (it != dict_to_guard_manager[dict_pointer].end()) {
3118+
dict_to_guard_manager[dict_pointer].erase(it);
3119+
}
3120+
}
3121+
}
3122+
}
3123+
_dict_callbacks_unwatched = true;
3124+
#endif
3125+
}
3126+
30343127
virtual bool check_nopybind(FrameLocalsMapping* value) {
30353128
return check_nopybind_template(value);
30363129
}
@@ -3270,6 +3363,12 @@ class GuardManager {
32703363
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
32713364
std::vector<WeakEntry> _tag_safe_entries;
32723365

3366+
// 3.12+ related helper
3367+
bool _dict_callback_installed = false;
3368+
#if IS_PYTHON_3_12_PLUS
3369+
bool _dict_callbacks_unwatched = false;
3370+
#endif
3371+
32733372
protected:
32743373
// weakref to the type of guarded value
32753374
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4056,26 @@ void add_relational_guard_resetter_to_cloned_root(
39574056
root->add_relational_guard_resetter(std::move(guard));
39584057
}
39594058

4059+
#if IS_PYTHON_3_12_PLUS
4060+
static int dict_recursive_tag_watch_callback(
4061+
PyDict_WatchEvent event,
4062+
PyObject* dict,
4063+
PyObject* key,
4064+
PyObject* new_value) noexcept {
4065+
auto it = dict_to_guard_manager.find(dict);
4066+
if (it != dict_to_guard_manager.end()) {
4067+
auto guard_managers = it->second;
4068+
for (auto& guard_manager : guard_managers) {
4069+
if (guard_manager) {
4070+
guard_manager->unwatch_dict_pointers();
4071+
guard_manager->disable_recursive_dict_tag_optimization();
4072+
}
4073+
}
4074+
}
4075+
return 0; // keep watching
4076+
}
4077+
#endif
4078+
39604079
std::unique_ptr<GuardManager> make_guard_manager(
39614080
RootGuardManager* root,
39624081
std::string source,
@@ -7558,6 +7677,13 @@ PyObject* torch_c_dynamo_guards_init() {
75587677
throw std::runtime_error("Failed to install dict_version_watch_callback");
75597678
}
75607679

7680+
dict_recursive_tag_watcher_id =
7681+
PyDict_AddWatcher(dict_recursive_tag_watch_callback);
7682+
if (dict_recursive_tag_watcher_id == -1) {
7683+
throw std::runtime_error(
7684+
"Failed to install dict_recursive_tag_watch_callback");
7685+
}
7686+
75617687
#endif
75627688

75637689
return m;

0 commit comments

Comments
 (0)