Skip to content

Update test framework to support smartcam device discovery. #1477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 47 additions & 39 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,47 @@ def _get_discovery_json(data: bytes, ip: str) -> dict:
) from ex
return info

@staticmethod
def _get_connection_parameters(
discovery_result: DiscoveryResult,
) -> DeviceConnectionParameters:
"""Get connection parameters from the discovery result."""
type_ = discovery_result.device_type
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {discovery_result.ip} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=discovery_result.ip,
)

if not (encrypt_type := encrypt_schm.encrypt_type) and (
encrypt_info := discovery_result.encrypt_info
):
encrypt_type = encrypt_info.sym_schm

if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])

if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {discovery_result.ip} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.to_dict(),
host=discovery_result.ip,
)
return DeviceConnectionParameters.from_values(
type_,
encrypt_type,
login_version=login_version,
https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
)

@staticmethod
def _get_device_instance(
info: dict,
Expand Down Expand Up @@ -838,55 +879,22 @@ def _get_device_instance(
config.host,
redact_data(info, NEW_DISCOVERY_REDACTORS),
)

type_ = discovery_result.device_type
if (encrypt_schm := discovery_result.mgt_encrypt_schm) is None:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
"with no mgt_encrypt_schm",
discovery_result=discovery_result.to_dict(),
host=config.host,
)

try:
if not (encrypt_type := encrypt_schm.encrypt_type) and (
encrypt_info := discovery_result.encrypt_info
):
encrypt_type = encrypt_info.sym_schm

if not (login_version := encrypt_schm.lv) and (
et := discovery_result.encrypt_type
):
# Known encrypt types are ["1","2"] and ["3"]
# Reuse the login_version attribute to pass the max to transport
login_version = max([int(i) for i in et])

if not encrypt_type:
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ "with no encryption type",
discovery_result=discovery_result.to_dict(),
host=config.host,
)
config.connection_type = DeviceConnectionParameters.from_values(
type_,
encrypt_type,
login_version=login_version,
https=encrypt_schm.is_support_https,
http_port=encrypt_schm.http_port,
)
conn_params = Discover._get_connection_parameters(discovery_result)
config.connection_type = conn_params
except KasaException as ex:
if isinstance(ex, UnsupportedDeviceError):
raise
raise UnsupportedDeviceError(
f"Unsupported device {config.host} of type {type_} "
+ f"with encrypt_type {encrypt_schm.encrypt_type}",
+ f"with encrypt_scheme {discovery_result.mgt_encrypt_schm}",
discovery_result=discovery_result.to_dict(),
host=config.host,
) from ex

if (
device_class := get_device_class_from_family(
type_, https=encrypt_schm.is_support_https
)
device_class := get_device_class_from_family(type_, https=conn_params.https)
) is None:
_LOGGER.debug("Got unsupported device type: %s", type_)
raise UnsupportedDeviceError(
Expand Down
69 changes: 58 additions & 11 deletions tests/discovery_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import copy
from collections.abc import Coroutine
from dataclasses import dataclass
from json import dumps as json_dumps
from typing import Any, TypedDict
Expand Down Expand Up @@ -34,7 +36,7 @@ class DiscoveryResponse(TypedDict):
"group_id": "REDACTED_07d902da02fa9beab8a64",
"group_name": "I01BU0tFRF9TU0lEIw==", # '#MASKED_SSID#'
"hardware_version": "3.0",
"ip": "192.168.1.192",
"ip": "127.0.0.1",
"mac": "24:2F:D0:00:00:00",
"master_device_id": "REDACTED_51f72a752213a6c45203530",
"need_account_digest": True,
Expand Down Expand Up @@ -134,7 +136,9 @@ def parametrize_discovery(


@pytest.fixture(
params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}),
params=filter_fixtures(
"discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"}
),
ids=idgenerator,
)
async def discovery_mock(request, mocker):
Expand Down Expand Up @@ -251,12 +255,46 @@ def patch_discovery(fixture_infos: dict[str, FixtureInfo], mocker):
first_ip = list(fixture_infos.keys())[0]
first_host = None

# Mock _run_callback_task so the tasks complete in the order they started.
# Otherwise test output is non-deterministic which affects readme examples.
callback_queue: asyncio.Queue = asyncio.Queue()
exception_queue: asyncio.Queue = asyncio.Queue()

async def process_callback_queue(finished_event: asyncio.Event) -> None:
while (finished_event.is_set() is False) or callback_queue.qsize():
coro = await callback_queue.get()
try:
await coro
except Exception as ex:
await exception_queue.put(ex)
else:
await exception_queue.put(None)
callback_queue.task_done()

async def wait_for_coro():
await callback_queue.join()
if ex := exception_queue.get_nowait():
raise ex

def _run_callback_task(self, coro: Coroutine) -> None:
callback_queue.put_nowait(coro)
task = asyncio.create_task(wait_for_coro())
self.callback_tasks.append(task)

mocker.patch(
"kasa.discover._DiscoverProtocol._run_callback_task", _run_callback_task
)

# do_discover_mock
async def mock_discover(self):
"""Call datagram_received for all mock fixtures.

Handles test cases modifying the ip and hostname of the first fixture
for discover_single testing.
"""
finished_event = asyncio.Event()
asyncio.create_task(process_callback_queue(finished_event))

for ip, dm in discovery_mocks.items():
first_ip = list(discovery_mocks.values())[0].ip
fixture_info = fixture_infos[ip]
Expand All @@ -283,10 +321,18 @@ async def mock_discover(self):
dm._datagram,
(dm.ip, port),
)
# Setting this event will stop the processing of callbacks
finished_event.set()

mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)

# query_mock
async def _query(self, request, retry_count: int = 3):
return await protos[self._host].query(request)

mocker.patch("kasa.IotProtocol.query", _query)
mocker.patch("kasa.SmartProtocol.query", _query)

def _getaddrinfo(host, *_, **__):
nonlocal first_host, first_ip
first_host = host # Store the hostname used by discover single
Expand All @@ -295,20 +341,21 @@ def _getaddrinfo(host, *_, **__):
].ip # ip could have been overridden in test
return [(None, None, None, None, (first_ip, 0))]

mocker.patch("kasa.IotProtocol.query", _query)
mocker.patch("kasa.SmartProtocol.query", _query)
mocker.patch("kasa.discover._DiscoverProtocol.do_discover", mock_discover)
mocker.patch(
"socket.getaddrinfo",
# side_effect=lambda *_, **__: [(None, None, None, None, (first_ip, 0))],
side_effect=_getaddrinfo,
)
mocker.patch("socket.getaddrinfo", side_effect=_getaddrinfo)

# Mock decrypt so it doesn't error with unencryptable empty data in the
# fixtures. The discovery result will already contain the decrypted data
# deserialized from the fixture
mocker.patch("kasa.discover.Discover._decrypt_discovery_data")

# Only return the first discovery mock to be used for testing discover single
return discovery_mocks[first_ip]


@pytest.fixture(
params=filter_fixtures("discoverable", protocol_filter={"SMART", "IOT"}),
params=filter_fixtures(
"discoverable", protocol_filter={"SMART", "SMARTCAM", "IOT"}
),
ids=idgenerator,
)
def discovery_data(request, mocker):
Expand Down
14 changes: 2 additions & 12 deletions tests/test_device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,7 @@ def _get_connection_type_device_class(discovery_info):
device_class = Discover._get_device_class(discovery_info)
dr = DiscoveryResult.from_dict(discovery_info["result"])

connection_type = DeviceConnectionParameters.from_values(
dr.device_type,
dr.mgt_encrypt_schm.encrypt_type,
login_version=dr.mgt_encrypt_schm.lv,
https=dr.mgt_encrypt_schm.is_support_https,
http_port=dr.mgt_encrypt_schm.http_port,
)
connection_type = Discover._get_connection_parameters(dr)
else:
connection_type = DeviceConnectionParameters.from_values(
DeviceFamily.IotSmartPlugSwitch.value, DeviceEncryptionType.Xor.value
Expand Down Expand Up @@ -118,11 +112,7 @@ async def test_connect_custom_port(discovery_mock, mocker, custom_port):
connection_type=ctype,
credentials=Credentials("dummy_user", "dummy_password"),
)
default_port = (
DiscoveryResult.from_dict(discovery_data["result"]).mgt_encrypt_schm.http_port
if "result" in discovery_data
else 9999
)
default_port = discovery_mock.default_port

ctype, _ = _get_connection_type_device_class(discovery_data)

Expand Down
Loading