diff --git a/adafruit_connection_manager.py b/adafruit_connection_manager.py index 658f338..58dcb0c 100644 --- a/adafruit_connection_manager.py +++ b/adafruit_connection_manager.py @@ -97,6 +97,10 @@ def create_fake_ssl_context( return _FakeSSLContext(iface) +class CPythonNetwork: # pylint: disable=too-few-public-methods + """Radio object to use when using ConnectionManager in CPython.""" + + _global_connection_managers = {} _global_key_by_socketpool = {} _global_socketpools = {} @@ -158,6 +162,12 @@ def get_radio_socketpool(radio): if ssl_context is None: ssl_context = create_fake_ssl_context(pool, radio) + elif class_name == "CPythonNetwork": + import socket as pool # pylint: disable=import-outside-toplevel + import ssl # pylint: disable=import-outside-toplevel + + ssl_context = ssl.create_default_context() + else: raise ValueError(f"Unsupported radio class: {class_name}") diff --git a/tests/get_radio_test.py b/tests/get_radio_test.py index 5631bdb..022aecd 100644 --- a/tests/get_radio_test.py +++ b/tests/get_radio_test.py @@ -53,6 +53,13 @@ def test_get_radio_socketpool_wiznet5k( # pylint: disable=unused-argument assert socket_pool in adafruit_connection_manager._global_socketpools.values() +def test_get_radio_socketpool_cpython(): + radio = adafruit_connection_manager.CPythonNetwork() + socket_pool = adafruit_connection_manager.get_radio_socketpool(radio) + assert socket_pool.__name__ == "socket" + assert socket_pool in adafruit_connection_manager._global_socketpools.values() + + def test_get_radio_socketpool_unsupported(): radio = mocket.MockRadio.Unsupported() with pytest.raises(ValueError) as context: @@ -98,6 +105,13 @@ def test_get_radio_ssl_context_wiznet5k( # pylint: disable=unused-argument assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() +def test_get_radio_ssl_context_cpython(): + radio = adafruit_connection_manager.CPythonNetwork() + ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) + assert isinstance(ssl_context, ssl.SSLContext) + assert ssl_context in adafruit_connection_manager._global_ssl_contexts.values() + + def test_get_radio_ssl_context_unsupported(): radio = mocket.MockRadio.Unsupported() with pytest.raises(ValueError) as context: diff --git a/tests/get_socket_test.py b/tests/get_socket_test.py index 46d053b..f61cb26 100644 --- a/tests/get_socket_test.py +++ b/tests/get_socket_test.py @@ -98,30 +98,32 @@ def test_get_socket_os_error(): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.side_effect = [ - OSError("OSError"), + OSError("OSError 1"), mock_socket_1, ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a OSError - with pytest.raises(OSError): + with pytest.raises(OSError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "OSError 1" in str(context) def test_get_socket_runtime_error(): mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.side_effect = [ - RuntimeError("RuntimeError"), + RuntimeError("RuntimeError 1"), mock_socket_1, ] connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to get a socket that returns a RuntimeError - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "RuntimeError 1" in str(context) def test_get_socket_connect_memory_error(): @@ -132,13 +134,14 @@ def test_get_socket_connect_memory_error(): mock_socket_1, mock_socket_2, ] - mock_socket_1.connect.side_effect = MemoryError("MemoryError") + mock_socket_1.connect.side_effect = MemoryError("MemoryError 1") connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a MemoryError - with pytest.raises(MemoryError): + with pytest.raises(MemoryError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "MemoryError 1" in str(context) def test_get_socket_connect_os_error(): @@ -149,13 +152,14 @@ def test_get_socket_connect_os_error(): mock_socket_1, mock_socket_2, ] - mock_socket_1.connect.side_effect = OSError("OSError") + mock_socket_1.connect.side_effect = OSError("OSError 1") connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) # try to connect a socket that returns a OSError - with pytest.raises(OSError): + with pytest.raises(OSError) as context: connection_manager.get_socket(mocket.MOCK_HOST_1, 80, "http:") + assert "OSError 1" in str(context) def test_get_socket_runtime_error_ties_again_at_least_one_free(): @@ -190,9 +194,9 @@ def test_get_socket_runtime_error_ties_again_only_once(): mock_socket_2 = mocket.Mocket() mock_pool.socket.side_effect = [ mock_socket_1, - RuntimeError("error 1"), - RuntimeError("error 2"), - RuntimeError("error 3"), + RuntimeError("RuntimeError 1"), + RuntimeError("RuntimeError 2"), + RuntimeError("RuntimeError 3"), mock_socket_2, ] @@ -207,8 +211,9 @@ def test_get_socket_runtime_error_ties_again_only_once(): free_sockets_mock.assert_not_called() # try to get a socket that returns a RuntimeError twice - with pytest.raises(RuntimeError): + with pytest.raises(RuntimeError) as context: connection_manager.get_socket(mocket.MOCK_HOST_2, 80, "http:") + assert "RuntimeError 2" in str(context) free_sockets_mock.assert_called_once() @@ -237,13 +242,15 @@ def test_fake_ssl_context_connect_error( # pylint: disable=unused-argument mock_pool = mocket.MocketPool() mock_socket_1 = mocket.Mocket() mock_pool.socket.return_value = mock_socket_1 - mock_socket_1.connect.side_effect = RuntimeError("RuntimeError") + mock_socket_1.connect.side_effect = RuntimeError("RuntimeError 1") radio = mocket.MockRadio.ESP_SPIcontrol() ssl_context = adafruit_connection_manager.get_radio_ssl_context(radio) connection_manager = adafruit_connection_manager.ConnectionManager(mock_pool) - with pytest.raises(OSError): + with pytest.raises(OSError) as context: connection_manager.get_socket( mocket.MOCK_HOST_1, 443, "https:", ssl_context=ssl_context ) + assert "12" in str(context) + assert "RuntimeError 1" in str(context)