Skip to content

Commit 73d3d4a

Browse files
committed
[dynamo][guards] Install dict watchers for recrusive dict tag optimization
ghstack-source-id: e95c47f Pull-Request: #159796
1 parent 48da9a5 commit 73d3d4a

File tree

1 file changed

+129
-1
lines changed

1 file changed

+129
-1
lines changed

torch/csrc/dynamo/guards.cpp

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
834834

835835
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
836836
static int dict_version_watcher_id;
837+
static int dict_recursive_tag_watcher_id;
837838
static uint64_t global_dict_version_id = 1;
838839
static int dict_version_watch_callback(
839840
PyDict_WatchEvent event,
@@ -1557,6 +1558,27 @@ class GuardManager;
15571558
class RootGuardManager;
15581559
class DictGuardManager;
15591560

1561+
// Global registry used by the *recursive-dict-tag* optimisation.
1562+
//
1563+
// Key : `PyObject*` pointing to a watched `dict`
1564+
// Value : list of `GuardManager*` instances that have recorded that dict
1565+
//
1566+
// Why is this global?
1567+
// -------------------
1568+
// * CPython allows only a small, fixed number of dict-watcher IDs (≈64).
1569+
// All `GuardManager`s therefore share a single watcher callback.
1570+
// * A “tag-safe root” in one compilation frame may observe the **same** object
1571+
// as another frame. Consequently multiple `GuardManager`s can end up
1572+
// watching the exact same dictionary pointer.
1573+
//
1574+
// Expected size
1575+
// -------------
1576+
// Every compilation frame contributes its tag-safe dicts to this registry, so
1577+
// the container can grow large over the lifetime of the process. That’s
1578+
// acceptable: lookup is by pointer (hash/equals = identity) and each entry
1579+
// stores only lightweight pointers.
1580+
std::unordered_map<PyObject*, std::list<GuardManager*>> dict_to_guard_manager;
1581+
15601582
/**
15611583
* Base class for the leaf guard in the GuardManager hierarchy.
15621584
*/
@@ -2625,6 +2647,7 @@ class GuardManager {
26252647

26262648
virtual ~GuardManager() {
26272649
cleanup_tag_safe_entries();
2650+
unwatch_dict_pointers();
26282651
}
26292652

26302653
void cleanup_tag_safe_entries() {
@@ -2727,6 +2750,10 @@ class GuardManager {
27272750
_tensor_pointers[value] = tensor_pointers;
27282751
}
27292752

2753+
void disable_recursive_dict_tag_optimization() {
2754+
_disable_dict_tag_matching = true;
2755+
}
2756+
27302757
public:
27312758
// For cloning
27322759
GuardManager(
@@ -2833,6 +2860,10 @@ class GuardManager {
28332860
}
28342861

28352862
bool check_dict_pointer_tags(PyObject* value) {
2863+
if (_dict_callback_installed) {
2864+
// This means that for 3.12+, there are callbacks watching dict pointers.
2865+
return true;
2866+
}
28362867
for (auto& kv : _dict_pointers[value]) {
28372868
PyObject* dict_pointer = kv.first;
28382869
uint64_t old_tag = kv.second;
@@ -2963,6 +2994,11 @@ class GuardManager {
29632994
throw std::runtime_error(
29642995
"Could not register a callback for recursive dict tag optimization");
29652996
}
2997+
#if IS_PYTHON_3_12_PLUS
2998+
// Ideally we don't need to even register a weakref callback for value.
2999+
// But it does not hurt to be more cautious
3000+
_dict_callback_installed = watch_dict_pointers(value);
3001+
#endif
29663002
}
29673003
}
29683004
if (!result) {
@@ -2979,8 +3015,16 @@ class GuardManager {
29793015
}
29803016
GuardManager* guard_manager = static_cast<GuardManager*>(
29813017
PyCapsule_GetPointer(self_capsule, "GuardManager*"));
2982-
if (guard_manager)
3018+
if (guard_manager) {
29833019
guard_manager->_disable_dict_tag_matching = true;
3020+
// When the entry for a given value in _dict_pointers becomes invalid, it
3021+
// can only be because that value itself has been garbage-collected. At
3022+
// that moment every dictionary it referenced has also been reclaimed, and
3023+
// the dict-watch callback triggered during their finalisation has already
3024+
// invoked unwatch_dict_pointers. In short, by the time we reach this
3025+
// code the dict pointers have been unwatched automatically, so no further
3026+
// action is needed.
3027+
}
29843028
Py_RETURN_NONE;
29853029
}
29863030

@@ -3031,6 +3075,59 @@ class GuardManager {
30313075
return true;
30323076
}
30333077

3078+
bool watch_dict_pointers(PyObject* value) {
3079+
#if IS_PYTHON_3_12_PLUS
3080+
// -----------------------------------------------------------------------------
3081+
// CPython 3.12 dict-watcher integration
3082+
// -----------------------------------------------------------------------------
3083+
//
3084+
// We register a single watcher on all every dictionary pointer recorded by
3085+
// a tag-safe root. The watcher callback fires *once* for any structural
3086+
// change to those dictionaries
3087+
//
3088+
// Fast-path benefit
3089+
// -----------------
3090+
// In steady state we no longer need to iterate over the recorded
3091+
// dictionaries and compare their `ma_version_tag`s (the
3092+
// “are-tags-unchanged” loop that used to dominate the fast-path guard
3093+
// evaluation). The presence of an *active watcher* is itself a guarantee
3094+
// that none of the dicts has mutated; if one **does** mutate, the callback
3095+
// simply flips `_disable_dict_tag_matching = true`, causing the next guard
3096+
// evaluation to skip the recursive-dict-tag optimisation entirely.
3097+
for (auto& kv : _dict_pointers[value]) {
3098+
PyObject* dict_pointer = kv.first;
3099+
int rc = PyDict_Watch(dict_recursive_tag_watcher_id, dict_pointer);
3100+
if (rc != 0) {
3101+
PyErr_Clear();
3102+
return false;
3103+
}
3104+
dict_to_guard_manager[dict_pointer].push_back(this);
3105+
}
3106+
#endif
3107+
return true;
3108+
}
3109+
3110+
void unwatch_dict_pointers() {
3111+
#if IS_PYTHON_3_12_PLUS
3112+
if (!_dict_callbacks_unwatched && !_disable_dict_tag_matching) {
3113+
for (auto& value_stashed_pointers : _dict_pointers) {
3114+
auto stashed_pointers = value_stashed_pointers.second;
3115+
3116+
for (auto& stashed_pointer : stashed_pointers) {
3117+
PyObject* dict_pointer = stashed_pointer.first;
3118+
PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer);
3119+
auto it = std::find(
3120+
dict_to_guard_manager[dict_pointer].begin(),
3121+
dict_to_guard_manager[dict_pointer].end(),
3122+
this);
3123+
dict_to_guard_manager[dict_pointer].erase(it);
3124+
}
3125+
}
3126+
}
3127+
_dict_callbacks_unwatched = true;
3128+
#endif
3129+
}
3130+
30343131
virtual bool check_nopybind(FrameLocalsMapping* value) {
30353132
return check_nopybind_template(value);
30363133
}
@@ -3270,6 +3367,12 @@ class GuardManager {
32703367
std::unordered_map<PyObject*, std::vector<PyObject*>> _tensor_pointers;
32713368
std::vector<WeakEntry> _tag_safe_entries;
32723369

3370+
// 3.12+ related helper
3371+
bool _dict_callback_installed = false;
3372+
#if IS_PYTHON_3_12_PLUS
3373+
bool _dict_callbacks_unwatched = false;
3374+
#endif
3375+
32733376
protected:
32743377
// weakref to the type of guarded value
32753378
// protected because it is used for cloning by DictGuardManager
@@ -3957,6 +4060,24 @@ void add_relational_guard_resetter_to_cloned_root(
39574060
root->add_relational_guard_resetter(std::move(guard));
39584061
}
39594062

4063+
#if IS_PYTHON_3_12_PLUS
4064+
static int dict_recursive_tag_watch_callback(
4065+
PyDict_WatchEvent event,
4066+
PyObject* dict,
4067+
PyObject* key,
4068+
PyObject* new_value) noexcept {
4069+
auto it = dict_to_guard_manager.find(dict);
4070+
if (it != dict_to_guard_manager.end()) {
4071+
auto guard_managers = it->second;
4072+
for (auto& guard_manager : guard_managers) {
4073+
guard_manager->unwatch_dict_pointers();
4074+
guard_manager->disable_recursive_dict_tag_optimization();
4075+
}
4076+
}
4077+
return 0; // keep watching
4078+
}
4079+
#endif
4080+
39604081
std::unique_ptr<GuardManager> make_guard_manager(
39614082
RootGuardManager* root,
39624083
std::string source,
@@ -7558,6 +7679,13 @@ PyObject* torch_c_dynamo_guards_init() {
75587679
throw std::runtime_error("Failed to install dict_version_watch_callback");
75597680
}
75607681

7682+
dict_recursive_tag_watcher_id =
7683+
PyDict_AddWatcher(dict_recursive_tag_watch_callback);
7684+
if (dict_recursive_tag_watcher_id == -1) {
7685+
throw std::runtime_error(
7686+
"Failed to install dict_recursive_tag_watch_callback");
7687+
}
7688+
75617689
#endif
75627690

75637691
return m;

0 commit comments

Comments
 (0)