From 85713ef59a4454c10360454bb4465013666a0e04 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Wed, 24 Apr 2024 08:00:49 -0700 Subject: [PATCH 01/20] Use new SocketPool for ESP32SPI and WIZNET5K --- adafruit_connection_manager.py | 16 +++++++++++++--- tests/conftest.py | 33 +++++++++++++++++++++++++++------ tests/get_radio_test.py | 4 ++-- 3 files changed, 42 insertions(+), 11 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index cc70f3f..353f778 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -56,6 +56,10 @@ def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self.recv = socket.recv self.close = socket.close self.recv_into = socket.recv_into + if hasattr(socket, "_interface"): + self._interface = socket._interface + if hasattr(socket, "_socket_pool"): + self._socket_pool = socket._socket_pool def connect(self, address: Tuple[str, int]) -> None: """Connect wrapper to add non-standard mode parameter""" @@ -93,7 +97,10 @@ def create_fake_ssl_context( * `Adafruit AirLift FeatherWing – ESP32 WiFi Co-Processor `_ """ - socket_pool.set_interface(iface) + if hasattr(socket_pool, "set_interface"): + # this is to manually support legacy hardware like the fona + socket_pool.set_interface(iface) + return _FakeSSLContext(iface) @@ -121,12 +128,15 @@ def get_radio_socketpool(radio): ssl_context = ssl.create_default_context() elif class_name == "ESP_SPIcontrol": - import adafruit_esp32spi.adafruit_esp32spi_socket as pool # pylint: disable=import-outside-toplevel + import adafruit_esp32spi.adafruit_esp32spi_socketpool as socketpool # pylint: disable=import-outside-toplevel + pool = socketpool.SocketPool(radio) ssl_context = create_fake_ssl_context(pool, radio) elif class_name == "WIZNET5K": - import adafruit_wiznet5k.adafruit_wiznet5k_socket as pool # pylint: disable=import-outside-toplevel + import adafruit_wiznet5k.adafruit_wiznet5k_socketpool as socketpool # pylint: disable=import-outside-toplevel + + pool = socketpool.SocketPool(radio) # Note: SSL/TLS connections are not supported by the Wiznet5k library at this time ssl_context = create_fake_ssl_context(pool, radio) diff --git a/tests/conftest.py b/tests/conftest.py index 2d9bb0a..06457cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,18 +14,39 @@ def set_interface(iface): """Helper to set the global internet interface""" +class SocketPool: + name = None + + def __init__(self, *args, **kwargs): + pass + + @property + def __name__(self): + return self.name + + +class ESP32SPI_SocketPool(SocketPool): # pylint: disable=too-few-public-methods + name = "adafruit_esp32spi_socketpool" + + +class WIZNET5K_SocketPool(SocketPool): # pylint: disable=too-few-public-methods + name = "adafruit_wiznet5k_socketpool" + + socketpool_module = type(sys)("socketpool") socketpool_module.SocketPool = mocket.MocketPool sys.modules["socketpool"] = socketpool_module esp32spi_module = type(sys)("adafruit_esp32spi") -esp32spi_socket_module = type(sys)("adafruit_esp32spi_socket") -esp32spi_socket_module.set_interface = set_interface +esp32spi_socket_module = type(sys)("adafruit_esp32spi_socketpool") +esp32spi_socket_module.SocketPool = ESP32SPI_SocketPool sys.modules["adafruit_esp32spi"] = esp32spi_module -sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] = esp32spi_socket_module +sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = esp32spi_socket_module wiznet5k_module = type(sys)("adafruit_wiznet5k") -wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket") -wiznet5k_socket_module.set_interface = set_interface +wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") +wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool sys.modules["adafruit_wiznet5k"] = wiznet5k_module -sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module +sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( + wiznet5k_socketpool_module +) diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index ea80f7e..cbd49e8 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -21,13 +21,13 @@ def test_get_radio_socketpool_wifi(): def test_get_radio_socketpool_esp32spi(): radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) - assert socket_pool.__name__ == "adafruit_esp32spi_socket" + assert socket_pool.__name__ == "adafruit_esp32spi_socketpool" def test_get_radio_socketpool_wiznet5k(): radio = mocket.MockRadio.WIZNET5K() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) - assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + assert socket_pool.__name__ == "adafruit_wiznet5k_socketpool" def test_get_radio_socketpool_unsupported(): From 8482b9fbdab469da550618b921adf06be8916417 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Wed, 24 Apr 2024 21:26:49 -0700 Subject: [PATCH 02/20] Merge fixes --- tests/conftest.py | 18 ++++++++++-------- tests/get_radio_test.py | 8 ++++---- tests/get_socket_test.py | 4 ++-- tests/ssl_context_test.py | 4 ++-- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 185ef3b..bfea88d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,29 +44,32 @@ def circuitpython_socketpool_module(): @pytest.fixture -def adafruit_esp32spi_socket_module(): +def adafruit_esp32spi_socketpool_module(): esp32spi_module = type(sys)("adafruit_esp32spi") esp32spi_socket_module = type(sys)("adafruit_esp32spi_socketpool") esp32spi_socket_module.SocketPool = ESP32SPI_SocketPool sys.modules["adafruit_esp32spi"] = esp32spi_module - sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = esp32spi_socket_module + sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] = ( + esp32spi_socket_module + ) yield del sys.modules["adafruit_esp32spi"] - del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socket"] + del sys.modules["adafruit_esp32spi.adafruit_esp32spi_socketpool"] @pytest.fixture -def adafruit_wiznet5k_socket_module(): +def adafruit_wiznet5k_socketpool_module(): wiznet5k_module = type(sys)("adafruit_wiznet5k") wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool - wiznet5k_socket_module.SOCK_STREAM = 0x21 + wiznet5k_socketpool_module.SOCK_STREAM = 0x21 sys.modules["adafruit_wiznet5k"] = wiznet5k_module sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( - wiznet5k_socketpool_module + wiznet5k_socketpool_module + ) yield del sys.modules["adafruit_wiznet5k"] - del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] + del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] @pytest.fixture(autouse=True) @@ -79,4 +82,3 @@ def reset_connection_manager(monkeypatch): "adafruit_connection_manager._global_ssl_contexts", {}, ) -) diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 05ea5f9..c73c97d 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -22,7 +22,7 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) @@ -30,7 +30,7 @@ def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument - adafruit_wiznet5k_socket_module, + adafruit_wiznet5k_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): @@ -63,7 +63,7 @@ def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) @@ -71,7 +71,7 @@ def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument - adafruit_wiznet5k_socket_module, + adafruit_wiznet5k_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index ea252cc..2937dd7 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -218,7 +218,7 @@ def test_get_socket_runtime_error_ties_again_only_once(): def test_fake_ssl_context_connect( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() @@ -237,7 +237,7 @@ def test_fake_ssl_context_connect( # pylint: disable=unused-argument def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() diff --git a/tests/ssl_context_test.py b/tests/ssl_context_test.py index abc857f..9f4a8a1 100644 --- a/tests/ssl_context_test.py +++ b/tests/ssl_context_test.py @@ -13,7 +13,7 @@ def test_connect_esp32spi_https( # pylint: disable=unused-argument - adafruit_esp32spi_socket_module, + adafruit_esp32spi_socketpool_module, ): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() @@ -46,7 +46,7 @@ def test_connect_wifi_https( # pylint: disable=unused-argument def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argument - adafruit_wiznet5k_socket_module, + adafruit_wiznet5k_socketpool_module, ): mock_pool = mocket.MocketPool() radio = mocket.MockRadio.WIZNET5K() From ef4a49d0cd2b8d7dd7207278c5f4a00874c86bd1 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Thu, 25 Apr 2024 14:46:20 -0700 Subject: [PATCH 03/20] Close all and counts --- adafruit_connection_manager.py | 47 +++++++++--- examples/connectionmanager_helpers.py | 32 +++++++- tests/conftest.py | 6 +- tests/connection_manager_close_all_test.py | 87 ++++++++++++++++++++++ tests/free_socket_test.py | 16 ++++ 5 files changed, 173 insertions(+), 15 deletions(-) create mode 100644 tests/connection_manager_close_all_test.py diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index b591372..9e1630f 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -99,7 +99,8 @@ def create_fake_ssl_context( return _FakeSSLContext(iface) -_global_socketpool = {} +_global_connection_managers = {} +_global_socketpools = {} _global_ssl_contexts = {} @@ -113,7 +114,7 @@ def get_radio_socketpool(radio): * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ class_name = radio.__class__.__name__ - if class_name not in _global_socketpool: + if class_name not in _global_socketpools: if class_name == "Radio": import ssl # pylint: disable=import-outside-toplevel @@ -151,10 +152,10 @@ def get_radio_socketpool(radio): else: raise AttributeError(f"Unsupported radio class: {class_name}") - _global_socketpool[class_name] = pool + _global_socketpools[class_name] = pool _global_ssl_contexts[class_name] = ssl_context - return _global_socketpool[class_name] + return _global_socketpools[class_name] def get_radio_ssl_context(radio): @@ -186,10 +187,10 @@ def __init__( self._available_socket = {} self._open_sockets = {} - def _free_sockets(self) -> None: + def _free_sockets(self, force: bool = False) -> None: available_sockets = [] for socket, free in self._available_socket.items(): - if free: + if free or force: available_sockets.append(socket) for socket in available_sockets: @@ -203,6 +204,18 @@ def _get_key_for_socket(self, socket): except StopIteration: return None + @property + def open_sockets(self) -> int: + """Get the count of open sockets""" + return len(self._open_sockets) + + @property + def freeable_open_sockets(self) -> int: + """Get the count of freeable open sockets""" + return len( + [socket for socket, free in self._available_socket.items() if free is True] + ) + def close_socket(self, socket: SocketType) -> None: """Close a previously opened socket.""" if socket not in self._open_sockets.values(): @@ -306,11 +319,25 @@ def get_socket( # global helpers -_global_connection_manager = {} +def connection_manager_close_all( + socket_pool: Optional[SocketpoolModuleType] = None, +) -> None: + """Close all open sockets for pool""" + if socket_pool: + keys = [socket_pool] + else: + keys = _global_connection_managers.keys() + + for key in keys: + connection_manager = _global_connection_managers.get(key, None) + if connection_manager is None: + raise RuntimeError("SocketPool not managed") + + connection_manager._free_sockets(force=True) # pylint: disable=protected-access def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """Get the ConnectionManager singleton for the given pool""" - if socket_pool not in _global_connection_manager: - _global_connection_manager[socket_pool] = ConnectionManager(socket_pool) - return _global_connection_manager[socket_pool] + if socket_pool not in _global_connection_managers: + _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) + return _global_connection_managers[socket_pool] diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py index 36f4af6..df383bd 100644 --- a/examples/connectionmanager_helpers.py +++ b/examples/connectionmanager_helpers.py @@ -24,14 +24,38 @@ # get request session requests = adafruit_requests.Session(pool, ssl_context) +connection_manager = adafruit_connection_manager.get_connection_manager(pool) +print("-" * 40) +print("Nothing yet opened") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") # make request print("-" * 40) -print(f"Fetching from {TEXT_URL}") +print(f"Fetching from {TEXT_URL} in a context handler") +with requests.get(TEXT_URL) as response: + response_text = response.text + print(f"Text Response {response_text}") + +print("-" * 40) +print("1 request, opened and freed") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print("-" * 40) +print(f"Fetching from {TEXT_URL} not in a context handler") response = requests.get(TEXT_URL) -response_text = response.text -response.close() -print(f"Text Response {response_text}") print("-" * 40) +print("1 request, opened but not freed") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") + +print("-" * 40) +print("Closing everything in the pool") +adafruit_connection_manager.connection_manager_close_all(pool) + +print("-" * 40) +print("Everything closed") +print(f"Open Sockets: {connection_manager.open_sockets}") +print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") diff --git a/tests/conftest.py b/tests/conftest.py index 08d3914..ef6c96d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,11 @@ def adafruit_wiznet5k_with_ssl_socket_module(): @pytest.fixture(autouse=True) def reset_connection_manager(monkeypatch): monkeypatch.setattr( - "adafruit_connection_manager._global_socketpool", + "adafruit_connection_manager._global_connection_managers", + {}, + ) + monkeypatch.setattr( + "adafruit_connection_manager._global_socketpools", {}, ) monkeypatch.setattr( diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py new file mode 100644 index 0000000..7ca32ee --- /dev/null +++ b/tests/connection_manager_close_all_test.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2024 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" Get Connection Manager Tests """ + +import mocket +import pytest + +import adafruit_connection_manager + + +def test_connection_manager_close_all_all(): + mock_pool_1 = mocket.MocketPool() + mock_pool_2 = mocket.MocketPool() + assert mock_pool_1 != mock_pool_2 + + connection_manager_1 = adafruit_connection_manager.get_connection_manager( + mock_pool_1 + ) + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + connection_manager_2 = adafruit_connection_manager.get_connection_manager( + mock_pool_2 + ) + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + assert len(adafruit_connection_manager._global_connection_managers) == 2 + + socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_1.open_sockets == 1 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_2.open_sockets == 1 + assert connection_manager_2.freeable_open_sockets == 0 + + adafruit_connection_manager.connection_manager_close_all() + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + socket_1.close.assert_called_once() + socket_2.close.assert_called_once() + + +def test_connection_manager_close_all_single(): + mock_pool_1 = mocket.MocketPool() + mock_pool_2 = mocket.MocketPool() + assert mock_pool_1 != mock_pool_2 + + connection_manager_1 = adafruit_connection_manager.get_connection_manager( + mock_pool_1 + ) + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + connection_manager_2 = adafruit_connection_manager.get_connection_manager( + mock_pool_2 + ) + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + assert len(adafruit_connection_manager._global_connection_managers) == 2 + + socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_1.open_sockets == 1 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 0 + assert connection_manager_2.freeable_open_sockets == 0 + socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert connection_manager_2.open_sockets == 1 + assert connection_manager_2.freeable_open_sockets == 0 + + adafruit_connection_manager.connection_manager_close_all(mock_pool_1) + assert connection_manager_1.open_sockets == 0 + assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_2.open_sockets == 1 + assert connection_manager_2.freeable_open_sockets == 0 + socket_1.close.assert_called_once() + socket_2.close.assert_not_called() + + +def test_connection_manager_close_all_untracked(): + mock_pool_1 = mocket.MocketPool() + with pytest.raises(RuntimeError) as context: + adafruit_connection_manager.connection_manager_close_all(mock_pool_1) + assert "SocketPool not managed" in str(context) diff --git a/tests/free_socket_test.py b/tests/free_socket_test.py index 93f34eb..b39d6d5 100644 --- a/tests/free_socket_test.py +++ b/tests/free_socket_test.py @@ -16,6 +16,8 @@ def test_free_socket(): mock_pool.socket.return_value = mock_socket_1 connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + assert connection_manager.open_sockets == 0 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and not available socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") @@ -24,12 +26,16 @@ def test_free_socket(): assert socket in connection_manager._available_socket assert connection_manager._available_socket[socket] is False assert key in connection_manager._open_sockets + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and is available connection_manager.free_socket(socket) assert socket in connection_manager._available_socket assert connection_manager._available_socket[socket] is True assert key in connection_manager._open_sockets + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 1 def test_free_socket_not_managed(): @@ -54,26 +60,36 @@ def test_free_sockets(): ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) + assert connection_manager.open_sockets == 0 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and not available socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") assert socket_1 == mock_socket_1 assert socket_1 in connection_manager._available_socket assert connection_manager._available_socket[socket_1] is False + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 0 socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") assert socket_2 == mock_socket_2 + assert connection_manager.open_sockets == 2 + assert connection_manager.freeable_open_sockets == 0 # validate socket is tracked and is available connection_manager.free_socket(socket_1) assert socket_1 in connection_manager._available_socket assert connection_manager._available_socket[socket_1] is True + assert connection_manager.open_sockets == 2 + assert connection_manager.freeable_open_sockets == 1 # validate socket is no longer tracked connection_manager._free_sockets() assert socket_1 not in connection_manager._available_socket assert socket_2 in connection_manager._available_socket mock_socket_1.close.assert_called_once() + assert connection_manager.open_sockets == 1 + assert connection_manager.freeable_open_sockets == 0 def test_get_key_for_socket(): From b0d5e027c63965ff321cc7464a9d0642057b8cb5 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Thu, 25 Apr 2024 17:40:48 -0700 Subject: [PATCH 04/20] Merge fixes --- adafruit_connection_manager.py | 1 - tests/conftest.py | 27 +++++++++++++++------------ tests/get_connection_manager_test.py | 2 +- tests/ssl_context_test.py | 2 +- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 0c6e0d0..f1cadba 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -150,7 +150,6 @@ def get_radio_socketpool(radio): import ssl # pylint: disable=import-outside-toplevel ssl_context = ssl.create_default_context() - pool.set_interface(radio) except ImportError: # if SSL not on board, default to fake_ssl_context pass diff --git a/tests/conftest.py b/tests/conftest.py index f6fc130..f105d92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,11 +10,6 @@ import pytest -# pylint: disable=unused-argument -def set_interface(iface): - """Helper to set the global internet interface""" - - class SocketPool: name = None @@ -32,6 +27,14 @@ class ESP32SPI_SocketPool(SocketPool): # pylint: disable=too-few-public-methods class WIZNET5K_SocketPool(SocketPool): # pylint: disable=too-few-public-methods name = "adafruit_wiznet5k_socketpool" + SOCK_STREAM = 0x21 + + +class WIZNET5K_With_SSL_SocketPool( + SocketPool +): # pylint: disable=too-few-public-methods + name = "adafruit_wiznet5k_socketpool" + SOCK_STREAM = 0x1 @pytest.fixture @@ -62,7 +65,6 @@ def adafruit_wiznet5k_socketpool_module(): wiznet5k_module = type(sys)("adafruit_wiznet5k") wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") wiznet5k_socketpool_module.SocketPool = WIZNET5K_SocketPool - wiznet5k_socketpool_module.SOCK_STREAM = 0x21 sys.modules["adafruit_wiznet5k"] = wiznet5k_module sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( wiznet5k_socketpool_module @@ -73,16 +75,17 @@ def adafruit_wiznet5k_socketpool_module(): @pytest.fixture -def adafruit_wiznet5k_with_ssl_socket_module(): +def adafruit_wiznet5k_with_ssl_socketpool_module(): wiznet5k_module = type(sys)("adafruit_wiznet5k") - wiznet5k_socket_module = type(sys)("adafruit_wiznet5k_socket") - wiznet5k_socket_module.set_interface = set_interface - wiznet5k_socket_module.SOCK_STREAM = 1 + wiznet5k_socketpool_module = type(sys)("adafruit_wiznet5k_socketpool") + wiznet5k_socketpool_module.SocketPool = WIZNET5K_With_SSL_SocketPool sys.modules["adafruit_wiznet5k"] = wiznet5k_module - sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] = wiznet5k_socket_module + sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] = ( + wiznet5k_socketpool_module + ) yield del sys.modules["adafruit_wiznet5k"] - del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socket"] + del sys.modules["adafruit_wiznet5k.adafruit_wiznet5k_socketpool"] @pytest.fixture(autouse=True) diff --git a/tests/get_connection_manager_test.py b/tests/get_connection_manager_test.py index 324d032..c5f7817 100644 --- a/tests/get_connection_manager_test.py +++ b/tests/get_connection_manager_test.py @@ -19,7 +19,7 @@ def test_get_connection_manager(): def test_different_connection_manager_different_pool( # pylint: disable=unused-argument - circuitpython_socketpool_module, adafruit_esp32spi_socket_module + circuitpython_socketpool_module, adafruit_esp32spi_socketpool_module ): radio_wifi = mocket.MockRadio.Radio() radio_esp = mocket.MockRadio.ESP_SPIcontrol() diff --git a/tests/ssl_context_test.py b/tests/ssl_context_test.py index 90c0adf..2f2e370 100644 --- a/tests/ssl_context_test.py +++ b/tests/ssl_context_test.py @@ -66,7 +66,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen def test_connect_wiznet5k_https_supported( # pylint: disable=unused-argument - adafruit_wiznet5k_with_ssl_socket_module, + adafruit_wiznet5k_with_ssl_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", (None, WIZNET5K_SSL_SUPPORT_VERSION)): From be9b9ddb2cea6614d8db3c3cb3279b3dc325bbd6 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Fri, 26 Apr 2024 07:04:02 -0700 Subject: [PATCH 05/20] Update hash keys --- adafruit_connection_manager.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index f1cadba..ecad474 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -110,6 +110,14 @@ def create_fake_ssl_context( _global_ssl_contexts = {} +def _get_radio_hash_key(radio): + class_name = radio.__class__.__name__ + # trying to use wifi.radio as a key results in: + # TypeError: unsupported type for __hash__: 'Radio' + # So just use the class name in this case + return class_name if class_name == "Radio" else radio + + def get_radio_socketpool(radio): """Helper to get a socket pool for common boards @@ -119,8 +127,9 @@ def get_radio_socketpool(radio): * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ - class_name = radio.__class__.__name__ - if class_name not in _global_socketpool: + key = _get_radio_hash_key(radio) + if key not in _global_socketpool: + class_name = radio.__class__.__name__ if class_name == "Radio": import ssl # pylint: disable=import-outside-toplevel @@ -160,10 +169,10 @@ def get_radio_socketpool(radio): else: raise AttributeError(f"Unsupported radio class: {class_name}") - _global_socketpool[class_name] = pool - _global_ssl_contexts[class_name] = ssl_context + _global_socketpool[key] = pool + _global_ssl_contexts[key] = ssl_context - return _global_socketpool[class_name] + return _global_socketpool[key] def get_radio_ssl_context(radio): @@ -175,9 +184,8 @@ def get_radio_ssl_context(radio): * Using the ESP32 WiFi Co-Processor (like the Adafruit AirLift) * Using a WIZ5500 (Like the Adafruit Ethernet FeatherWing) """ - class_name = radio.__class__.__name__ get_radio_socketpool(radio) - return _global_ssl_contexts[class_name] + return _global_ssl_contexts[_get_radio_hash_key(radio)] # main class From 778b78f77f120566631ee4eb6f37ecf5f50d9b5e Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Fri, 26 Apr 2024 08:03:14 -0700 Subject: [PATCH 06/20] Better hashing --- adafruit_connection_manager.py | 9 ++++----- tests/get_radio_test.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index ecad474..b58d042 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -111,11 +111,10 @@ def create_fake_ssl_context( def _get_radio_hash_key(radio): - class_name = radio.__class__.__name__ - # trying to use wifi.radio as a key results in: - # TypeError: unsupported type for __hash__: 'Radio' - # So just use the class name in this case - return class_name if class_name == "Radio" else radio + try: + return hash(radio) + except TypeError: + return radio.__class__.__name__ def get_radio_socketpool(radio): diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index c73c97d..c4ffde6 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -13,6 +13,18 @@ import adafruit_connection_manager +def test__get_radio_hash_key(): + radio = mocket.MockRadio.Radio() + assert adafruit_connection_manager._get_radio_hash_key(radio) == hash(radio) + + +def test__get_radio_hash_key_not_hashable(): + radio = mocket.MockRadio.Radio() + + with mock.patch("builtins.hash", side_effect=TypeError()): + assert adafruit_connection_manager._get_radio_hash_key(radio) == "Radio" + + def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument circuitpython_socketpool_module, ): From dcd6be575328361d875a18f485071cf0b43bb3a7 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Fri, 26 Apr 2024 21:51:17 -0700 Subject: [PATCH 07/20] Add release_references --- adafruit_connection_manager.py | 27 +++++-- tests/connection_manager_close_all_test.py | 91 ++++++++++++++++++++++ tests/get_radio_test.py | 26 ++++--- 3 files changed, 130 insertions(+), 14 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 9e1630f..16da447 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -320,21 +320,38 @@ def get_socket( def connection_manager_close_all( - socket_pool: Optional[SocketpoolModuleType] = None, + socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False ) -> None: """Close all open sockets for pool""" if socket_pool: - keys = [socket_pool] + socket_pools = [socket_pool] else: - keys = _global_connection_managers.keys() + socket_pools = _global_connection_managers.keys() - for key in keys: - connection_manager = _global_connection_managers.get(key, None) + for pool in socket_pools: + connection_manager = _global_connection_managers.get(pool, None) if connection_manager is None: raise RuntimeError("SocketPool not managed") connection_manager._free_sockets(force=True) # pylint: disable=protected-access + if release_references: + radio_key = None + for radio_check, pool_check in _global_socketpools.items(): + if pool == pool_check: + radio_key = radio_check + break + + if radio_key: + if radio_key in _global_socketpools: + del _global_socketpools[radio_key] + + if radio_key in _global_ssl_contexts: + del _global_ssl_contexts[radio_key] + + if pool in _global_connection_managers: + del _global_connection_managers[pool] + def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """Get the ConnectionManager singleton for the given pool""" diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py index 7ca32ee..4bea462 100644 --- a/tests/connection_manager_close_all_test.py +++ b/tests/connection_manager_close_all_test.py @@ -85,3 +85,94 @@ def test_connection_manager_close_all_untracked(): with pytest.raises(RuntimeError) as context: adafruit_connection_manager.connection_manager_close_all(mock_pool_1) assert "SocketPool not managed" in str(context) + + +def test_connection_manager_close_all_single_release_references_false( # pylint: disable=unused-argument + circuitpython_socketpool_module, adafruit_esp32spi_socket_module +): + radio_wifi = mocket.MockRadio.Radio() + radio_esp = mocket.MockRadio.ESP_SPIcontrol() + + socket_pool_wifi = adafruit_connection_manager.get_radio_socketpool(radio_wifi) + socket_pool_esp = adafruit_connection_manager.get_radio_socketpool(radio_esp) + + ssl_context_wifi = adafruit_connection_manager.get_radio_ssl_context(radio_wifi) + ssl_context_esp = adafruit_connection_manager.get_radio_ssl_context(radio_esp) + + connection_manager_wifi = adafruit_connection_manager.get_connection_manager( + socket_pool_wifi + ) + connection_manager_esp = adafruit_connection_manager.get_connection_manager( + socket_pool_esp + ) + + assert socket_pool_wifi != socket_pool_esp + assert ssl_context_wifi != ssl_context_esp + assert connection_manager_wifi != connection_manager_esp + + adafruit_connection_manager.connection_manager_close_all( + socket_pool_wifi, release_references=False + ) + + assert socket_pool_wifi in adafruit_connection_manager._global_socketpools.values() + assert socket_pool_esp in adafruit_connection_manager._global_socketpools.values() + + assert ssl_context_wifi in adafruit_connection_manager._global_ssl_contexts.values() + assert ssl_context_esp in adafruit_connection_manager._global_ssl_contexts.values() + + assert ( + socket_pool_wifi + in adafruit_connection_manager._global_connection_managers.keys() + ) + assert ( + socket_pool_esp + in adafruit_connection_manager._global_connection_managers.keys() + ) + + +def test_connection_manager_close_all_single_release_references_true( # pylint: disable=unused-argument + circuitpython_socketpool_module, adafruit_esp32spi_socket_module +): + radio_wifi = mocket.MockRadio.Radio() + radio_esp = mocket.MockRadio.ESP_SPIcontrol() + + socket_pool_wifi = adafruit_connection_manager.get_radio_socketpool(radio_wifi) + socket_pool_esp = adafruit_connection_manager.get_radio_socketpool(radio_esp) + + ssl_context_wifi = adafruit_connection_manager.get_radio_ssl_context(radio_wifi) + ssl_context_esp = adafruit_connection_manager.get_radio_ssl_context(radio_esp) + + connection_manager_wifi = adafruit_connection_manager.get_connection_manager( + socket_pool_wifi + ) + connection_manager_esp = adafruit_connection_manager.get_connection_manager( + socket_pool_esp + ) + + assert socket_pool_wifi != socket_pool_esp + assert ssl_context_wifi != ssl_context_esp + assert connection_manager_wifi != connection_manager_esp + + adafruit_connection_manager.connection_manager_close_all( + socket_pool_wifi, release_references=True + ) + + assert ( + socket_pool_wifi not in adafruit_connection_manager._global_socketpools.values() + ) + assert socket_pool_esp in adafruit_connection_manager._global_socketpools.values() + + assert ( + ssl_context_wifi + not in adafruit_connection_manager._global_ssl_contexts.values() + ) + assert ssl_context_esp in adafruit_connection_manager._global_ssl_contexts.values() + + assert ( + socket_pool_wifi + not in adafruit_connection_manager._global_connection_managers.keys() + ) + assert ( + socket_pool_esp + in adafruit_connection_manager._global_connection_managers.keys() + ) diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 426f785..5c43ad1 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -19,6 +19,7 @@ def test_get_radio_socketpool_wifi( # pylint: disable=unused-argument radio = mocket.MockRadio.Radio() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert isinstance(socket_pool, mocket.MocketPool) + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument @@ -27,6 +28,7 @@ def test_get_radio_socketpool_esp32spi( # pylint: disable=unused-argument radio = mocket.MockRadio.ESP_SPIcontrol() socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool.__name__ == "adafruit_esp32spi_socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument @@ -36,6 +38,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument with mock.patch("sys.implementation", return_value=[9, 0, 0]): socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool.__name__ == "adafruit_wiznet5k_socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() def test_get_radio_socketpool_unsupported(): @@ -52,22 +55,25 @@ def test_get_radio_socketpool_returns_same_one( # pylint: disable=unused-argume socket_pool_1 = adafruit_connection_manager.get_radio_socketpool(radio) socket_pool_2 = adafruit_connection_manager.get_radio_socketpool(radio) assert socket_pool_1 == socket_pool_2 + assert socket_pool_1 in adafruit_connection_manager._global_socketpools.values() def test_get_radio_ssl_context_wifi( # pylint: disable=unused-argument circuitpython_socketpool_module, ): radio = mocket.MockRadio.Radio() - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, ssl.SSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, ssl.SSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_esp32spi( # pylint: disable=unused-argument adafruit_esp32spi_socket_module, ): radio = mocket.MockRadio.ESP_SPIcontrol() - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, adafruit_connection_manager._FakeSSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument @@ -75,8 +81,9 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument ): radio = mocket.MockRadio.WIZNET5K() with mock.patch("sys.implementation", return_value=[9, 0, 0]): - ssl_contexts = adafruit_connection_manager.get_radio_ssl_context(radio) - assert isinstance(ssl_contexts, adafruit_connection_manager._FakeSSLContext) + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, adafruit_connection_manager._FakeSSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() def test_get_radio_ssl_context_unsupported(): @@ -90,6 +97,7 @@ def test_get_radio_ssl_context_returns_same_one( # pylint: disable=unused-argum circuitpython_socketpool_module, ): radio = mocket.MockRadio.Radio() - ssl_contexts_1 = adafruit_connection_manager.get_radio_ssl_context(radio) - ssl_contexts_2 = adafruit_connection_manager.get_radio_ssl_context(radio) - assert ssl_contexts_1 == ssl_contexts_2 + ssl_context_1 = adafruit_connection_manager.get_radio_ssl_context(radio) + ssl_context_2 = adafruit_connection_manager.get_radio_ssl_context(radio) + assert ssl_context_1 == ssl_context_2 + assert ssl_context_1 in adafruit_connection_manager._global_ssl_contexts.values() From 9aa4b455249165ef74a4a8a6313d4a0ca0c6dadf Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sat, 27 Apr 2024 15:34:51 -0700 Subject: [PATCH 08/20] Code review comments --- adafruit_connection_manager.py | 154 ++++++++++----------- examples/connectionmanager_helpers.py | 16 +-- tests/close_socket_test.py | 8 +- tests/connection_manager_close_all_test.py | 56 ++++---- tests/free_socket_test.py | 75 ++++------ 5 files changed, 135 insertions(+), 174 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 16da447..9f1d565 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -35,7 +35,7 @@ if not sys.implementation.name == "circuitpython": - from typing import Optional, Tuple + from typing import List, Optional, Tuple from circuitpython_typing.socket import ( CircuitPythonSocketType, @@ -71,8 +71,7 @@ class _FakeSSLContext: def __init__(self, iface: InterfaceType) -> None: self._iface = iface - # pylint: disable=unused-argument - def wrap_socket( + def wrap_socket( # pylint: disable=unused-argument self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None ) -> _FakeSSLSocket: """Return the same socket""" @@ -184,54 +183,75 @@ def __init__( ) -> None: self._socket_pool = socket_pool # Hang onto open sockets so that we can reuse them. - self._available_socket = {} - self._open_sockets = {} + self._available_sockets = set() + self._managed_socket_by_key = {} + self._managed_socket_by_socket = {} def _free_sockets(self, force: bool = False) -> None: - available_sockets = [] - for socket, free in self._available_socket.items(): - if free or force: - available_sockets.append(socket) - + # cloning lists since items are being removed + available_sockets = list(self._available_sockets) for socket in available_sockets: self.close_socket(socket) + if force: + open_sockets = list(self._managed_socket_by_key.values()) + for socket in open_sockets: + self.close_socket(socket) - def _get_key_for_socket(self, socket): + def _get_connected_socket( # pylint: disable=too-many-arguments + self, + addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], + host: str, + port: int, + timeout: float, + is_ssl: bool, + ssl_context: Optional[SSLContextType] = None, + ): try: - return next( - key for key, value in self._open_sockets.items() if value == socket - ) - except StopIteration: - return None + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) + except (OSError, RuntimeError) as exc: + return exc - @property - def open_sockets(self) -> int: - """Get the count of open sockets""" - return len(self._open_sockets) + if is_ssl: + socket = ssl_context.wrap_socket(socket, server_hostname=host) + connect_host = host + else: + connect_host = addr_info[-1][0] + socket.settimeout(timeout) # socket read timeout + + try: + socket.connect((connect_host, port)) + except (MemoryError, OSError) as exc: + socket.close() + return exc + + return socket @property - def freeable_open_sockets(self) -> int: + def available_socket_count(self) -> int: """Get the count of freeable open sockets""" - return len( - [socket for socket, free in self._available_socket.items() if free is True] - ) + return len(self._available_sockets) + + @property + def managed_socket_count(self) -> int: + """Get the count of open sockets""" + return len(self._managed_socket_by_key) def close_socket(self, socket: SocketType) -> None: """Close a previously opened socket.""" - if socket not in self._open_sockets.values(): + if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") - key = self._get_key_for_socket(socket) socket.close() - del self._available_socket[socket] - del self._open_sockets[key] + key = self._managed_socket_by_socket.pop(socket) + del self._managed_socket_by_key[key] + if socket in self._available_sockets: + self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: """Mark a previously opened socket as available so it can be reused if needed.""" - if socket not in self._open_sockets.values(): + if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") - self._available_socket[socket] = True + self._available_sockets.add(socket) - # pylint: disable=too-many-branches,too-many-locals,too-many-statements def get_socket( self, host: str, @@ -247,10 +267,10 @@ def get_socket( if session_id: session_id = str(session_id) key = (host, port, proto, session_id) - if key in self._open_sockets: - socket = self._open_sockets[key] - if self._available_socket[socket]: - self._available_socket[socket] = False + if key in self._managed_socket_by_key: + socket = self._managed_socket_by_key[key] + if socket in self._available_sockets: + self._available_sockets.remove(socket) return socket raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") @@ -266,54 +286,22 @@ def get_socket( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - try_count = 0 - socket = None - last_exc = None - while try_count < 2 and socket is None: - try_count += 1 - if try_count > 1: - if any( - socket - for socket, free in self._available_socket.items() - if free is True - ): - self._free_sockets() - else: - break - - try: - socket = self._socket_pool.socket(addr_info[0], addr_info[1]) - except OSError as exc: - last_exc = exc - continue - except RuntimeError as exc: - last_exc = exc - continue - - if is_ssl: - socket = ssl_context.wrap_socket(socket, server_hostname=host) - connect_host = host - else: - connect_host = addr_info[-1][0] - socket.settimeout(timeout) # socket read timeout - - try: - socket.connect((connect_host, port)) - except MemoryError as exc: - last_exc = exc - socket.close() - socket = None - except OSError as exc: - last_exc = exc - socket.close() - socket = None - - if socket is None: - raise RuntimeError(f"Error connecting socket: {last_exc}") from last_exc - - self._available_socket[socket] = False - self._open_sockets[key] = socket - return socket + result = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + if isinstance(result, Exception): + # Got an error, if there are any available sockets, free them and try again + if self.available_socket_count: + self._free_sockets() + result = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + if isinstance(result, Exception): + raise RuntimeError(f"Error connecting socket: {result}") from result + + self._managed_socket_by_key[key] = result + self._managed_socket_by_socket[result] = key + return result # global helpers diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py index df383bd..e9fb842 100644 --- a/examples/connectionmanager_helpers.py +++ b/examples/connectionmanager_helpers.py @@ -27,8 +27,8 @@ connection_manager = adafruit_connection_manager.get_connection_manager(pool) print("-" * 40) print("Nothing yet opened") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") # make request print("-" * 40) @@ -39,8 +39,8 @@ print("-" * 40) print("1 request, opened and freed") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") print("-" * 40) print(f"Fetching from {TEXT_URL} not in a context handler") @@ -48,8 +48,8 @@ print("-" * 40) print("1 request, opened but not freed") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") print("-" * 40) print("Closing everything in the pool") @@ -57,5 +57,5 @@ print("-" * 40) print("Everything closed") -print(f"Open Sockets: {connection_manager.open_sockets}") -print(f"Freeable Open Sockets: {connection_manager.freeable_open_sockets}") +print(f"Open Sockets: {connection_manager.managed_socket_count}") +print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") diff --git a/tests/close_socket_test.py b/tests/close_socket_test.py index 957cb94..3927181 100644 --- a/tests/close_socket_test.py +++ b/tests/close_socket_test.py @@ -21,13 +21,13 @@ def test_close_socket(): socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") key = (mocket.MOCK_HOST_1, 80, "http:", None) assert socket == mock_socket_1 - assert socket in connection_manager._available_socket - assert key in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key # validate socket is no longer tracked connection_manager.close_socket(socket) - assert socket not in connection_manager._available_socket - assert key not in connection_manager._open_sockets + assert socket not in connection_manager._available_sockets + assert key not in connection_manager._managed_socket_by_key def test_close_socket_not_managed(): diff --git a/tests/connection_manager_close_all_test.py b/tests/connection_manager_close_all_test.py index 4bea462..c0fa498 100644 --- a/tests/connection_manager_close_all_test.py +++ b/tests/connection_manager_close_all_test.py @@ -18,29 +18,29 @@ def test_connection_manager_close_all_all(): connection_manager_1 = adafruit_connection_manager.get_connection_manager( mock_pool_1 ) - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 connection_manager_2 = adafruit_connection_manager.get_connection_manager( mock_pool_2 ) - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 assert len(adafruit_connection_manager._global_connection_managers) == 2 socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_1.open_sockets == 1 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 1 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_2.open_sockets == 1 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 adafruit_connection_manager.connection_manager_close_all() - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 socket_1.close.assert_called_once() socket_2.close.assert_called_once() @@ -53,29 +53,29 @@ def test_connection_manager_close_all_single(): connection_manager_1 = adafruit_connection_manager.get_connection_manager( mock_pool_1 ) - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 connection_manager_2 = adafruit_connection_manager.get_connection_manager( mock_pool_2 ) - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 assert len(adafruit_connection_manager._global_connection_managers) == 2 socket_1 = connection_manager_1.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_1.open_sockets == 1 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 0 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 1 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 0 + assert connection_manager_2.available_socket_count == 0 socket_2 = connection_manager_2.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert connection_manager_2.open_sockets == 1 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 adafruit_connection_manager.connection_manager_close_all(mock_pool_1) - assert connection_manager_1.open_sockets == 0 - assert connection_manager_1.freeable_open_sockets == 0 - assert connection_manager_2.open_sockets == 1 - assert connection_manager_2.freeable_open_sockets == 0 + assert connection_manager_1.managed_socket_count == 0 + assert connection_manager_1.available_socket_count == 0 + assert connection_manager_2.managed_socket_count == 1 + assert connection_manager_2.available_socket_count == 0 socket_1.close.assert_called_once() socket_2.close.assert_not_called() diff --git a/tests/free_socket_test.py b/tests/free_socket_test.py index b39d6d5..666a072 100644 --- a/tests/free_socket_test.py +++ b/tests/free_socket_test.py @@ -16,26 +16,24 @@ def test_free_socket(): mock_pool.socket.return_value = mock_socket_1 connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - assert connection_manager.open_sockets == 0 - assert connection_manager.freeable_open_sockets == 0 + assert connection_manager.managed_socket_count == 0 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and not available socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") key = (mocket.MOCK_HOST_1, 80, "http:", None) assert socket == mock_socket_1 - assert socket in connection_manager._available_socket - assert connection_manager._available_socket[socket] is False - assert key in connection_manager._open_sockets - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 0 + assert socket not in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and is available connection_manager.free_socket(socket) - assert socket in connection_manager._available_socket - assert connection_manager._available_socket[socket] is True - assert key in connection_manager._open_sockets - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 1 + assert socket in connection_manager._available_sockets + assert key in connection_manager._managed_socket_by_key + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 1 def test_free_socket_not_managed(): @@ -60,56 +58,31 @@ def test_free_sockets(): ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - assert connection_manager.open_sockets == 0 - assert connection_manager.freeable_open_sockets == 0 + assert connection_manager.managed_socket_count == 0 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and not available socket_1 = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") assert socket_1 == mock_socket_1 - assert socket_1 in connection_manager._available_socket - assert connection_manager._available_socket[socket_1] is False - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 0 + assert socket_1 not in connection_manager._available_sockets + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 socket_2 = connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") assert socket_2 == mock_socket_2 - assert connection_manager.open_sockets == 2 - assert connection_manager.freeable_open_sockets == 0 + assert connection_manager.managed_socket_count == 2 + assert connection_manager.available_socket_count == 0 # validate socket is tracked and is available connection_manager.free_socket(socket_1) - assert socket_1 in connection_manager._available_socket - assert connection_manager._available_socket[socket_1] is True - assert connection_manager.open_sockets == 2 - assert connection_manager.freeable_open_sockets == 1 + assert socket_1 in connection_manager._available_sockets + assert connection_manager.managed_socket_count == 2 + assert connection_manager.available_socket_count == 1 # validate socket is no longer tracked connection_manager._free_sockets() - assert socket_1 not in connection_manager._available_socket - assert socket_2 in connection_manager._available_socket + assert socket_1 not in connection_manager._available_sockets + assert socket_2 not in connection_manager._available_sockets mock_socket_1.close.assert_called_once() - assert connection_manager.open_sockets == 1 - assert connection_manager.freeable_open_sockets == 0 - - -def test_get_key_for_socket(): - mock_pool = mocket.MocketPool() - mock_socket_1 = mocket.Mocket() - mock_pool.socket.return_value = mock_socket_1 - - connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - - # validate tracked socket has correct key - socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - key = (mocket.MOCK_HOST_1, 80, "http:", None) - assert connection_manager._get_key_for_socket(socket) == key - - -def test_get_key_for_socket_not_managed(): - mock_pool = mocket.MocketPool() - mock_socket_1 = mocket.Mocket() - - connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - - # validate untracked socket has no key - assert connection_manager._get_key_for_socket(mock_socket_1) is None + assert connection_manager.managed_socket_count == 1 + assert connection_manager.available_socket_count == 0 From f33d78cf757096bf760a3a0602893db078aa1811 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sat, 27 Apr 2024 15:57:46 -0700 Subject: [PATCH 09/20] rename var --- adafruit_connection_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 9f1d565..bb03328 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -184,8 +184,8 @@ def __init__( self._socket_pool = socket_pool # Hang onto open sockets so that we can reuse them. self._available_sockets = set() + self._key_by_managed_socket = {} self._managed_socket_by_key = {} - self._managed_socket_by_socket = {} def _free_sockets(self, force: bool = False) -> None: # cloning lists since items are being removed @@ -241,7 +241,7 @@ def close_socket(self, socket: SocketType) -> None: if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") socket.close() - key = self._managed_socket_by_socket.pop(socket) + key = self._key_by_managed_socket.pop(socket) del self._managed_socket_by_key[key] if socket in self._available_sockets: self._available_sockets.remove(socket) @@ -299,8 +299,8 @@ def get_socket( if isinstance(result, Exception): raise RuntimeError(f"Error connecting socket: {result}") from result + self._key_by_managed_socket[result] = key self._managed_socket_by_key[key] = result - self._managed_socket_by_socket[result] = key return result From 601ee66791d53d72572b581f28d8d92a2c0401ec Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sat, 27 Apr 2024 19:02:32 -0700 Subject: [PATCH 10/20] Little error cleanup --- adafruit_connection_manager.py | 9 +++++++-- tests/get_socket_test.py | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index bb03328..9dc8f51 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -64,7 +64,7 @@ def connect(self, address: Tuple[str, int]) -> None: try: return self._socket.connect(address, self._mode) except RuntimeError as error: - raise OSError(errno.ENOMEM) from error + raise OSError(errno.ENOMEM, str(error)) from error class _FakeSSLContext: @@ -286,18 +286,23 @@ def get_socket( host, port, 0, self._socket_pool.SOCK_STREAM )[0] + first_exception = None result = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) if isinstance(result, Exception): # Got an error, if there are any available sockets, free them and try again if self.available_socket_count: + first_exception = result self._free_sockets() result = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) if isinstance(result, Exception): - raise RuntimeError(f"Error connecting socket: {result}") from result + last_result = f", first error: {first_exception}" if first_exception else "" + raise RuntimeError( + f"Error connecting socket: {result}{last_result}" + ) from result self._key_by_managed_socket[result] = key self._managed_socket_by_key[key] = result diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index ea252cc..6be48f0 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -213,7 +213,7 @@ def test_get_socket_runtime_error_ties_again_only_once(): # try to get a socket that returns a RuntimeError twice with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") - assert "Error connecting socket: error 2" in str(context) + assert "Error connecting socket: error 2, first error: error 1" in str(context) free_sockets_mock.assert_called_once() @@ -242,7 +242,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.return_value = mock_socket_1 - mock_socket_1.connect.side_effect = RuntimeError("RuntimeError ") + mock_socket_1.connect.side_effect = RuntimeError("RuntimeError") radio = mocket.MockRadio.ESP_SPIcontrol() ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) @@ -252,4 +252,4 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) - assert "Error connecting socket: 12" in str(context) + assert "Error connecting socket: [Errno 12] RuntimeError" in str(context) From 72fdd31c8c5cb3e8214cffff87bd096174c84093 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Sun, 28 Apr 2024 09:18:30 -0700 Subject: [PATCH 11/20] Update docs --- adafruit_connection_manager.py | 72 ++++++++++++++++++--------- examples/connectionmanager_helpers.py | 20 ++++---- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 9dc8f51..c8c209e 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -99,12 +99,13 @@ def create_fake_ssl_context( _global_connection_managers = {} +_global_key_by_socketpool = {} _global_socketpools = {} _global_ssl_contexts = {} def get_radio_socketpool(radio): - """Helper to get a socket pool for common boards + """Helper to get a socket pool for common boards. Currently supported: @@ -151,6 +152,7 @@ def get_radio_socketpool(radio): else: raise AttributeError(f"Unsupported radio class: {class_name}") + _global_key_by_socketpool[pool] = class_name _global_socketpools[class_name] = pool _global_ssl_contexts[class_name] = ssl_context @@ -158,7 +160,7 @@ def get_radio_socketpool(radio): def get_radio_ssl_context(radio): - """Helper to get ssl_contexts for common boards + """Helper to get ssl_contexts for common boards. Currently supported: @@ -175,7 +177,7 @@ def get_radio_ssl_context(radio): class ConnectionManager: - """Connection manager for sharing open sockets (aka connections).""" + """A library for managing sockets accross libraries.""" def __init__( self, @@ -228,16 +230,20 @@ def _get_connected_socket( # pylint: disable=too-many-arguments @property def available_socket_count(self) -> int: - """Get the count of freeable open sockets""" + """Get the count of available (freed) managed sockets.""" return len(self._available_sockets) @property def managed_socket_count(self) -> int: - """Get the count of open sockets""" + """Get the count of managed sockets.""" return len(self._managed_socket_by_key) def close_socket(self, socket: SocketType) -> None: - """Close a previously opened socket.""" + """ + Close a previously managed and connected socket. + + - **socket_pool** *(SocketType)* – The socket you want to close + """ if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") socket.close() @@ -247,7 +253,7 @@ def close_socket(self, socket: SocketType) -> None: self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: - """Mark a previously opened socket as available so it can be reused if needed.""" + """Mark a managed socket as available so it can be reused.""" if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") self._available_sockets.add(socket) @@ -263,7 +269,20 @@ def get_socket( is_ssl: bool = False, ssl_context: Optional[SSLContextType] = None, ) -> CircuitPythonSocketType: - """Get a new socket and connect""" + """ + Get a new socket and connect. + + - **host** *(str)* – The host you are want to connect to: "www.adaftuit.com" + - **port** *(int)* – The port you want to connect to: 80 + - **proto** *(str)* – The protocal you want to use: "http:" + - **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open + connections to the same host + - **timeout** *(float)* – Time timeout used for connecting + - **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is + "https:") + - **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL + requests + """ if session_id: session_id = str(session_id) key = (host, port, proto, session_id) @@ -315,7 +334,14 @@ def get_socket( def connection_manager_close_all( socket_pool: Optional[SocketpoolModuleType] = None, release_references: bool = False ) -> None: - """Close all open sockets for pool""" + """ + Close all open sockets for pool, optionally release references. + + - **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close + sockets for, leave blank for all SocketPools + - **release_references** *(bool)* – Set to True if you want to also clear stored references to + the SocketPool and SSL contexts + """ if socket_pool: socket_pools = [socket_pool] else: @@ -328,26 +354,24 @@ def connection_manager_close_all( connection_manager._free_sockets(force=True) # pylint: disable=protected-access - if release_references: - radio_key = None - for radio_check, pool_check in _global_socketpools.items(): - if pool == pool_check: - radio_key = radio_check - break + if not release_references: + continue - if radio_key: - if radio_key in _global_socketpools: - del _global_socketpools[radio_key] + key = _global_key_by_socketpool.pop(pool) + if key: + _global_socketpools.pop(key, None) + _global_ssl_contexts.pop(key, None) - if radio_key in _global_ssl_contexts: - del _global_ssl_contexts[radio_key] - - if pool in _global_connection_managers: - del _global_connection_managers[pool] + _global_connection_managers.pop(pool, None) def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: - """Get the ConnectionManager singleton for the given pool""" + """ + Get the ConnectionManager singleton for the given pool. + + - **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the + ConnectionManager for + """ if socket_pool not in _global_connection_managers: _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) return _global_connection_managers[socket_pool] diff --git a/examples/connectionmanager_helpers.py b/examples/connectionmanager_helpers.py index e9fb842..d4bb916 100644 --- a/examples/connectionmanager_helpers.py +++ b/examples/connectionmanager_helpers.py @@ -27,8 +27,8 @@ connection_manager = adafruit_connection_manager.get_connection_manager(pool) print("-" * 40) print("Nothing yet opened") -print(f"Open Sockets: {connection_manager.managed_socket_count}") -print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") +print(f"Managed Sockets: {connection_manager.managed_socket_count}") +print(f"Available Managed Sockets: {connection_manager.available_socket_count}") # make request print("-" * 40) @@ -38,18 +38,18 @@ print(f"Text Response {response_text}") print("-" * 40) -print("1 request, opened and freed") -print(f"Open Sockets: {connection_manager.managed_socket_count}") -print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") +print("1 request, opened and closed") +print(f"Managed Sockets: {connection_manager.managed_socket_count}") +print(f"Available Managed Sockets: {connection_manager.available_socket_count}") print("-" * 40) print(f"Fetching from {TEXT_URL} not in a context handler") response = requests.get(TEXT_URL) print("-" * 40) -print("1 request, opened but not freed") -print(f"Open Sockets: {connection_manager.managed_socket_count}") -print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") +print("1 request, opened but not closed") +print(f"Managed Sockets: {connection_manager.managed_socket_count}") +print(f"Available Managed Sockets: {connection_manager.available_socket_count}") print("-" * 40) print("Closing everything in the pool") @@ -57,5 +57,5 @@ print("-" * 40) print("Everything closed") -print(f"Open Sockets: {connection_manager.managed_socket_count}") -print(f"Freeable Open Sockets: {connection_manager.available_socket_count}") +print(f"Managed Sockets: {connection_manager.managed_socket_count}") +print(f"Available Managed Sockets: {connection_manager.available_socket_count}") From b14ed9920a651abb02e4ea38bda75556b1b8b8b9 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Tue, 30 Apr 2024 06:19:49 -0700 Subject: [PATCH 12/20] Code review updates --- adafruit_connection_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 4b5eb1d..5b8a10c 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -58,10 +58,10 @@ def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self.recv = socket.recv self.close = socket.close self.recv_into = socket.recv_into - if hasattr(socket, "_interface"): - self._interface = socket._interface - if hasattr(socket, "_socket_pool"): - self._socket_pool = socket._socket_pool + # For sockets that come from software socketpools (like the esp32api), they track + # the interface and socket pool. We need to make sure the clones do as well + self._interface = getattr(socket, "_interface", None) + self._socket_pool = getattr(socket, "_socket_pool", None) def connect(self, address: Tuple[str, int]) -> None: """Connect wrapper to add non-standard mode parameter""" From 99f8972dcec55a8a4ecb2db70aa58e5250b92910 Mon Sep 17 00:00:00 2001 From: Dan Halbert Date: Fri, 10 May 2024 16:47:37 -0400 Subject: [PATCH 13/20] Recover in more cases when a socket cannot be created. Also: - Clarify some documentation. Use sphinx argument documentation style. - Fix some typos. - Remove a few internal comments marking code sections. - Clarify an error message. - Internally, catch exceptions instead of passing them back. - Change one exception. - Update to pylint 3.1.0 so pre-commit can run under Python 3.12 --- .pre-commit-config.yaml | 2 +- adafruit_connection_manager.py | 115 +++++++++++++++------------------ 2 files changed, 52 insertions(+), 65 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77ed663..4d2e392 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/pylint - rev: v2.17.4 + rev: v3.1.0 hooks: - id: pylint name: pylint (library code) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 5b8a10c..27f8a7b 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -21,8 +21,6 @@ """ -# imports - __version__ = "0.0.0+auto.0" __repo__ = "https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager.git" @@ -31,9 +29,6 @@ WIZNET5K_SSL_SUPPORT_VERSION = (9, 1) -# typing - - if not sys.implementation.name == "circuitpython": from typing import List, Optional, Tuple @@ -46,9 +41,6 @@ ) -# ssl and pool helpers - - class _FakeSSLSocket: def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: self._socket = socket @@ -189,11 +181,8 @@ def get_radio_ssl_context(radio): return _global_ssl_contexts[_get_radio_hash_key(radio)] -# main class - - class ConnectionManager: - """A library for managing sockets accross libraries.""" + """A library for managing sockets across multiple hardware platforms and libraries.""" def __init__( self, @@ -224,23 +213,24 @@ def _get_connected_socket( # pylint: disable=too-many-arguments is_ssl: bool, ssl_context: Optional[SSLContextType] = None, ): - try: - socket = self._socket_pool.socket(addr_info[0], addr_info[1]) - except (OSError, RuntimeError) as exc: - return exc + + socket = self._socket_pool.socket(addr_info[0], addr_info[1]) if is_ssl: socket = ssl_context.wrap_socket(socket, server_hostname=host) connect_host = host else: connect_host = addr_info[-1][0] - socket.settimeout(timeout) # socket read timeout + + # Set socket read and connect timeout. + socket.settimeout(timeout) try: socket.connect((connect_host, port)) - except (MemoryError, OSError) as exc: + except (MemoryError, OSError): + # If any connect problems, clean up and re-raise the problem exception. socket.close() - return exc + raise return socket @@ -269,11 +259,16 @@ def close_socket(self, socket: SocketType) -> None: self._available_sockets.remove(socket) def free_socket(self, socket: SocketType) -> None: - """Mark a managed socket as available so it can be reused.""" + """Mark a managed socket as available so it can be reused. The socket is not closed.""" if socket not in self._managed_socket_by_key.values(): raise RuntimeError("Socket not managed") self._available_sockets.add(socket) + def _register_connected_socket(self, key, socket): + self._key_by_managed_socket[socket] = key + self._managed_socket_by_key[key] = socket + + # pylint: disable=too-many-arguments def get_socket( self, host: str, @@ -281,70 +276,65 @@ def get_socket( proto: str, session_id: Optional[str] = None, *, - timeout: float = 1, + timeout: float = 1.0, is_ssl: bool = False, ssl_context: Optional[SSLContextType] = None, ) -> CircuitPythonSocketType: """ - Get a new socket and connect. - - - **host** *(str)* – The host you are want to connect to: "www.adaftuit.com" - - **port** *(int)* – The port you want to connect to: 80 - - **proto** *(str)* – The protocal you want to use: "http:" - - **session_id** *(Optional[str])* – A unique Session ID, when wanting to have multiple open - connections to the same host - - **timeout** *(float)* – Time timeout used for connecting - - **is_ssl** *(bool)* – If the connection is to be over SSL (auto set when proto is - "https:") - - **ssl_context** *(Optional[SSLContextType])* – The SSL context to use when making SSL - requests + Get a new socket and connect to the given host. + + :param str host: host to connect to, such as ``"www.example.org"`` + :param int port: port to use for connection, such as ``80`` or ``443`` + :param str proto: connection protocol: ``"http:"``, ``"https:"``, etc. + :param Optional[str]: unique session ID, + used for multiple simultaneous connections to the same host + :param float timeout: how long to wait to connect + :param bool is_ssl: ``True`` If the connection is to be over SSL; + automatically set when ``proto`` is ``"https:"` + :param Optional[SSLContextType]: SSL context to use when making SSL requests """ if session_id: session_id = str(session_id) key = (host, port, proto, session_id) + + # Do we have already have a socket available for the requested connection? if key in self._managed_socket_by_key: socket = self._managed_socket_by_key[key] if socket in self._available_sockets: self._available_sockets.remove(socket) return socket - raise RuntimeError(f"Socket already connected to {proto}//{host}:{port}") + raise RuntimeError( + f"An existing socket is already connected to {proto}//{host}:{port}" + ) if proto == "https:": is_ssl = True if is_ssl and not ssl_context: - raise AttributeError( - "ssl_context must be set before using adafruit_requests for https" - ) + raise ValueError("ssl_context must be provided if using ssl") addr_info = self._socket_pool.getaddrinfo( host, port, 0, self._socket_pool.SOCK_STREAM )[0] - first_exception = None - result = self._get_connected_socket( - addr_info, host, port, timeout, is_ssl, ssl_context - ) - if isinstance(result, Exception): - # Got an error, if there are any available sockets, free them and try again + try: + socket = self._get_connected_socket( + addr_info, host, port, timeout, is_ssl, ssl_context + ) + self._register_connected_socket(key, socket) + return socket + except (MemoryError, OSError, RuntimeError): + # Could not get a new socket (or two, if SSL). + # If there are any available sockets, free them all and try again. if self.available_socket_count: - first_exception = result self._free_sockets() - result = self._get_connected_socket( + socket = self._get_connected_socket( addr_info, host, port, timeout, is_ssl, ssl_context ) - if isinstance(result, Exception): - last_result = f", first error: {first_exception}" if first_exception else "" - raise RuntimeError( - f"Error connecting socket: {result}{last_result}" - ) from result - - self._key_by_managed_socket[result] = key - self._managed_socket_by_key[key] = result - return result - - -# global helpers + self._register_connected_socket(key, socket) + return socket + # Re-raise exception if no sockets could be freed. + raise def connection_manager_close_all( @@ -353,9 +343,9 @@ def connection_manager_close_all( """ Close all open sockets for pool, optionally release references. - - **socket_pool** *(Optional[SocketpoolModuleType])* – A specifc SocketPool you want to close - sockets for, leave blank for all SocketPools - - **release_references** *(bool)* – Set to True if you want to also clear stored references to + :param Optional[SocketpoolModuleType] socket_pool: + a specific `SocketPool` whose sockets you want to close; `None`` means all `SocketPool`s + :param bool release_references: ``True`` if you want to also clear stored references to the SocketPool and SSL contexts """ if socket_pool: @@ -383,10 +373,7 @@ def connection_manager_close_all( def get_connection_manager(socket_pool: SocketpoolModuleType) -> ConnectionManager: """ - Get the ConnectionManager singleton for the given pool. - - - **socket_pool** *(Optional[SocketpoolModuleType])* – The SocketPool you want the - ConnectionManager for + Get or create the ConnectionManager singleton for the given pool. """ if socket_pool not in _global_connection_managers: _global_connection_managers[socket_pool] = ConnectionManager(socket_pool) From 917edc8bdd4962e9053a55c902d0d7c9ab545c50 Mon Sep 17 00:00:00 2001 From: Dan Halbert Date: Fri, 10 May 2024 17:09:43 -0400 Subject: [PATCH 14/20] use ValueError instead of AttributeError for incorrect args --- adafruit_connection_manager.py | 4 ++-- tests/get_radio_test.py | 4 ++-- tests/protocol_test.py | 4 ++-- tests/ssl_context_test.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 27f8a7b..b195d19 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -74,7 +74,7 @@ def wrap_socket( # pylint: disable=unused-argument if hasattr(self._iface, "TLS_MODE"): return _FakeSSLSocket(socket, self._iface.TLS_MODE) - raise AttributeError("This radio does not support TLS/HTTPS") + raise ValueError("This radio does not support TLS/HTTPS") def create_fake_ssl_context( @@ -159,7 +159,7 @@ def get_radio_socketpool(radio): ssl_context = create_fake_ssl_context(pool, radio) else: - raise AttributeError(f"Unsupported radio class: {class_name}") + raise ValueError(f"Unsupported radio class: {class_name}") _global_key_by_socketpool[pool] = key _global_socketpools[key] = pool diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 9844e9e..5631bdb 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -55,7 +55,7 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument def test_get_radio_socketpool_unsupported(): radio = mocket.MockRadio.Unsupported() - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: adafruit_connection_manager.get_radio_socketpool(radio) assert "Unsupported radio class" in str(context) @@ -100,7 +100,7 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument def test_get_radio_ssl_context_unsupported(): radio = mocket.MockRadio.Unsupported() - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: adafruit_connection_manager.get_radio_ssl_context(radio) assert "Unsupported radio class" in str(context) diff --git a/tests/protocol_test.py b/tests/protocol_test.py index 98b5296..50a071c 100644 --- a/tests/protocol_test.py +++ b/tests/protocol_test.py @@ -18,9 +18,9 @@ def test_get_https_no_ssl(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # verify not sending in a SSL context for a HTTPS call errors - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 443, "https:") - assert "ssl_context must be set" in str(context) + assert "ssl_context must be provided if using ssl" in str(context) def test_connect_https(): diff --git a/tests/ssl_context_test.py b/tests/ssl_context_test.py index 2f2e370..02bf96e 100644 --- a/tests/ssl_context_test.py +++ b/tests/ssl_context_test.py @@ -58,7 +58,7 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # verify a HTTPS call for a board without built in WiFi and SSL support errors - with pytest.raises(AttributeError) as context: + with pytest.raises(ValueError) as context: connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) From da3cd5fa8a9097be5b8c0f7cb293b6c58b7b082f Mon Sep 17 00:00:00 2001 From: Dan Halbert Date: Fri, 10 May 2024 21:41:40 -0400 Subject: [PATCH 15/20] redo test outputs --- adafruit_connection_manager.py | 17 +++++++++-------- tests/get_socket_test.py | 20 +++++++------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index b195d19..658f338 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -204,6 +204,11 @@ def _free_sockets(self, force: bool = False) -> None: for socket in open_sockets: self.close_socket(socket) + def _register_connected_socket(self, key, socket): + """Register a socket as managed.""" + self._key_by_managed_socket[socket] = key + self._managed_socket_by_key[key] = socket + def _get_connected_socket( # pylint: disable=too-many-arguments self, addr_info: List[Tuple[int, int, int, str, Tuple[str, int]]], @@ -264,10 +269,6 @@ def free_socket(self, socket: SocketType) -> None: raise RuntimeError("Socket not managed") self._available_sockets.add(socket) - def _register_connected_socket(self, key, socket): - self._key_by_managed_socket[socket] = key - self._managed_socket_by_key[key] = socket - # pylint: disable=too-many-arguments def get_socket( self, @@ -290,7 +291,7 @@ def get_socket( used for multiple simultaneous connections to the same host :param float timeout: how long to wait to connect :param bool is_ssl: ``True`` If the connection is to be over SSL; - automatically set when ``proto`` is ``"https:"` + automatically set when ``proto`` is ``"https:"`` :param Optional[SSLContextType]: SSL context to use when making SSL requests """ if session_id: @@ -344,9 +345,9 @@ def connection_manager_close_all( Close all open sockets for pool, optionally release references. :param Optional[SocketpoolModuleType] socket_pool: - a specific `SocketPool` whose sockets you want to close; `None`` means all `SocketPool`s - :param bool release_references: ``True`` if you want to also clear stored references to - the SocketPool and SSL contexts + a specific socket pool whose sockets you want to close; ``None`` means all socket pools + :param bool release_references: ``True`` if you also want the `ConnectionManager` to forget + all the socket pools and SSL contexts it knows about """ if socket_pool: socket_pools = [socket_pool] diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index 9abbf98..46d053b 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -91,7 +91,7 @@ def test_get_socket_not_flagged_free(): # get a socket for the same host, should be a different one with pytest.raises(RuntimeError) as context: socket = connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Socket already connected" in str(context) + assert "An existing socket is already connected" in str(context) def test_get_socket_os_error(): @@ -105,9 +105,8 @@ def test_get_socket_os_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a OSError - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: OSError" in str(context) def test_get_socket_runtime_error(): @@ -121,9 +120,8 @@ def test_get_socket_runtime_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a RuntimeError - with pytest.raises(RuntimeError) as context: + with pytest.raises(RuntimeError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: RuntimeError" in str(context) def test_get_socket_connect_memory_error(): @@ -139,9 +137,8 @@ def test_get_socket_connect_memory_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a MemoryError - with pytest.raises(RuntimeError) as context: + with pytest.raises(MemoryError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: MemoryError" in str(context) def test_get_socket_connect_os_error(): @@ -157,9 +154,8 @@ def test_get_socket_connect_os_error(): connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a OSError - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") - assert "Error connecting socket: OSError" in str(context) def test_get_socket_runtime_error_ties_again_at_least_one_free(): @@ -211,9 +207,8 @@ def test_get_socket_runtime_error_ties_again_only_once(): free_sockets_mock.assert_not_called() # try to get a socket that returns a RuntimeError twice - with pytest.raises(RuntimeError) as context: + with pytest.raises(RuntimeError): connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") - assert "Error connecting socket: error 2, first error: error 1" in str(context) free_sockets_mock.assert_called_once() @@ -248,8 +243,7 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - with pytest.raises(RuntimeError) as context: + with pytest.raises(OSError): connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) - assert "Error connecting socket: [Errno 12] RuntimeError" in str(context) From 0792b271bdb4afdefbad002638e60417abb15e3a Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Mon, 13 May 2024 06:09:34 -0700 Subject: [PATCH 16/20] Update tests --- tests/get_socket_test.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index 46d053b..f61cb26 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -98,30 +98,32 @@ def test_get_socket_os_error(): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.side_effect = [ - OSError("OSError"), + OSError("OSError 1"), mock_socket_1, ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a OSError - with pytest.raises(OSError): + with pytest.raises(OSError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "OSError 1" in str(context) def test_get_socket_runtime_error(): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.side_effect = [ - RuntimeError("RuntimeError"), + RuntimeError("RuntimeError 1"), mock_socket_1, ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a RuntimeError - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "RuntimeError 1" in str(context) def test_get_socket_connect_memory_error(): @@ -132,13 +134,14 @@ def test_get_socket_connect_memory_error(): mock_socket_1, mock_socket_2, ] - mock_socket_1.connect.side_effect = MemoryError("MemoryError") + mock_socket_1.connect.side_effect = MemoryError("MemoryError 1") connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a MemoryError - with pytest.raises(MemoryError): + with pytest.raises(MemoryError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "MemoryError 1" in str(context) def test_get_socket_connect_os_error(): @@ -149,13 +152,14 @@ def test_get_socket_connect_os_error(): mock_socket_1, mock_socket_2, ] - mock_socket_1.connect.side_effect = OSError("OSError") + mock_socket_1.connect.side_effect = OSError("OSError 1") connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a OSError - with pytest.raises(OSError): + with pytest.raises(OSError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "OSError 1" in str(context) def test_get_socket_runtime_error_ties_again_at_least_one_free(): @@ -190,9 +194,9 @@ def test_get_socket_runtime_error_ties_again_only_once(): mock_socket_2 = mocket.Mocket() mock_pool.socket.side_effect = [ mock_socket_1, - RuntimeError("error 1"), - RuntimeError("error 2"), - RuntimeError("error 3"), + RuntimeError("RuntimeError 1"), + RuntimeError("RuntimeError 2"), + RuntimeError("RuntimeError 3"), mock_socket_2, ] @@ -207,8 +211,9 @@ def test_get_socket_runtime_error_ties_again_only_once(): free_sockets_mock.assert_not_called() # try to get a socket that returns a RuntimeError twice - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") + assert "RuntimeError 2" in str(context) free_sockets_mock.assert_called_once() @@ -237,13 +242,15 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.return_value = mock_socket_1 - mock_socket_1.connect.side_effect = RuntimeError("RuntimeError") + mock_socket_1.connect.side_effect = RuntimeError("RuntimeError 1") radio = mocket.MockRadio.ESP_SPIcontrol() ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - with pytest.raises(OSError): + with pytest.raises(OSError) as context: connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) + assert "12" in str(context) + assert "RuntimeError 1" in str(context) From 0cccbe4022d2b168ecd142c334a20911c694dc32 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Mon, 13 May 2024 14:59:24 -0700 Subject: [PATCH 17/20] Support CPython --- adafruit_connection_manager.py | 10 ++++++++++ tests/get_radio_test.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 658f338..58dcb0c 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -97,6 +97,10 @@ def create_fake_ssl_context( return _FakeSSLContext(iface) +class CPythonNetwork: # pylint: disable=too-few-public-methods + """Radio object to use when using ConnectionManager in CPython.""" + + _global_connection_managers = {} _global_key_by_socketpool = {} _global_socketpools = {} @@ -158,6 +162,12 @@ def get_radio_socketpool(radio): if ssl_context is None: ssl_context = create_fake_ssl_context(pool, radio) + elif class_name == "CPythonNetwork": + import socket as pool # pylint: disable=import-outside-toplevel + import ssl # pylint: disable=import-outside-toplevel + + ssl_context = ssl.create_default_context() + else: raise ValueError(f"Unsupported radio class: {class_name}") diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 5631bdb..022aecd 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -53,6 +53,13 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument assert socket_pool in adafruit_connection_manager._global_socketpools.values() +def test_get_radio_socketpool_cpython(): + radio = adafruit_connection_manager.CPythonNetwork() + socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) + assert socket_pool.__name__ == "socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() + + def test_get_radio_socketpool_unsupported(): radio = mocket.MockRadio.Unsupported() with pytest.raises(ValueError) as context: @@ -98,6 +105,13 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() +def test_get_radio_ssl_context_cpython(): + radio = adafruit_connection_manager.CPythonNetwork() + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, ssl.SSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() + + def test_get_radio_ssl_context_unsupported(): radio = mocket.MockRadio.Unsupported() with pytest.raises(ValueError) as context: From 1247dd4b7635a251795d0ae8b52de31165dc1dc5 Mon Sep 17 00:00:00 2001 From: Justin Myers Date: Mon, 10 Jun 2024 16:31:04 -0700 Subject: [PATCH 18/20] Update WIZNet version check for SSL --- adafruit_connection_manager.py | 9 +++++++-- tests/ssl_context_test.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 58dcb0c..67152d7 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -149,8 +149,13 @@ def get_radio_socketpool(radio): # versions of the Wiznet5k library or on boards withouut the ssl module # see https://docs.circuitpython.org/en/latest/shared-bindings/support_matrix.html ssl_context = None - cp_version = sys.implementation[1] - if pool.SOCK_STREAM == 1 and cp_version >= WIZNET5K_SSL_SUPPORT_VERSION: + implementation_name = sys.implementation.name + implementation_version = sys.implementation.version + if ( + pool.SOCK_STREAM == 1 + and implementation_name == "circuitpython" + and implementation_version >= WIZNET5K_SSL_SUPPORT_VERSION + ): try: import ssl # pylint: disable=import-outside-toplevel diff --git a/tests/ssl_context_test.py b/tests/ssl_context_test.py index 02bf96e..25e389e 100644 --- a/tests/ssl_context_test.py +++ b/tests/ssl_context_test.py @@ -5,6 +5,7 @@ """ SLL Context Tests """ import ssl +from collections import namedtuple from unittest import mock import mocket @@ -13,6 +14,8 @@ import adafruit_connection_manager from adafruit_connection_manager import WIZNET5K_SSL_SUPPORT_VERSION +SimpleNamespace = namedtuple("SimpleNamespace", "name version") + def test_connect_esp32spi_https( # pylint: disable=unused-argument adafruit_esp32spi_socketpool_module, @@ -53,7 +56,9 @@ def test_connect_wiznet5k_https_not_supported( # pylint: disable=unused-argumen mock_pool = mocket.MocketPool() radio = mocket.MockRadio.WIZNET5K() old_version = (WIZNET5K_SSL_SUPPORT_VERSION[0] - 1, 0, 0) - with mock.patch("sys.implementation", (None, old_version)): + with mock.patch( + "sys.implementation", SimpleNamespace("circuitpython", old_version) + ): ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) @@ -69,6 +74,9 @@ def test_connect_wiznet5k_https_supported( # pylint: disable=unused-argument adafruit_wiznet5k_with_ssl_socketpool_module, ): radio = mocket.MockRadio.WIZNET5K() - with mock.patch("sys.implementation", (None, WIZNET5K_SSL_SUPPORT_VERSION)): + with mock.patch( + "sys.implementation", + SimpleNamespace("circuitpython", WIZNET5K_SSL_SUPPORT_VERSION), + ): ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) assert isinstance(ssl_context, ssl.SSLContext) From d79982a13879ba60f884c8286e2ca0962bdf5100 Mon Sep 17 00:00:00 2001 From: foamyguy Date: Mon, 7 Oct 2024 09:24:05 -0500 Subject: [PATCH 19/20] remove deprecated get_html_theme_path() call Signed-off-by: foamyguy --- docs/conf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/conf.py b/docs/conf.py index b184b10..000d273 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -115,7 +115,6 @@ import sphinx_rtd_theme html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path(), "."] # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, From 42073559468d0c8af9bb1fe5e06fccd4d1d9a845 Mon Sep 17 00:00:00 2001 From: foamyguy Date: Tue, 14 Jan 2025 11:32:34 -0600 Subject: [PATCH 20/20] add sphinx configuration to rtd.yaml Signed-off-by: foamyguy --- .readthedocs.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index b79ec5b..fe4faae 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -8,6 +8,9 @@ # Required version: 2 +sphinx: + configuration: docs/conf.py + build: os: ubuntu-20.04 tools: