From 3427796c35db58530d24058397ba81592535e77d Mon Sep 17 00:00:00 2001 From: Federico Bond Date: Thu, 2 May 2024 10:33:09 +1000 Subject: [PATCH] fix: event handler methods are not thread-safe The _client_handlers dictionary allowed modifications during iteration without proper concurrency control. I added some reentrant locks to manage concurrent access to the _global_handlers and _client_handlers data structures. See #326 Signed-off-by: Federico Bond --- openfeature/_event_support.py | 43 +++++++++++++++++++++++------------ tests/test_client.py | 28 +++++++++++++++++++++++ 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/openfeature/_event_support.py b/openfeature/_event_support.py index 26753b05..42e6250b 100644 --- a/openfeature/_event_support.py +++ b/openfeature/_event_support.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading from collections import defaultdict from typing import TYPE_CHECKING, Dict, List @@ -15,7 +16,10 @@ from openfeature.client import OpenFeatureClient +_global_lock = threading.RLock() _global_handlers: Dict[ProviderEvent, List[EventHandler]] = defaultdict(list) + +_client_lock = threading.RLock() _client_handlers: Dict[OpenFeatureClient, Dict[ProviderEvent, List[EventHandler]]] = ( defaultdict(lambda: defaultdict(list)) ) @@ -24,20 +28,23 @@ def run_client_handlers( client: OpenFeatureClient, event: ProviderEvent, details: EventDetails ) -> None: - for handler in _client_handlers[client][event]: - handler(details) + with _client_lock: + for handler in _client_handlers[client][event]: + handler(details) def run_global_handlers(event: ProviderEvent, details: EventDetails) -> None: - for handler in _global_handlers[event]: - handler(details) + with _global_lock: + for handler in _global_handlers[event]: + handler(details) def add_client_handler( client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler ) -> None: - handlers = _client_handlers[client][event] - handlers.append(handler) + with _client_lock: + handlers = _client_handlers[client][event] + handlers.append(handler) _run_immediate_handler(client, event, handler) @@ -45,12 +52,14 @@ def add_client_handler( def remove_client_handler( client: OpenFeatureClient, event: ProviderEvent, handler: EventHandler ) -> None: - handlers = _client_handlers[client][event] - handlers.remove(handler) + with _client_lock: + handlers = _client_handlers[client][event] + handlers.remove(handler) def add_global_handler(event: ProviderEvent, handler: EventHandler) -> None: - _global_handlers[event].append(handler) + with _global_lock: + _global_handlers[event].append(handler) from openfeature.api import get_client @@ -58,7 +67,8 @@ def add_global_handler(event: ProviderEvent, handler: EventHandler) -> None: def remove_global_handler(event: ProviderEvent, handler: EventHandler) -> None: - _global_handlers[event].remove(handler) + with _global_lock: + _global_handlers[event].remove(handler) def run_handlers_for_provider( @@ -72,9 +82,10 @@ def run_handlers_for_provider( # run the global handlers run_global_handlers(event, details) # run the handlers for clients associated to this provider - for client in _client_handlers: - if client.provider == provider: - run_client_handlers(client, event, details) + with _client_lock: + for client in _client_handlers: + if client.provider == provider: + run_client_handlers(client, event, details) def _run_immediate_handler( @@ -91,5 +102,7 @@ def _run_immediate_handler( def clear() -> None: - _global_handlers.clear() - _client_handlers.clear() + with _global_lock: + _global_handlers.clear() + with _client_lock: + _client_handlers.clear() diff --git a/tests/test_client.py b/tests/test_client.py index dc25abee..b51c460c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,6 @@ +import time +import uuid +from concurrent.futures import ThreadPoolExecutor from unittest.mock import MagicMock import pytest @@ -356,3 +359,28 @@ def test_provider_event_late_binding(): # Then spy.provider_configuration_changed.assert_called_once_with(details) + + +def test_client_handlers_thread_safety(): + provider = NoOpProvider() + set_provider(provider) + + def add_handlers_task(): + def handler(*args, **kwargs): + time.sleep(0.005) + + for _ in range(10): + time.sleep(0.01) + client = get_client(str(uuid.uuid4())) + client.add_handler(ProviderEvent.PROVIDER_CONFIGURATION_CHANGED, handler) + + def emit_events_task(): + for _ in range(10): + time.sleep(0.01) + provider.emit_provider_configuration_changed(ProviderEventDetails()) + + with ThreadPoolExecutor(max_workers=2) as executor: + f1 = executor.submit(add_handlers_task) + f2 = executor.submit(emit_events_task) + f1.result() + f2.result()