From ef4a49d0cd2b8d7dd7207278c5f4a00874c86bd1 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Thu, 25 Apr 2024 14:46:20 -0700 Subject: [PATCH 1/5] Close all and counts --- adafruit_connection_manager.py | 47 +++++++++--- examples/connectionmanager_helpers.py | 32 +++++++- tests/conftest.py | 6 +- tests/connection_manager_close_all_test.py | 87 ++++++++++++++++++++++ tests/free_socket_test.py | 16 ++++ 5 files changed, 173 insertions(+), 15 deletions(-) create mode 100644 tests/connection_manager_close_all_test.py diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index b591372..9e1630f 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -99,7 +99,8 @@ def create_fake_ssl_context( return _FakeSSLContext(iface) -_global_socketpool = {} +_global_connection_managers = {} +_global_socketpools = {} _global_ssl_contexts = {} @@ -113,7 +114,7 @@ def get_radio_socketpool(radio): * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ class_name = radio.__class__.__name__ - if class_name not in _global_socketpool: + if class_name not in _global_socketpools: if class_name == "Radio": import ssl # pylint: disable=import-outside-toplevel @@ -151,10 +152,10 @@ def get_radio_socketpool(radio): else: raise AttributeError(f"Unsupported radio class: {class_name}") - _global_socketpool[class_name] = pool + _global_socketpools[class_name] = pool _global_ssl_contexts[class_name] = ssl_context - return _global_socketpool[class_name] + return _global_socketpools[class_name] def get_radio_ssl_context(radio): @@ -186,10 +187,10 @@ def __init__( self._available_socket = {} self._open_sockets = {} - def _free_sockets(self) -> None: + def _free_sockets(self, force: bool = False) -> None: available_sockets = [] for socket, free in self._available_socket.items(): - if free: + if free or force: available_sockets.append(socket) for socket in available_sockets: @@ -203,6 +204,18 @@ def _get_key_for_socket(self, socket): except StopIteration: return None + @property + def open_sockets(self) -> int: + """Get the count of open sockets""" + return len(self._open_sockets) + + @property + def freeable_open_sockets(self) -> int: + """Get the count of freeable open sockets""" + return len( + [socket for socket, free in self._available_socket.items() if free is True] + ) + def close_socket(self, socket: SocketType) -> None: """Close a previously opened socket.""" if socket not in self._open_sockets.values(): @@ -306,11 +319,25 @@ def get_socket( # global helpers -_global_connection_manager = {} +def connection_manager_close_all( + socket_pool: Optional[SocketpoolModuleType] = None, +) -> None: + """Close all open sockets for pool""" + if socket_pool: + keys = [socket_pool] + else: + keys = _global_connection_managers.keys() + + for key in keys: + connection_manager = _global_connection_managers.get(key, None) + if connection_manager is None: + raise RuntimeError("SocketPool not managed") + + connection_manager._free_sockets(force=True) # pylint: disable=protected-access def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """Get the ConnectionManager singleton for the given pool""" - if socket_pool not in _global_connection_manager: - _global_connection_manager[socket_pool] = ConnectionManager(socket_pool) - return _global_connection_manager[socket_pool] + if socket_pool not in _global_connection_managers: + _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) + return _global_connection_managers[socket_pool] diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py index 36f4af6..df383bd 100644 --- a/examples/connectionmanager_helpers.py +++ b/examples/connectionmanager_helpers.py @@ -24,14 +24,38 @@ # get request session requests = adafruit_requests.Session(pool, ssl_context) +connection_manager = adafruit_connection_manager.get_connection_manager(pool) +print("-" * 40) +print("Nothing yet opened") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") # make request print("-" * 40) -print(f"Fetching from {TEXT_URL}") +print(f"Fetching from {TEXT_URL} in a context handler") +with requests.get(TEXT_URL) as response: + response_text = response.text + print(f"Text Response {response_text}") + +print("-" * 40) +print("1 request, opened and freed") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print("-" * 40) +print(f"Fetching from {TEXT_URL} not in a context handler") response = requests.get(TEXT_URL) -response_text = response.text -response.close() -print(f"Text Response {response_text}") print("-" * 40) +print("1 request, opened but not freed") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") + +print("-" * 40) +print("Closing everything in the pool") +adafruit_connection_manager.connection_manager_close_all(pool) + +print("-" * 40) +print("Everything closed") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") diff --git a/tests/conftest.py b/tests/conftest.py index 08d3914..ef6c96d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module(): @pytest.fixture(autouse=True) def reset_connection_manager(monkeypatch): monkeypatch.setattr( - "adafruit_connection_manager._global_socketpool", + "adafruit_connection_manager._global_connection_managers", + {}, + ) + monkeypatch.setattr( + "adafruit_connection_manager._global_socketpools", {}, ) monkeypatch.setattr( diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py new file mode 100644 index 0000000..7ca32ee --- /dev/null +++ b/tests/connection_manager_close_all_test.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Get Connection Manager Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_connection_manager_close_all_all(): + mock_pool_1 = mocket.MocketPool() + mock_pool_2 = mocket.MocketPool() + assert mock_pool_1 != mock_pool_2 + + connection_manager_1 = adafruit_connection_manager.get_connection_manager( + mock_pool_1 + ) + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + connection_manager_2 = adafruit_connection_manager.get_connection_manager( + mock_pool_2 + ) + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + assert len(adafruit_connection_manager._global_connection_managers) == 2 + + socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_1.open_sockets == 1 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_2.open_sockets == 1 + assert connection_manager_2.freeable_open_sockets == 0 + + adafruit_connection_manager.connection_manager_close_all() + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + socket_1.close.assert_called_once() + socket_2.close.assert_called_once() + + +def test_connection_manager_close_all_single(): + mock_pool_1 = mocket.MocketPool() + mock_pool_2 = mocket.MocketPool() + assert mock_pool_1 != mock_pool_2 + + connection_manager_1 = adafruit_connection_manager.get_connection_manager( + mock_pool_1 + ) + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + connection_manager_2 = adafruit_connection_manager.get_connection_manager( + mock_pool_2 + ) + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + assert len(adafruit_connection_manager._global_connection_managers) == 2 + + socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_1.open_sockets == 1 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_2.open_sockets == 1 + assert connection_manager_2.freeable_open_sockets == 0 + + adafruit_connection_manager.connection_manager_close_all(mock_pool_1) + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 1 + assert connection_manager_2.freeable_open_sockets == 0 + socket_1.close.assert_called_once() + socket_2.close.assert_not_called() + + +def test_connection_manager_close_all_untracked(): + mock_pool_1 = mocket.MocketPool() + with pytest.raises(RuntimeError) as context: + adafruit_connection_manager.connection_manager_close_all(mock_pool_1) + assert "SocketPool not managed" in str(context) diff --git a/tests/free_socket_test.py b/tests/free_socket_test.py index 93f34eb..b39d6d5 100644 --- a/tests/free_socket_test.py +++ b/tests/free_socket_test.py @@ -16,6 +16,8 @@ def test_free_socket(): mock_pool.socket.return_value = mock_socket_1 connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + assert connection_manager.open_sockets == 0 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and not available socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") @@ -24,12 +26,16 @@ def test_free_socket(): assert socket in connection_manager._available_socket assert connection_manager._available_socket[socket] is False assert key in connection_manager._open_sockets + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and is available connection_manager.free_socket(socket) assert socket in connection_manager._available_socket assert connection_manager._available_socket[socket] is True assert key in connection_manager._open_sockets + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 1 def test_free_socket_not_managed(): @@ -54,26 +60,36 @@ def test_free_sockets(): ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + assert connection_manager.open_sockets == 0 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and not available socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") assert socket_1 == mock_socket_1 assert socket_1 in connection_manager._available_socket assert connection_manager._available_socket[socket_1] is False + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 0 socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") assert socket_2 == mock_socket_2 + assert connection_manager.open_sockets == 2 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and is available connection_manager.free_socket(socket_1) assert socket_1 in connection_manager._available_socket assert connection_manager._available_socket[socket_1] is True + assert connection_manager.open_sockets == 2 + assert connection_manager.freeable_open_sockets == 1 # validate socket is no longer tracked connection_manager._free_sockets() assert socket_1 not in connection_manager._available_socket assert socket_2 in connection_manager._available_socket mock_socket_1.close.assert_called_once() + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 0 def test_get_key_for_socket(): From dcd6be575328361d875a18f485071cf0b43bb3a7 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Fri, 26 Apr 2024 21:51:17 -0700 Subject: [PATCH 2/5] Add release_references --- adafruit_connection_manager.py | 27 +++++-- tests/connection_manager_close_all_test.py | 91 ++++++++++++++++++++++ tests/get_radio_test.py | 26 ++++--- 3 files changed, 130 insertions(+), 14 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 9e1630f..16da447 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -320,21 +320,38 @@ def get_socket( def connection_manager_close_all( - socket_pool: Optional[SocketpoolModuleType] = None, + socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False ) -> None: """Close all open sockets for pool""" if socket_pool: - keys = [socket_pool] + socket_pools = [socket_pool] else: - keys = _global_connection_managers.keys() + socket_pools = _global_connection_managers.keys() - for key in keys: - connection_manager = _global_connection_managers.get(key, None) + for pool in socket_pools: + connection_manager = _global_connection_managers.get(pool, None) if connection_manager is None: raise RuntimeError("SocketPool not managed") connection_manager._free_sockets(force=True) # pylint: disable=protected-access + if release_references: + radio_key = None + for radio_check, pool_check in _global_socketpools.items(): + if pool == pool_check: + radio_key = radio_check + break + + if radio_key: + if radio_key in _global_socketpools: + del _global_socketpools[radio_key] + + if radio_key in _global_ssl_contexts: + del _global_ssl_contexts[radio_key] + + if pool in _global_connection_managers: + del _global_connection_managers[pool] + def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """Get the ConnectionManager singleton for the given pool""" diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py index 7ca32ee..4bea462 100644 --- a/tests/connection_manager_close_all_test.py +++ b/tests/connection_manager_close_all_test.py @@ -85,3 +85,94 @@ def test_connection_manager_close_all_untracked(): with pytest.raises(RuntimeError) as context: adafruit_connection_manager.connection_manager_close_all(mock_pool_1) assert "SocketPool not managed" in str(context) + + +def test_connection_manager_close_all_single_release_references_false( # pylint: disable=unused-argument + circuitpython_socketpool_module, adafruit_esp32spi_socket_module +): + radio_wifi = mocket.MockRadio.Radio() + radio_esp = mocket.MockRadio.ESP_SPIcontrol() + + socket_pool_wifi = adafruit_connection_manager.get_radio_socketpool(radio_wifi) + socket_pool_esp = adafruit_connection_manager.get_radio_socketpool(radio_esp) + + ssl_context_wifi = adafruit_connection_manager.get_radio_ssl_context(radio_wifi) + ssl_context_esp = adafruit_connection_manager.get_radio_ssl_context(radio_esp) + + connection_manager_wifi = adafruit_connection_manager.get_connection_manager( + socket_pool_wifi + ) + connection_manager_esp = adafruit_connection_manager.get_connection_manager( + socket_pool_esp + ) + + assert socket_pool_wifi != socket_pool_esp + assert ssl_context_wifi != ssl_context_esp + assert connection_manager_wifi != connection_manager_esp + + adafruit_connection_manager.connection_manager_close_all( + socket_pool_wifi, release_references=False + ) + + assert socket_pool_wifi in adafruit_connection_manager._global_socketpools.values() + assert socket_pool_esp in adafruit_connection_manager._global_socketpools.values() + + assert ssl_context_wifi in adafruit_connection_manager._global_ssl_contexts.values() + assert ssl_context_esp in adafruit_connection_manager._global_ssl_contexts.values() + + assert ( + socket_pool_wifi + in adafruit_connection_manager._global_connection_managers.keys() + ) + assert ( + socket_pool_esp + in adafruit_connection_manager._global_connection_managers.keys() + ) + + +def test_connection_manager_close_all_single_release_references_true( # pylint: disable=unused-argument + circuitpython_socketpool_module, adafruit_esp32spi_socket_module +): + radio_wifi = mocket.MockRadio.Radio() + radio_esp = mocket.MockRadio.ESP_SPIcontrol() + + socket_pool_wifi = adafruit_connection_manager.get_radio_socketpool(radio_wifi) + socket_pool_esp = adafruit_connection_manager.get_radio_socketpool(radio_esp) + + ssl_context_wifi = adafruit_connection_manager.get_radio_ssl_context(radio_wifi) + ssl_context_esp = adafruit_connection_manager.get_radio_ssl_context(radio_esp) + + connection_manager_wifi = adafruit_connection_manager.get_connection_manager( + socket_pool_wifi + ) + connection_manager_esp = adafruit_connection_manager.get_connection_manager( + socket_pool_esp + ) + + assert socket_pool_wifi != socket_pool_esp + assert ssl_context_wifi != ssl_context_esp + assert connection_manager_wifi != connection_manager_esp + + adafruit_connection_manager.connection_manager_close_all( + socket_pool_wifi, release_references=True + ) + + assert ( + socket_pool_wifi not in adafruit_connection_manager._global_socketpools.values() + ) + assert socket_pool_esp in adafruit_connection_manager._global_socketpools.values() + + assert ( + ssl_context_wifi + not in adafruit_connection_manager._global_ssl_contexts.values() + ) + assert ssl_context_esp in adafruit_connection_manager._global_ssl_contexts.values() + + assert ( + socket_pool_wifi + not in adafruit_connection_manager._global_connection_managers.keys() + ) + assert ( + socket_pool_esp + in adafruit_connection_manager._global_connection_managers.keys() + ) diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 426f785..5c43ad1 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -19,6 +19,7 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument radio = mocket.MockRadio.Radio() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert isinstance(socket_pool, mocket.MocketPool) + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument @@ -27,6 +28,7 @@ def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool.__name__ == "adafruit_esp32spi_socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument @@ -36,6 +38,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument with mock.patch("sys.implementation", return_value=[9, 0, 0]): socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_unsupported(): @@ -52,22 +55,25 @@ def test_get_radio_socketpool_returns_same_one( # pylint: disable=unused-argume socket_pool_1 = adafruit_connection_manager.get_radio_socketpool(radio) socket_pool_2 = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool_1 == socket_pool_2 + assert socket_pool_1 in adafruit_connection_manager._global_socketpools.values() def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument circuitpython_socketpool_module, ): radio = mocket.MockRadio.Radio() - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, ssl.SSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, ssl.SSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument adafruit_esp32spi_socket_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, adafruit_connection_manager._FakeSSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument @@ -75,8 +81,9 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, adafruit_connection_manager._FakeSSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_unsupported(): @@ -90,6 +97,7 @@ def test_get_radio_ssl_context_returns_same_one( # pylint: disable=unused-argum circuitpython_socketpool_module, ): radio = mocket.MockRadio.Radio() - ssl_contexts_1 = adafruit_connection_manager.get_radio_ssl_context(radio) - ssl_contexts_2 = adafruit_connection_manager.get_radio_ssl_context(radio) - assert ssl_contexts_1 == ssl_contexts_2 + ssl_context_1 = adafruit_connection_manager.get_radio_ssl_context(radio) + ssl_context_2 = adafruit_connection_manager.get_radio_ssl_context(radio) + assert ssl_context_1 == ssl_context_2 + assert ssl_context_1 in adafruit_connection_manager._global_ssl_contexts.values() From 9aa4b455249165ef74a4a8a6313d4a0ca0c6dadf Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sat, 27 Apr 2024 15:34:51 -0700 Subject: [PATCH 3/5] Code review comments --- adafruit_connection_manager.py | 154 ++++++++++----------- examples/connectionmanager_helpers.py | 16 +-- tests/close_socket_test.py | 8 +- tests/connection_manager_close_all_test.py | 56 ++++---- tests/free_socket_test.py | 75 ++++------ 5 files changed, 135 insertions(+), 174 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 16da447..9f1d565 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -35,7 +35,7 @@ if not sys.implementation.name == "circuitpython": - from typing import Optional, Tuple + from typing import List, Optional, Tuple from circuitpython_typing.socket import ( CircuitPythonSocketType, @@ -71,8 +71,7 @@ class _FakeSSLContext: def __init__(self, iface: InterfaceType) -> None: self._iface = iface - # pylint: disable=unused-argument - def wrap_socket( + def wrap_socket( # pylint: disable=unused-argument self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None ) -> _FakeSSLSocket: """Return the same socket""" @@ -184,54 +183,75 @@ def __init__( ) -> None: self._socket_pool = socket_pool # Hang onto open sockets so that we can reuse them. - self._available_socket = {} - self._open_sockets = {} + self._available_sockets = set() + self._managed_socket_by_key = {} + self._managed_socket_by_socket = {} def _free_sockets(self, force: bool = False) -> None: - available_sockets = [] - for socket, free in self._available_socket.items(): - if free or force: - available_sockets.append(socket) - + # cloning lists since items are being removed + available_sockets = list(self._available_sockets) for socket in available_sockets: self.close_socket(socket) + if force: + open_sockets = list(self._managed_socket_by_key.values()) + for socket in open_sockets: + self.close_socket(socket) - def _get_key_for_socket(self, socket): + def _get_connected_socket( # pylint: disable=too-many-arguments + self, + addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], + host: str, + port: int, + timeout: float, + is_ssl: bool, + ssl_context: Optional[SSLContextType] = None, + ): try: - return next( - key for key, value in self._open_sockets.items() if value == socket - ) - except StopIteration: - return None + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) + except (OSError, RuntimeError) as exc: + return exc - @property - def open_sockets(self) -> int: - """Get the count of open sockets""" - return len(self._open_sockets) + if is_ssl: + socket = ssl_context.wrap_socket(socket, server_hostname=host) + connect_host = host + else: + connect_host = addr_info[-1][0] + socket.settimeout(timeout) # socket read timeout + + try: + socket.connect((connect_host, port)) + except (MemoryError, OSError) as exc: + socket.close() + return exc + + return socket @property - def freeable_open_sockets(self) -> int: + def available_socket_count(self) -> int: """Get the count of freeable open sockets""" - return len( - [socket for socket, free in self._available_socket.items() if free is True] - ) + return len(self._available_sockets) + + @property + def managed_socket_count(self) -> int: + """Get the count of open sockets""" + return len(self._managed_socket_by_key) def close_socket(self, socket: SocketType) -> None: """Close a previously opened socket.""" - if socket not in self._open_sockets.values(): + if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") - key = self._get_key_for_socket(socket) socket.close() - del self._available_socket[socket] - del self._open_sockets[key] + key = self._managed_socket_by_socket.pop(socket) + del self._managed_socket_by_key[key] + if socket in self._available_sockets: + self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: """Mark a previously opened socket as available so it can be reused if needed.""" - if socket not in self._open_sockets.values(): + if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") - self._available_socket[socket] = True + self._available_sockets.add(socket) - # pylint: disable=too-many-branches,too-many-locals,too-many-statements def get_socket( self, host: str, @@ -247,10 +267,10 @@ def get_socket( if session_id: session_id = str(session_id) key = (host, port, proto, session_id) - if key in self._open_sockets: - socket = self._open_sockets[key] - if self._available_socket[socket]: - self._available_socket[socket] = False + if key in self._managed_socket_by_key: + socket = self._managed_socket_by_key[key] + if socket in self._available_sockets: + self._available_sockets.remove(socket) return socket raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") @@ -266,54 +286,22 @@ def get_socket( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - try_count = 0 - socket = None - last_exc = None - while try_count < 2 and socket is None: - try_count += 1 - if try_count > 1: - if any( - socket - for socket, free in self._available_socket.items() - if free is True - ): - self._free_sockets() - else: - break - - try: - socket = self._socket_pool.socket(addr_info[0], addr_info[1]) - except OSError as exc: - last_exc = exc - continue - except RuntimeError as exc: - last_exc = exc - continue - - if is_ssl: - socket = ssl_context.wrap_socket(socket, server_hostname=host) - connect_host = host - else: - connect_host = addr_info[-1][0] - socket.settimeout(timeout) # socket read timeout - - try: - socket.connect((connect_host, port)) - except MemoryError as exc: - last_exc = exc - socket.close() - socket = None - except OSError as exc: - last_exc = exc - socket.close() - socket = None - - if socket is None: - raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc - - self._available_socket[socket] = False - self._open_sockets[key] = socket - return socket + result = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + if isinstance(result, Exception): + # Got an error, if there are any available sockets, free them and try again + if self.available_socket_count: + self._free_sockets() + result = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + if isinstance(result, Exception): + raise RuntimeError(f"Error connecting socket: {result}") from result + + self._managed_socket_by_key[key] = result + self._managed_socket_by_socket[result] = key + return result # global helpers diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py index df383bd..e9fb842 100644 --- a/examples/connectionmanager_helpers.py +++ b/examples/connectionmanager_helpers.py @@ -27,8 +27,8 @@ connection_manager = adafruit_connection_manager.get_connection_manager(pool) print("-" * 40) print("Nothing yet opened") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") # make request print("-" * 40) @@ -39,8 +39,8 @@ print("-" * 40) print("1 request, opened and freed") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") print("-" * 40) print(f"Fetching from {TEXT_URL} not in a context handler") @@ -48,8 +48,8 @@ print("-" * 40) print("1 request, opened but not freed") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") print("-" * 40) print("Closing everything in the pool") @@ -57,5 +57,5 @@ print("-" * 40) print("Everything closed") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") diff --git a/tests/close_socket_test.py b/tests/close_socket_test.py index 957cb94..3927181 100644 --- a/tests/close_socket_test.py +++ b/tests/close_socket_test.py @@ -21,13 +21,13 @@ def test_close_socket(): socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") key = (mocket.MOCK_HOST_1, 80, "http:", None) assert socket == mock_socket_1 - assert socket in connection_manager._available_socket - assert key in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key # validate socket is no longer tracked connection_manager.close_socket(socket) - assert socket not in connection_manager._available_socket - assert key not in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key not in connection_manager._managed_socket_by_key def test_close_socket_not_managed(): diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py index 4bea462..c0fa498 100644 --- a/tests/connection_manager_close_all_test.py +++ b/tests/connection_manager_close_all_test.py @@ -18,29 +18,29 @@ def test_connection_manager_close_all_all(): connection_manager_1 = adafruit_connection_manager.get_connection_manager( mock_pool_1 ) - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 connection_manager_2 = adafruit_connection_manager.get_connection_manager( mock_pool_2 ) - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 assert len(adafruit_connection_manager._global_connection_managers) == 2 socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_1.open_sockets == 1 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 1 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_2.open_sockets == 1 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 adafruit_connection_manager.connection_manager_close_all() - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 socket_1.close.assert_called_once() socket_2.close.assert_called_once() @@ -53,29 +53,29 @@ def test_connection_manager_close_all_single(): connection_manager_1 = adafruit_connection_manager.get_connection_manager( mock_pool_1 ) - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 connection_manager_2 = adafruit_connection_manager.get_connection_manager( mock_pool_2 ) - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 assert len(adafruit_connection_manager._global_connection_managers) == 2 socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_1.open_sockets == 1 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 1 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_2.open_sockets == 1 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 adafruit_connection_manager.connection_manager_close_all(mock_pool_1) - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 1 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 socket_1.close.assert_called_once() socket_2.close.assert_not_called() diff --git a/tests/free_socket_test.py b/tests/free_socket_test.py index b39d6d5..666a072 100644 --- a/tests/free_socket_test.py +++ b/tests/free_socket_test.py @@ -16,26 +16,24 @@ def test_free_socket(): mock_pool.socket.return_value = mock_socket_1 connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - assert connection_manager.open_sockets == 0 - assert connection_manager.freeable_open_sockets == 0 + assert connection_manager.managed_socket_count == 0 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and not available socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") key = (mocket.MOCK_HOST_1, 80, "http:", None) assert socket == mock_socket_1 - assert socket in connection_manager._available_socket - assert connection_manager._available_socket[socket] is False - assert key in connection_manager._open_sockets - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 0 + assert socket not in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and is available connection_manager.free_socket(socket) - assert socket in connection_manager._available_socket - assert connection_manager._available_socket[socket] is True - assert key in connection_manager._open_sockets - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 1 + assert socket in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 1 def test_free_socket_not_managed(): @@ -60,56 +58,31 @@ def test_free_sockets(): ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - assert connection_manager.open_sockets == 0 - assert connection_manager.freeable_open_sockets == 0 + assert connection_manager.managed_socket_count == 0 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and not available socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") assert socket_1 == mock_socket_1 - assert socket_1 in connection_manager._available_socket - assert connection_manager._available_socket[socket_1] is False - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 0 + assert socket_1 not in connection_manager._available_sockets + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") assert socket_2 == mock_socket_2 - assert connection_manager.open_sockets == 2 - assert connection_manager.freeable_open_sockets == 0 + assert connection_manager.managed_socket_count == 2 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and is available connection_manager.free_socket(socket_1) - assert socket_1 in connection_manager._available_socket - assert connection_manager._available_socket[socket_1] is True - assert connection_manager.open_sockets == 2 - assert connection_manager.freeable_open_sockets == 1 + assert socket_1 in connection_manager._available_sockets + assert connection_manager.managed_socket_count == 2 + assert connection_manager.available_socket_count == 1 # validate socket is no longer tracked connection_manager._free_sockets() - assert socket_1 not in connection_manager._available_socket - assert socket_2 in connection_manager._available_socket + assert socket_1 not in connection_manager._available_sockets + assert socket_2 not in connection_manager._available_sockets mock_socket_1.close.assert_called_once() - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 0 - - -def test_get_key_for_socket(): - mock_pool = mocket.MocketPool() - mock_socket_1 = mocket.Mocket() - mock_pool.socket.return_value = mock_socket_1 - - connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - - # validate tracked socket has correct key - socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - key = (mocket.MOCK_HOST_1, 80, "http:", None) - assert connection_manager._get_key_for_socket(socket) == key - - -def test_get_key_for_socket_not_managed(): - mock_pool = mocket.MocketPool() - mock_socket_1 = mocket.Mocket() - - connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - - # validate untracked socket has no key - assert connection_manager._get_key_for_socket(mock_socket_1) is None + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 From f33d78cf757096bf760a3a0602893db078aa1811 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sat, 27 Apr 2024 15:57:46 -0700 Subject: [PATCH 4/5] rename var --- adafruit_connection_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 9f1d565..bb03328 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -184,8 +184,8 @@ def __init__( self._socket_pool = socket_pool # Hang onto open sockets so that we can reuse them. self._available_sockets = set() + self._key_by_managed_socket = {} self._managed_socket_by_key = {} - self._managed_socket_by_socket = {} def _free_sockets(self, force: bool = False) -> None: # cloning lists since items are being removed @@ -241,7 +241,7 @@ def close_socket(self, socket: SocketType) -> None: if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") socket.close() - key = self._managed_socket_by_socket.pop(socket) + key = self._key_by_managed_socket.pop(socket) del self._managed_socket_by_key[key] if socket in self._available_sockets: self._available_sockets.remove(socket) @@ -299,8 +299,8 @@ def get_socket( if isinstance(result, Exception): raise RuntimeError(f"Error connecting socket: {result}") from result + self._key_by_managed_socket[result] = key self._managed_socket_by_key[key] = result - self._managed_socket_by_socket[result] = key return result From 601ee66791d53d72572b581f28d8d92a2c0401ec Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sat, 27 Apr 2024 19:02:32 -0700 Subject: [PATCH 5/5] Little error cleanup --- adafruit_connection_manager.py | 9 +++++++-- tests/get_socket_test.py | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index bb03328..9dc8f51 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -64,7 +64,7 @@ def connect(self, address: Tuple[str, int]) -> None: try: return self._socket.connect(address, self._mode) except RuntimeError as error: - raise OSError(errno.ENOMEM) from error + raise OSError(errno.ENOMEM, str(error)) from error class _FakeSSLContext: @@ -286,18 +286,23 @@ def get_socket( host, port, 0, self._socket_pool.SOCK_STREAM )[0] + first_exception = None result = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) if isinstance(result, Exception): # Got an error, if there are any available sockets, free them and try again if self.available_socket_count: + first_exception = result self._free_sockets() result = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) if isinstance(result, Exception): - raise RuntimeError(f"Error connecting socket: {result}") from result + last_result = f", first error: {first_exception}" if first_exception else "" + raise RuntimeError( + f"Error connecting socket: {result}{last_result}" + ) from result self._key_by_managed_socket[result] = key self._managed_socket_by_key[key] = result diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index ea252cc..6be48f0 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -213,7 +213,7 @@ def test_get_socket_runtime_error_ties_again_only_once(): # try to get a socket that returns a RuntimeError twice with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") - assert "Error connecting socket: error 2" in str(context) + assert "Error connecting socket: error 2, first error: error 1" in str(context) free_sockets_mock.assert_called_once() @@ -242,7 +242,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.return_value = mock_socket_1 - mock_socket_1.connect.side_effect = RuntimeError("RuntimeError ") + mock_socket_1.connect.side_effect = RuntimeError("RuntimeError") radio = mocket.MockRadio.ESP_SPIcontrol() ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) @@ -252,4 +252,4 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) - assert "Error connecting socket: 12" in str(context) + assert "Error connecting socket: [Errno 12] RuntimeError" in str(context)