Skip to content

Handle smartcam partial list responses #1411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions kasa/protocols/smartcamprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from dataclasses import dataclass
from pprint import pformat as pf
from typing import Any
from typing import Any, cast

from ..exceptions import (
AuthenticationError,
Expand Down Expand Up @@ -49,10 +49,13 @@ class SingleRequest:
class SmartCamProtocol(SmartProtocol):
"""Class for SmartCam Protocol."""

async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int
) -> None:
pass
def _get_list_request(
self, method: str, params: dict | None, start_index: int
) -> dict:
# All smartcam requests have params
params = cast(dict, params)
module_name = next(iter(params))
return {method: {module_name: {"start_index": start_index}}}

def _handle_response_error_code(
self, resp_dict: dict, method: str, raise_on_error: bool = True
Expand Down Expand Up @@ -147,7 +150,9 @@ async def _execute_query(
if len(request) == 1 and method in {"get", "set", "do", "multipleRequest"}:
single_request = self._get_smart_camera_single_request(request)
else:
return await self._execute_multiple_query(request, retry_count)
return await self._execute_multiple_query(
request, retry_count, iterate_list_pages
)
else:
single_request = self._make_smart_camera_single_request(request)

Expand Down
32 changes: 24 additions & 8 deletions kasa/protocols/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ async def _query(self, request: str | dict, retry_count: int = 3) -> dict:
# make mypy happy, this should never be reached..
raise KasaException("Query reached somehow to unreachable")

async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dict:
async def _execute_multiple_query(
self, requests: dict, retry_count: int, iterate_list_pages: bool
) -> dict:
debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
multi_result: dict[str, Any] = {}
smart_method = "multipleRequest"
Expand Down Expand Up @@ -275,9 +277,11 @@ async def _execute_multiple_query(self, requests: dict, retry_count: int) -> dic
response, method, raise_on_error=raise_on_error
)
result = response.get("result", None)
await self._handle_response_lists(
result, method, retry_count=retry_count
)
request_params = rp if (rp := requests.get(method)) else None
if iterate_list_pages and result:
await self._handle_response_lists(
result, method, request_params, retry_count=retry_count
)
multi_result[method] = result

# Multi requests don't continue after errors so requery any missing.
Expand All @@ -303,7 +307,9 @@ async def _execute_query(
smart_method = next(iter(request))
smart_params = request[smart_method]
else:
return await self._execute_multiple_query(request, retry_count)
return await self._execute_multiple_query(
request, retry_count, iterate_list_pages
)
else:
smart_method = request
smart_params = None
Expand All @@ -330,12 +336,21 @@ async def _execute_query(
result = response_data.get("result")
if iterate_list_pages and result:
await self._handle_response_lists(
result, smart_method, retry_count=retry_count
result, smart_method, smart_params, retry_count=retry_count
)
return {smart_method: result}

def _get_list_request(
self, method: str, params: dict | None, start_index: int
) -> dict:
return {method: {"start_index": start_index}}

async def _handle_response_lists(
self, response_result: dict[str, Any], method: str, retry_count: int
self,
response_result: dict[str, Any],
method: str,
params: dict | None,
retry_count: int,
) -> None:
if (
response_result is None
Expand All @@ -355,8 +370,9 @@ async def _handle_response_lists(
)
)
while (list_length := len(response_result[response_list_name])) < list_sum:
request = self._get_list_request(method, params, list_length)
response = await self._execute_query(
{method: {"start_index": list_length}},
request,
retry_count=retry_count,
iterate_list_pages=False,
)
Expand Down
19 changes: 15 additions & 4 deletions tests/fakeprotocol_smartcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
*,
list_return_size=10,
is_child=False,
get_child_fixtures=True,
verbatim=False,
components_not_included=False,
):
Expand All @@ -52,9 +53,12 @@ def __init__(
self.verbatim = verbatim
if not is_child:
self.info = copy.deepcopy(info)
self.child_protocols = FakeSmartTransport._get_child_protocols(
self.info, self.fixture_name, "getChildDeviceList"
)
# We don't need to get the child fixtures if testing things like
# lists
if get_child_fixtures:
self.child_protocols = FakeSmartTransport._get_child_protocols(
self.info, self.fixture_name, "getChildDeviceList"
)
else:
self.info = info
# self.child_protocols = self._get_child_protocols()
Expand Down Expand Up @@ -229,9 +233,16 @@ async def _send_request(self, request_dict: dict):
list_key = next(
iter([key for key in result if isinstance(result[key], list)])
)
assert isinstance(params, dict)
module_name = next(iter(params))

start_index = (
start_index
if (params and (start_index := params.get("start_index")))
if (
params
and module_name
and (start_index := params[module_name].get("start_index"))
)
else 0
)

Expand Down
41 changes: 41 additions & 0 deletions tests/protocols/test_smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
KasaException,
SmartErrorCode,
)
from kasa.protocols.smartcamprotocol import SmartCamProtocol
from kasa.protocols.smartprotocol import SmartProtocol, _ChildProtocolWrapper
from kasa.smart import SmartDevice

Expand Down Expand Up @@ -373,6 +374,46 @@ async def test_smart_protocol_lists_multiple_request(mocker, list_sum, batch_siz
assert resp == response


@pytest.mark.parametrize("list_sum", [5, 10, 30])
@pytest.mark.parametrize("batch_size", [1, 2, 3, 50])
async def test_smartcam_protocol_list_request(mocker, list_sum, batch_size):
"""Test smartcam protocol list handling for lists."""
child_list = [{"foo": i} for i in range(list_sum)]

response = {
"getChildDeviceList": {
"child_device_list": child_list,
"start_index": 0,
"sum": list_sum,
},
"getChildDeviceComponentList": {
"child_component_list": child_list,
"start_index": 0,
"sum": list_sum,
},
}
request = {
"getChildDeviceList": {"childControl": {"start_index": 0}},
"getChildDeviceComponentList": {"childControl": {"start_index": 0}},
}

ft = FakeSmartCamTransport(
response,
"foobar",
list_return_size=batch_size,
components_not_included=True,
get_child_fixtures=False,
)
protocol = SmartCamProtocol(transport=ft)
query_spy = mocker.spy(protocol, "_execute_query")
resp = await protocol.query(request)
expected_count = 1 + 2 * (
int(list_sum / batch_size) + (0 if list_sum % batch_size else -1)
)
assert query_spy.call_count == expected_count
assert resp == response


async def test_incomplete_list(mocker, caplog):
"""Test for handling incomplete lists returned from queries."""
info = {
Expand Down
Loading