From 0e874a35f1ff6d4e2bfa2f13fa7333085b8acfdb Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 17:11:23 +0000 Subject: [PATCH 1/3] Update transport close/reset behaviour --- kasa/aestransport.py | 8 +++++--- kasa/exceptions.py | 4 ++++ kasa/httpclient.py | 13 +++++++++++-- kasa/iotprotocol.py | 23 ++++++++++++----------- kasa/klaptransport.py | 7 ++++++- kasa/protocol.py | 26 ++++++++++++++++---------- kasa/smartdevice.py | 4 ++++ kasa/smartprotocol.py | 23 ++++++++++++----------- kasa/tests/conftest.py | 11 +++++++++-- kasa/tests/newfakes.py | 3 +++ kasa/tests/test_httpclient.py | 5 +++-- kasa/tests/test_klapprotocol.py | 3 ++- 12 files changed, 87 insertions(+), 43 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 14a9ee6a1..c03b6a111 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -306,10 +306,12 @@ async def send(self, request: str): return await self.send_secure_passthrough(request) async def close(self) -> None: - """Mark the handshake and login as not done. + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() - Since we likely lost the connection. - """ + async def reset(self) -> None: + """Reset internal handshake and login state.""" self._handshake_done = False self._login_token = None diff --git a/kasa/exceptions.py b/kasa/exceptions.py index c0ef23b6a..8720d97b4 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -42,6 +42,10 @@ class ConnectionException(SmartDeviceException): """Connection exception for device errors.""" +class DisconnectedException(SmartDeviceException): + """Disconnected exception for device errors.""" + + class SmartErrorCode(IntEnum): """Enum for SMART Error Codes.""" diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 28a19e8bd..73c91fa4b 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -5,7 +5,12 @@ import aiohttp from .deviceconfig import DeviceConfig -from .exceptions import ConnectionException, SmartDeviceException, TimeoutException +from .exceptions import ( + ConnectionException, + DisconnectedException, + SmartDeviceException, + TimeoutException, +) from .json import loads as json_loads @@ -76,7 +81,11 @@ async def post( if return_json: response_data = json_loads(response_data.decode()) - except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: + except aiohttp.ServerDisconnectedError as ex: + raise DisconnectedException( + f"Disconnected from the device: {self._config.host}: {ex}", ex + ) from ex + except aiohttp.ClientOSError as ex: raise ConnectionException( f"Unable to connect to the device: {self._config.host}: {ex}", ex ) from ex diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index c58cc8802..aac21f103 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -6,6 +6,7 @@ from .exceptions import ( AuthenticationException, ConnectionException, + DisconnectedException, RetryableException, SmartDeviceException, TimeoutException, @@ -44,33 +45,38 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) + except DisconnectedException as sdex: + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) + raise sdex + continue except ConnectionException as sdex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue except AuthenticationException as auex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except SmartDeviceException as ex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to query the device: %s, not retrying: %s", self._host, @@ -85,10 +91,5 @@ async def _execute_query(self, request: str, retry_count: int) -> Dict: return await self._transport.send(request) async def close(self) -> None: - """Close the underlying transport. - - Some transports may close the connection, and some may - use this as a hint that they need to reconnect, or - reauthenticate. - """ + """Close the underlying transport.""" await self._transport.close() diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 5411314a3..c678e4483 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -348,7 +348,12 @@ async def send(self, request: str): return json_payload async def close(self) -> None: - """Mark the handshake as not done since we likely lost the connection.""" + """Close the http client and reset internal state.""" + await self.reset() + await self._http_client.close() + + async def reset(self) -> None: + """Reset internal handshake state.""" self._handshake_done = False @staticmethod diff --git a/kasa/protocol.py b/kasa/protocol.py index 59fea4a84..ae8eb89b1 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -80,6 +80,10 @@ async def send(self, request: str) -> Dict: async def close(self) -> None: """Close the transport. Abstract method to be overriden.""" + @abstractmethod + async def reset(self) -> None: + """Reset internal state.""" + class BaseProtocol(ABC): """Base class for all TP-Link Smart Home communication.""" @@ -139,7 +143,10 @@ async def send(self, request: str) -> Dict: return {} async def close(self) -> None: - """Close the transport. Abstract method to be overriden.""" + """Close the transport.""" + + async def reset(self) -> None: + """Reset internal state..""" class TPLinkSmartHomeProtocol(BaseProtocol): @@ -233,9 +240,9 @@ def close_without_wait(self) -> None: if writer: writer.close() - def _reset(self) -> None: - """Clear any varibles that should not survive between loops.""" - self.reader = self.writer = None + async def reset(self) -> None: + """Reset the transport.""" + await self.close() async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: """Try to query a device.""" @@ -252,12 +259,12 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: try: await self._connect(timeout) except ConnectionRefusedError as ex: - await self.close() + await self.reset() raise SmartDeviceException( f"Unable to connect to the device: {self._host}:{self._port}: {ex}" ) from ex except OSError as ex: - await self.close() + await self.reset() if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count: raise SmartDeviceException( f"Unable to connect to the device:" @@ -265,7 +272,7 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: ) from ex continue except Exception as ex: - await self.close() + await self.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( @@ -290,7 +297,7 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: async with asyncio_timeout(timeout): return await self._execute_query(request) except Exception as ex: - await self.close() + await self.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( @@ -312,7 +319,7 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict: raise # make mypy happy, this should never be reached.. - await self.close() + await self.reset() raise SmartDeviceException("Query reached somehow to unreachable") def __del__(self) -> None: @@ -322,7 +329,6 @@ def __del__(self) -> None: # or in another thread so we need to make sure the call to # close is called safely with call_soon_threadsafe self.loop.call_soon_threadsafe(self.writer.close) - self._reset() @staticmethod def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]: diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 08a6bfb65..31418afcc 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -806,6 +806,10 @@ def config(self) -> DeviceConfig: """Return the device configuration.""" return self.protocol.config + async def disconnect(self): + """Disconnect and close any underlying connection resources.""" + await self.protocol.close() + @staticmethod async def connect( *, diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index c28db948e..8b876c144 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -18,6 +18,7 @@ SMART_TIMEOUT_ERRORS, AuthenticationException, ConnectionException, + DisconnectedException, RetryableException, SmartDeviceException, SmartErrorCode, @@ -65,33 +66,38 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) + except DisconnectedException as sdex: + if retry >= retry_count: + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) + raise sdex + continue except ConnectionException as sdex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex continue except AuthenticationException as auex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: - await self.close() + await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT) continue except SmartDeviceException as ex: - await self.close() + await self._transport.reset() _LOGGER.debug( "Unable to query the device: %s, not retrying: %s", self._host, @@ -167,12 +173,7 @@ def _handle_response_error_code(self, resp_dict: dict): raise SmartDeviceException(msg, error_code=error_code) async def close(self) -> None: - """Close the underlying transport. - - Some transports may close the connection, and some may - use this as a hint that they need to reconnect, or - reauthenticate. - """ + """Close the underlying transport.""" await self._transport.close() diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 12f9c2769..7addbe72a 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -15,6 +15,7 @@ Credentials, Discover, SmartBulb, + SmartDevice, SmartDimmer, SmartLightStrip, SmartPlug, @@ -416,9 +417,15 @@ async def dev(request): IP_MODEL_CACHE[ip] = model = d.model if model not in file: pytest.skip(f"skipping file {file}") - return d if d else await _discover_update_and_close(ip, username, password) + dev: SmartDevice = ( + d if d else await _discover_update_and_close(ip, username, password) + ) + else: + dev: SmartDevice = await get_device_for_file(file, protocol) + + yield dev - return await get_device_for_file(file, protocol) + await dev.disconnect() @pytest.fixture diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 78bea3340..625a4994c 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -377,6 +377,9 @@ def _send_request(self, request_dict: dict): async def close(self) -> None: pass + async def reset(self) -> None: + pass + class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): diff --git a/kasa/tests/test_httpclient.py b/kasa/tests/test_httpclient.py index 0a6c2beba..bcf48df2a 100644 --- a/kasa/tests/test_httpclient.py +++ b/kasa/tests/test_httpclient.py @@ -7,6 +7,7 @@ from ..deviceconfig import DeviceConfig from ..exceptions import ( ConnectionException, + DisconnectedException, SmartDeviceException, TimeoutException, ) @@ -18,8 +19,8 @@ [ ( aiohttp.ServerDisconnectedError(), - ConnectionException, - "Unable to connect to the device: ", + DisconnectedException, + "Disconnected from the device: ", ), ( aiohttp.ClientOSError(), diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 54f4a4bed..6bd142caa 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -54,9 +54,10 @@ async def read(self): [ (Exception("dummy exception"), False), (aiohttp.ServerTimeoutError("dummy exception"), True), + (aiohttp.ServerDisconnectedError("dummy exception"), True), (aiohttp.ClientOSError("dummy exception"), True), ], - ids=("Exception", "SmartDeviceException", "ConnectError"), + ids=("Exception", "SmartDeviceException", "DisconnectError", "ConnectError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) From f68acb60e7f4ca68be2d3c40cb780d9d0b168874 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 17:51:18 +0000 Subject: [PATCH 2/3] Avoid reset for all connection errors --- kasa/exceptions.py | 4 ---- kasa/httpclient.py | 9 ++------- kasa/iotprotocol.py | 7 ------- kasa/smartprotocol.py | 7 ------- kasa/tests/test_httpclient.py | 7 +++---- kasa/tests/test_klapprotocol.py | 2 +- 6 files changed, 6 insertions(+), 30 deletions(-) diff --git a/kasa/exceptions.py b/kasa/exceptions.py index 8720d97b4..c0ef23b6a 100644 --- a/kasa/exceptions.py +++ b/kasa/exceptions.py @@ -42,10 +42,6 @@ class ConnectionException(SmartDeviceException): """Connection exception for device errors.""" -class DisconnectedException(SmartDeviceException): - """Disconnected exception for device errors.""" - - class SmartErrorCode(IntEnum): """Enum for SMART Error Codes.""" diff --git a/kasa/httpclient.py b/kasa/httpclient.py index 73c91fa4b..7fe0b2c39 100644 --- a/kasa/httpclient.py +++ b/kasa/httpclient.py @@ -7,7 +7,6 @@ from .deviceconfig import DeviceConfig from .exceptions import ( ConnectionException, - DisconnectedException, SmartDeviceException, TimeoutException, ) @@ -81,13 +80,9 @@ async def post( if return_json: response_data = json_loads(response_data.decode()) - except aiohttp.ServerDisconnectedError as ex: - raise DisconnectedException( - f"Disconnected from the device: {self._config.host}: {ex}", ex - ) from ex - except aiohttp.ClientOSError as ex: + except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex: raise ConnectionException( - f"Unable to connect to the device: {self._config.host}: {ex}", ex + f"Device connection error: {self._config.host}: {ex}", ex ) from ex except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex: raise TimeoutException( diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index aac21f103..ed926101c 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -6,7 +6,6 @@ from .exceptions import ( AuthenticationException, ConnectionException, - DisconnectedException, RetryableException, SmartDeviceException, TimeoutException, @@ -45,13 +44,7 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) - except DisconnectedException as sdex: - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) - raise sdex - continue except ConnectionException as sdex: - await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 8b876c144..6f0648ea0 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -18,7 +18,6 @@ SMART_TIMEOUT_ERRORS, AuthenticationException, ConnectionException, - DisconnectedException, RetryableException, SmartDeviceException, SmartErrorCode, @@ -66,13 +65,7 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) - except DisconnectedException as sdex: - if retry >= retry_count: - _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) - raise sdex - continue except ConnectionException as sdex: - await self._transport.reset() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise sdex diff --git a/kasa/tests/test_httpclient.py b/kasa/tests/test_httpclient.py index bcf48df2a..e178b8189 100644 --- a/kasa/tests/test_httpclient.py +++ b/kasa/tests/test_httpclient.py @@ -7,7 +7,6 @@ from ..deviceconfig import DeviceConfig from ..exceptions import ( ConnectionException, - DisconnectedException, SmartDeviceException, TimeoutException, ) @@ -19,13 +18,13 @@ [ ( aiohttp.ServerDisconnectedError(), - DisconnectedException, - "Disconnected from the device: ", + ConnectionException, + "Device connection error: ", ), ( aiohttp.ClientOSError(), ConnectionException, - "Unable to connect to the device: ", + "Device connection error: ", ), ( aiohttp.ServerTimeoutError(), diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 6bd142caa..09ceccaef 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -57,7 +57,7 @@ async def read(self): (aiohttp.ServerDisconnectedError("dummy exception"), True), (aiohttp.ClientOSError("dummy exception"), True), ], - ids=("Exception", "SmartDeviceException", "DisconnectError", "ConnectError"), + ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) From 019b709d7ef4bc67804121a617a5a61fab2682af Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 23 Jan 2024 20:06:10 +0000 Subject: [PATCH 3/3] Tests --- kasa/tests/test_aestransport.py | 1 + kasa/tests/test_device_factory.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 4694e3631..5005c7b78 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -137,6 +137,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): await transport.send(json_dumps(request)) assert transport._login_token == mock_aes_device.token assert post_mock.call_count == call_count # Login, Handshake, Login + await transport.close() @status_parameters diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 25a13aea5..8e3e2ed60 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -69,6 +69,8 @@ async def test_connect( assert dev.config == config + await dev.disconnect() + @pytest.mark.parametrize("custom_port", [123, None]) async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):