Skip to content

Enable multiple requests in smartprotocol #584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,19 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None
return resp.status_code, response_data

def _handle_response_error_code(self, resp_dict: dict, msg: str):
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False
self._login_token = None
raise AuthenticationException(msg)
raise SmartDeviceException(msg)
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False
self._login_token = None
raise AuthenticationException(msg)
raise SmartDeviceException(msg)

async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
Expand Down
7 changes: 7 additions & 0 deletions kasa/iotprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict:
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except SmartDeviceException as ex:
_LOGGER.debug(
"Unable to connect to the device: %s, not retrying: %s",
self._host,
ex,
)
raise ex
except Exception as ex:
await self.close()
if retry >= retry_count:
Expand Down
72 changes: 49 additions & 23 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,7 @@ def get_smart_request(self, method, params=None) -> str:
async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self._query_lock:
resp_dict = await self._query(request, retry_count)

if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
)
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationException(msg)
raise SmartDeviceException(msg)

if "result" in resp_dict:
return resp_dict["result"]
return {}
return await self._query(request, retry_count)

async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
for retry in range(retry_count + 1):
Expand Down Expand Up @@ -128,6 +109,11 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
# Transport would have raised RetryableException if retry makes sense.
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
except Exception as ex:
if retry >= retry_count:
await self.close()
Expand All @@ -145,8 +131,15 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:

async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> Dict:
if isinstance(request, dict):
smart_method = next(iter(request))
smart_params = request[smart_method]
if len(request) == 1:
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
requests = []
for method, params in request.items():
requests.append({"method": method, "params": params})
smart_method = "multipleRequest"
smart_params = {"requests": requests}
else:
smart_method = request
smart_params = None
Expand All @@ -165,7 +158,40 @@ async def _execute_query(self, request: Union[str, Dict], retry_count: int) -> D
_LOGGER.isEnabledFor(logging.DEBUG) and pf(response_data),
)

return response_data
self._handle_response_error_code(response_data)

if (result := response_data.get("result")) is None:
# Single set_ requests do not return a result
return {smart_method: None}

if (responses := result.get("responses")) is None:
return {smart_method: result}

# responses is returned for multipleRequest
multi_result = {}
for response in responses:
self._handle_response_error_code(response)
result = response.get("result", None)
multi_result[response["method"]] = result
return multi_result

def _handle_response_error_code(self, resp_dict: dict):
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
msg = (
f"Error querying device: {self._host}: "
+ f"{error_code.name}({error_code.value})"
)
if method := resp_dict.get("method"):
msg += f" for method: {method}"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationException(msg)
raise SmartDeviceException(msg)

async def close(self) -> None:
"""Close the protocol."""
Expand Down
15 changes: 11 additions & 4 deletions kasa/tapo/tapodevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ async def update(self, update_children: bool = True):
raise AuthenticationException("Tapo plug requires authentication.")

if self._components is None:
self._components = await self.protocol.query("component_nego")
resp = await self.protocol.query("component_nego")
self._components = resp["component_nego"]

self._info = await self.protocol.query("get_device_info")
self._usage = await self.protocol.query("get_device_usage")
self._time = await self.protocol.query("get_device_time")
req = {
"get_device_info": None,
"get_device_usage": None,
"get_device_time": None,
}
resp = await self.protocol.query(req)
self._info = resp["get_device_info"]
self._usage = resp["get_device_usage"]
self._time = resp["get_device_time"]

self._last_update = self._data = {
"components": self._components,
Expand Down
16 changes: 14 additions & 2 deletions kasa/tapo/tapoplug.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ async def update(self, update_children: bool = True):
"""Call the device endpoint and update the device data."""
await super().update(update_children)

self._energy = await self.protocol.query("get_energy_usage")
self._emeter = await self.protocol.query("get_current_power")
req = {
"get_energy_usage": None,
"get_current_power": None,
}
resp = await self.protocol.query(req)
self._energy = resp["get_energy_usage"]
self._emeter = resp["get_current_power"]

self._data["energy"] = self._energy
self._data["emeter"] = self._emeter
Expand Down Expand Up @@ -71,6 +76,13 @@ def emeter_realtime(self) -> EmeterStatus:
}
)

async def get_emeter_realtime(self) -> EmeterStatus:
"""Retrieve current energy readings."""
self._verify_emeter()
resp = await self.protocol.query("get_energy_usage")
self._energy = resp["get_energy_usage"]
return self.emeter_realtime

@property
def emeter_today(self) -> Optional[float]:
"""Get the emeter value for today."""
Expand Down
8 changes: 6 additions & 2 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,13 @@ def parametrize(desc, devices, protocol_filter=None, ids=None):
)


has_emeter = parametrize("has emeter", WITH_EMETER_IOT, protocol_filter={"IOT"})
has_emeter = parametrize("has emeter", WITH_EMETER, protocol_filter={"SMART", "IOT"})
no_emeter = parametrize(
"no emeter", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"SMART", "IOT"}
"no emeter", ALL_DEVICES - WITH_EMETER, protocol_filter={"SMART", "IOT"}
)
has_emeter_iot = parametrize("has emeter iot", WITH_EMETER_IOT, protocol_filter={"IOT"})
no_emeter_iot = parametrize(
"no emeter iot", ALL_DEVICES_IOT - WITH_EMETER_IOT, protocol_filter={"IOT"}
)

bulb = parametrize("bulbs", BULBS, protocol_filter={"SMART", "IOT"})
Expand Down
43 changes: 25 additions & 18 deletions kasa/tests/newfakes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import logging
import re
import warnings
from json import loads as json_loads

from voluptuous import (
Expand Down Expand Up @@ -294,9 +295,7 @@ def __init__(self, info):
async def query(self, request, retry_count: int = 3):
"""Implement query here so can still patch SmartProtocol.query."""
resp_dict = await self._query(request, retry_count)
if "result" in resp_dict:
return resp_dict["result"]
return {}
return resp_dict


class FakeSmartTransport(BaseTransport):
Expand All @@ -306,26 +305,34 @@ def __init__(self, info):
)
self.info = info

@property
def needs_handshake(self) -> bool:
return False

@property
def needs_login(self) -> bool:
return False

async def login(self, request: str) -> None:
pass

async def handshake(self) -> None:
pass

async def send(self, request: str):
request_dict = json_loads(request)
method = request_dict["method"]
params = request_dict["params"]
if method == "multipleRequest":
responses = []
for request in params["requests"]:
response = self._send_request(request) # type: ignore[arg-type]
response["method"] = request["method"] # type: ignore[index]
responses.append(response)
return {"result": {"responses": responses}, "error_code": 0}
else:
return self._send_request(request_dict)

def _send_request(self, request_dict: dict):
method = request_dict["method"]
params = request_dict["params"]
if method == "component_nego" or method[:4] == "get_":
return {"result": self.info[method], "error_code": 0}
if method in self.info:
return {"result": self.info[method], "error_code": 0}
else:
warnings.warn(
UserWarning(
f"Fixture missing expected method {method}, try to regenerate"
),
stacklevel=1,
)
return {"result": {}, "error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
self.info[target_method].update(params)
Expand Down
33 changes: 32 additions & 1 deletion kasa/tests/test_aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from ..aestransport import AesEncyptionSession, AesTransport
from ..credentials import Credentials
from ..exceptions import SmartDeviceException
from ..exceptions import (
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
SmartDeviceException,
SmartErrorCode,
)

DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}}

Expand Down Expand Up @@ -105,6 +110,32 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati
assert "result" in res


ERRORS = [e for e in SmartErrorCode if e != 0]


@pytest.mark.parametrize("error_code", ERRORS, ids=lambda e: e.name)
async def test_passthrough_errors(mocker, error_code):
host = "127.0.0.1"
mock_aes_device = MockAesDevice(host, 200, error_code, 0)
mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post)

transport = AesTransport(host=host, credentials=Credentials("foo", "bar"))
transport._handshake_done = True
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session
transport._login_token = mock_aes_device.token

request = {
"method": "get_device_info",
"params": None,
"request_time_milis": round(time.time() * 1000),
"requestID": 1,
"terminal_uuid": "foobar",
}
with pytest.raises(SmartDeviceException):
await transport.send(json_dumps(request))


class MockAesDevice:
class _mock_response:
def __init__(self, status_code, json: dict):
Expand Down
10 changes: 5 additions & 5 deletions kasa/tests/test_emeter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from kasa import EmeterStatus, SmartDeviceException

from .conftest import has_emeter, no_emeter
from .conftest import has_emeter, has_emeter_iot, no_emeter
from .newfakes import CURRENT_CONSUMPTION_SCHEMA


Expand All @@ -20,15 +20,15 @@ async def test_no_emeter(dev):
await dev.erase_emeter_stats()


@has_emeter
@has_emeter_iot
async def test_get_emeter_realtime(dev):
assert dev.has_emeter

current_emeter = await dev.get_emeter_realtime()
CURRENT_CONSUMPTION_SCHEMA(current_emeter)


@has_emeter
@has_emeter_iot
@pytest.mark.requires_dummy
async def test_get_emeter_daily(dev):
assert dev.has_emeter
Expand All @@ -48,7 +48,7 @@ async def test_get_emeter_daily(dev):
assert v * 1000 == v2


@has_emeter
@has_emeter_iot
@pytest.mark.requires_dummy
async def test_get_emeter_monthly(dev):
assert dev.has_emeter
Expand All @@ -68,7 +68,7 @@ async def test_get_emeter_monthly(dev):
assert v * 1000 == v2


@has_emeter
@has_emeter_iot
async def test_emeter_status(dev):
assert dev.has_emeter

Expand Down
Loading