diff --git a/aten/src/ATen/core/CachingHostAllocator.h b/aten/src/ATen/core/CachingHostAllocator.h index 5049018d731e..a8f5f2fd7997 100644 --- a/aten/src/ATen/core/CachingHostAllocator.h +++ b/aten/src/ATen/core/CachingHostAllocator.h @@ -251,6 +251,7 @@ struct CachingHostAllocatorImpl { auto* block = reinterpret_cast(ctx); std::optional> events; + ska::flat_hash_set streams; { std::lock_guard g(block->mutex_); block->allocated_ = false; @@ -259,14 +260,19 @@ struct CachingHostAllocatorImpl { } else { events = std::vector(); events->reserve(block->streams_.size()); - for (auto stream : block->streams_) { - record_stream(events, stream); - } - block->event_count_ += events->size(); + block->event_count_ += block->streams_.size(); + // Move out streams to avoid holding the mutex during event recording + streams = std::move(block->streams_); block->streams_.clear(); } } + // Event recording must be done outside the mutex to avoid potential + // deadlocks (e.g., when Python GIL is involved) + for (auto stream : streams) { + record_stream(events, stream); + } + if (!events) { auto index = size_index(block->size_); std::lock_guard g(free_list_[index].mutex_);