diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 9db0db4f3..e79d0651d 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -125,19 +125,19 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None return resp.status_code, response_data def _handle_response_error_code(self, resp_dict: dict, msg: str): - if ( - error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] - ) != SmartErrorCode.SUCCESS: - msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" - if error_code in SMART_TIMEOUT_ERRORS: - raise TimeoutException(msg) - if error_code in SMART_RETRYABLE_ERRORS: - raise RetryableException(msg) - if error_code in SMART_AUTHENTICATION_ERRORS: - self._handshake_done = False - self._login_token = None - raise AuthenticationException(msg) - raise SmartDeviceException(msg) + error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] + if error_code == SmartErrorCode.SUCCESS: + return + msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" + if error_code in SMART_TIMEOUT_ERRORS: + raise TimeoutException(msg) + if error_code in SMART_RETRYABLE_ERRORS: + raise RetryableException(msg) + if error_code in SMART_AUTHENTICATION_ERRORS: + self._handshake_done = False + self._login_token = None + raise AuthenticationException(msg) + raise SmartDeviceException(msg) async def send_secure_passthrough(self, request: str): """Send encrypted message as passthrough.""" diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index d942d0609..fbb37b15a 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -62,6 +62,13 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: "Unable to authenticate with %s, not retrying", self._host ) raise auex + except SmartDeviceException as ex: + _LOGGER.debug( + "Unable to connect to the device: %s, not retrying: %s", + self._host, + ex, + ) + raise ex except Exception as ex: await self.close() if retry >= retry_count: diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index a344cf66c..a9266174a 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -62,26 +62,7 @@ def get_smart_request(self, method, params=None) -> str: async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Query the device retrying for retry_count on failure.""" async with self._query_lock: - resp_dict = await self._query(request, retry_count) - - if ( - error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] - ) != SmartErrorCode.SUCCESS: - msg = ( - f"Error querying device: {self._host}: " - + f"{error_code.name}({error_code.value})" - ) - if error_code in SMART_TIMEOUT_ERRORS: - raise TimeoutException(msg) - if error_code in SMART_RETRYABLE_ERRORS: - raise RetryableException(msg) - if error_code in SMART_AUTHENTICATION_ERRORS: - raise AuthenticationException(msg) - raise SmartDeviceException(msg) - - if "result" in resp_dict: - return resp_dict["result"] - return {} + return await self._query(request, retry_count) async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): @@ -128,6 +109,11 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: raise ex await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) continue + except SmartDeviceException as ex: + # Transport would have raised RetryableException if retry makes sense. + await self.close() + _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) + raise ex except Exception as ex: if retry >= retry_count: await self.close() @@ -145,8 +131,15 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict: if isinstance(request, dict): - smart_method = next(iter(request)) - smart_params = request[smart_method] + if len(request) == 1: + smart_method = next(iter(request)) + smart_params = request[smart_method] + else: + requests = [] + for method, params in request.items(): + requests.append({"method": method, "params": params}) + smart_method = "multipleRequest" + smart_params = {"requests": requests} else: smart_method = request smart_params = None @@ -165,7 +158,40 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D _LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data), ) - return response_data + self._handle_response_error_code(response_data) + + if (result := response_data.get("result")) is None: + # Single set_ requests do not return a result + return {smart_method: None} + + if (responses := result.get("responses")) is None: + return {smart_method: result} + + # responses is returned for multipleRequest + multi_result = {} + for response in responses: + self._handle_response_error_code(response) + result = response.get("result", None) + multi_result[response["method"]] = result + return multi_result + + def _handle_response_error_code(self, resp_dict: dict): + error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] + if error_code == SmartErrorCode.SUCCESS: + return + msg = ( + f"Error querying device: {self._host}: " + + f"{error_code.name}({error_code.value})" + ) + if method := resp_dict.get("method"): + msg += f" for method: {method}" + if error_code in SMART_TIMEOUT_ERRORS: + raise TimeoutException(msg) + if error_code in SMART_RETRYABLE_ERRORS: + raise RetryableException(msg) + if error_code in SMART_AUTHENTICATION_ERRORS: + raise AuthenticationException(msg) + raise SmartDeviceException(msg) async def close(self) -> None: """Close the protocol.""" diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index e5d9effe0..97405b3f1 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -41,11 +41,18 @@ async def update(self, update_children: bool = True): raise AuthenticationException("Tapo plug requires authentication.") if self._components is None: - self._components = await self.protocol.query("component_nego") + resp = await self.protocol.query("component_nego") + self._components = resp["component_nego"] - self._info = await self.protocol.query("get_device_info") - self._usage = await self.protocol.query("get_device_usage") - self._time = await self.protocol.query("get_device_time") + req = { + "get_device_info": None, + "get_device_usage": None, + "get_device_time": None, + } + resp = await self.protocol.query(req) + self._info = resp["get_device_info"] + self._usage = resp["get_device_usage"] + self._time = resp["get_device_time"] self._last_update = self._data = { "components": self._components, diff --git a/kasa/tapo/tapoplug.py b/kasa/tapo/tapoplug.py index 84d00bc8c..9d868253e 100644 --- a/kasa/tapo/tapoplug.py +++ b/kasa/tapo/tapoplug.py @@ -39,8 +39,13 @@ async def update(self, update_children: bool = True): """Call the device endpoint and update the device data.""" await super().update(update_children) - self._energy = await self.protocol.query("get_energy_usage") - self._emeter = await self.protocol.query("get_current_power") + req = { + "get_energy_usage": None, + "get_current_power": None, + } + resp = await self.protocol.query(req) + self._energy = resp["get_energy_usage"] + self._emeter = resp["get_current_power"] self._data["energy"] = self._energy self._data["emeter"] = self._emeter @@ -71,6 +76,13 @@ def emeter_realtime(self) -> EmeterStatus: } ) + async def get_emeter_realtime(self) -> EmeterStatus: + """Retrieve current energy readings.""" + self._verify_emeter() + resp = await self.protocol.query("get_energy_usage") + self._energy = resp["get_energy_usage"] + return self.emeter_realtime + @property def emeter_today(self) -> Optional[float]: """Get the emeter value for today.""" diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 095971de8..43bba825b 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -196,9 +196,13 @@ def parametrize(desc, devices, protocol_filter=None, ids=None): ) -has_emeter = parametrize("has emeter", WITH_EMETER_IOT, protocol_filter={"IOT"}) +has_emeter = parametrize("has emeter", WITH_EMETER, protocol_filter={"SMART", "IOT"}) no_emeter = parametrize( - "no emeter", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"SMART", "IOT"} + "no emeter", ALL_DEVICES - WITH_EMETER, protocol_filter={"SMART", "IOT"} +) +has_emeter_iot = parametrize("has emeter iot", WITH_EMETER_IOT, protocol_filter={"IOT"}) +no_emeter_iot = parametrize( + "no emeter iot", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"IOT"} ) bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"}) diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index c01c8ee3e..cd7ad4fd9 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -1,6 +1,7 @@ import copy import logging import re +import warnings from json import loads as json_loads from voluptuous import ( @@ -294,9 +295,7 @@ def __init__(self, info): async def query(self, request, retry_count: int = 3): """Implement query here so can still patch SmartProtocol.query.""" resp_dict = await self._query(request, retry_count) - if "result" in resp_dict: - return resp_dict["result"] - return {} + return resp_dict class FakeSmartTransport(BaseTransport): @@ -306,26 +305,34 @@ def __init__(self, info): ) self.info = info - @property - def needs_handshake(self) -> bool: - return False - - @property - def needs_login(self) -> bool: - return False - - async def login(self, request: str) -> None: - pass - - async def handshake(self) -> None: - pass - async def send(self, request: str): request_dict = json_loads(request) + method = request_dict["method"] + params = request_dict["params"] + if method == "multipleRequest": + responses = [] + for request in params["requests"]: + response = self._send_request(request) # type: ignore[arg-type] + response["method"] = request["method"] # type: ignore[index] + responses.append(response) + return {"result": {"responses": responses}, "error_code": 0} + else: + return self._send_request(request_dict) + + def _send_request(self, request_dict: dict): method = request_dict["method"] params = request_dict["params"] if method == "component_nego" or method[:4] == "get_": - return {"result": self.info[method], "error_code": 0} + if method in self.info: + return {"result": self.info[method], "error_code": 0} + else: + warnings.warn( + UserWarning( + f"Fixture missing expected method {method}, try to regenerate" + ), + stacklevel=1, + ) + return {"result": {}, "error_code": 0} elif method[:4] == "set_": target_method = f"get_{method[4:]}" self.info[target_method].update(params) diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index b018b4975..198e8f39e 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -12,7 +12,12 @@ from ..aestransport import AesEncyptionSession, AesTransport from ..credentials import Credentials -from ..exceptions import SmartDeviceException +from ..exceptions import ( + SMART_RETRYABLE_ERRORS, + SMART_TIMEOUT_ERRORS, + SmartDeviceException, + SmartErrorCode, +) DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} @@ -105,6 +110,32 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati assert "result" in res +ERRORS = [e for e in SmartErrorCode if e != 0] + + +@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) +async def test_passthrough_errors(mocker, error_code): + host = "127.0.0.1" + mock_aes_device = MockAesDevice(host, 200, error_code, 0) + mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) + + transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport._handshake_done = True + transport._session_expire_at = time.time() + 86400 + transport._encryption_session = mock_aes_device.encryption_session + transport._login_token = mock_aes_device.token + + request = { + "method": "get_device_info", + "params": None, + "request_time_milis": round(time.time() * 1000), + "requestID": 1, + "terminal_uuid": "foobar", + } + with pytest.raises(SmartDeviceException): + await transport.send(json_dumps(request)) + + class MockAesDevice: class _mock_response: def __init__(self, status_code, json: dict): diff --git a/kasa/tests/test_emeter.py b/kasa/tests/test_emeter.py index 75375230a..9bc70bbaf 100644 --- a/kasa/tests/test_emeter.py +++ b/kasa/tests/test_emeter.py @@ -2,7 +2,7 @@ from kasa import EmeterStatus, SmartDeviceException -from .conftest import has_emeter, no_emeter +from .conftest import has_emeter, has_emeter_iot, no_emeter from .newfakes import CURRENT_CONSUMPTION_SCHEMA @@ -20,7 +20,7 @@ async def test_no_emeter(dev): await dev.erase_emeter_stats() -@has_emeter +@has_emeter_iot async def test_get_emeter_realtime(dev): assert dev.has_emeter @@ -28,7 +28,7 @@ async def test_get_emeter_realtime(dev): CURRENT_CONSUMPTION_SCHEMA(current_emeter) -@has_emeter +@has_emeter_iot @pytest.mark.requires_dummy async def test_get_emeter_daily(dev): assert dev.has_emeter @@ -48,7 +48,7 @@ async def test_get_emeter_daily(dev): assert v * 1000 == v2 -@has_emeter +@has_emeter_iot @pytest.mark.requires_dummy async def test_get_emeter_monthly(dev): assert dev.has_emeter @@ -68,7 +68,7 @@ async def test_get_emeter_monthly(dev): assert v * 1000 == v2 -@has_emeter +@has_emeter_iot async def test_emeter_status(dev): assert dev.has_emeter diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index d29f4e302..1ed57ef22 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -26,20 +26,29 @@ def __init__(self, status_code, content: bytes): self.content = content +@pytest.mark.parametrize( + "error, retry_expectation", + [ + (Exception("dummy exception"), True), + (SmartDeviceException("dummy exception"), False), + ], + ids=("Exception", "SmartDeviceException"), +) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @pytest.mark.parametrize("retry_count", [1, 3, 5]) -async def test_protocol_retries(mocker, retry_count, protocol_class, transport_class): +async def test_protocol_retries( + mocker, retry_count, protocol_class, transport_class, error, retry_expectation +): host = "127.0.0.1" - conn = mocker.patch.object( - httpx.AsyncClient, "post", side_effect=Exception("dummy exception") - ) + conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error) with pytest.raises(SmartDeviceException): await protocol_class(host, transport=transport_class(host)).query( DUMMY_QUERY, retry_count=retry_count ) - assert conn.call_count == retry_count + 1 + expected_count = retry_count + 1 if retry_expectation else 1 + assert conn.call_count == expected_count @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @@ -109,7 +118,7 @@ def _fail_one_less_than_retry_count(*_, **__): response = await protocol_class(host, transport=transport_class(host)).query( DUMMY_QUERY, retry_count=retry_count ) - assert "result" in response or "great" in response + assert "result" in response or "foobar" in response assert send_mock.call_count == retry_count diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 90eae16f2..47f523d00 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -8,7 +8,7 @@ from kasa import Credentials, SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType -from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter, turn_on +from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol # List of all SmartXXX classes including the SmartDevice base class @@ -48,7 +48,7 @@ async def test_initial_update_emeter(dev, mocker): assert spy.call_count == expected_queries + len(dev.children) -@no_emeter +@no_emeter_iot async def test_initial_update_no_emeter(dev, mocker): """Test that the initial update performs second query if emeter is available.""" dev._last_update = None diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py new file mode 100644 index 000000000..5dbbed279 --- /dev/null +++ b/kasa/tests/test_smartprotocol.py @@ -0,0 +1,81 @@ +import errno +import json +import logging +import secrets +import struct +import sys +import time +from contextlib import nullcontext as does_not_raise +from itertools import chain + +import httpx +import pytest + +from ..aestransport import AesTransport +from ..credentials import Credentials +from ..exceptions import ( + SMART_RETRYABLE_ERRORS, + SMART_TIMEOUT_ERRORS, + SmartDeviceException, + SmartErrorCode, +) +from ..iotprotocol import IotProtocol +from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 +from ..smartprotocol import SmartProtocol + +DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} +ERRORS = [e for e in SmartErrorCode if e != 0] + + +@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) +async def test_smart_device_errors(mocker, error_code): + host = "127.0.0.1" + mock_response = {"result": {"great": "success"}, "error_code": error_code.value} + + mocker.patch.object(AesTransport, "perform_handshake") + mocker.patch.object(AesTransport, "perform_login") + + send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) + + protocol = SmartProtocol(host, transport=AesTransport(host)) + with pytest.raises(SmartDeviceException): + await protocol.query(DUMMY_QUERY, retry_count=2) + + if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): + expected_calls = 3 + else: + expected_calls = 1 + assert send_mock.call_count == expected_calls + + +@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name) +async def test_smart_device_errors_in_multiple_request(mocker, error_code): + host = "127.0.0.1" + mock_response = { + "result": { + "responses": [ + {"method": "foobar1", "result": {"great": "success"}, "error_code": 0}, + { + "method": "foobar2", + "result": {"great": "success"}, + "error_code": error_code.value, + }, + {"method": "foobar3", "result": {"great": "success"}, "error_code": 0}, + ] + }, + "error_code": 0, + } + + mocker.patch.object(AesTransport, "perform_handshake") + mocker.patch.object(AesTransport, "perform_login") + + send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) + + protocol = SmartProtocol(host, transport=AesTransport(host)) + with pytest.raises(SmartDeviceException): + await protocol.query(DUMMY_QUERY, retry_count=2) + if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): + expected_calls = 3 + else: + expected_calls = 1 + assert send_mock.call_count == expected_calls