Skip to content

Sleep between discovery packets #656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions kasa/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
on_discovered: Optional[OnDiscoveredCallable] = None,
target: str = "255.255.255.255",
discovery_packets: int = 3,
discovery_timeout: int = 5,
interface: Optional[str] = None,
on_unsupported: Optional[
Callable[[UnsupportedDeviceException], Awaitable[None]]
Expand All @@ -65,7 +66,8 @@ def __init__(

self.port = port
self.discovery_port = port or Discover.DISCOVERY_PORT
self.target = (target, self.discovery_port)
self.target = target
self.target_1 = (target, self.discovery_port)
self.target_2 = (target, Discover.DISCOVERY_PORT_2)

self.discovered_devices = {}
Expand All @@ -75,7 +77,9 @@ def __init__(
self.discovered_event = discovered_event
self.credentials = credentials
self.timeout = timeout
self.discovery_timeout = discovery_timeout
self.seen_hosts: Set[str] = set()
self.discover_task: Optional[asyncio.Task] = None

def connection_made(self, transport) -> None:
"""Set socket options for broadcasting."""
Expand All @@ -93,16 +97,21 @@ def connection_made(self, transport) -> None:
socket.SOL_SOCKET, socket.SO_BINDTODEVICE, self.interface.encode()
)

self.do_discover()
self.discover_task = asyncio.create_task(self.do_discover())

def do_discover(self) -> None:
async def do_discover(self) -> None:
"""Send number of discovery datagrams."""
req = json_dumps(Discover.DISCOVERY_QUERY)
_LOGGER.debug("[DISCOVERY] %s >> %s", self.target, Discover.DISCOVERY_QUERY)
encrypted_req = TPLinkSmartHomeProtocol.encrypt(req)
for _i in range(self.discovery_packets):
self.transport.sendto(encrypted_req[4:], self.target) # type: ignore
sleep_between_packets = self.discovery_timeout / self.discovery_packets
for i in range(self.discovery_packets):
if self.target in self.seen_hosts: # Stop sending for discover_single
break
self.transport.sendto(encrypted_req[4:], self.target_1) # type: ignore
self.transport.sendto(Discover.DISCOVERY_QUERY_2, self.target_2) # type: ignore
if i < self.discovery_packets - 1:
await asyncio.sleep(sleep_between_packets)

def datagram_received(self, data, addr) -> None:
"""Handle discovery responses."""
Expand Down Expand Up @@ -132,30 +141,36 @@ def datagram_received(self, data, addr) -> None:
self.unsupported_device_exceptions[ip] = udex
if self.on_unsupported is not None:
asyncio.ensure_future(self.on_unsupported(udex))
if self.discovered_event is not None:
self.discovered_event.set()
self._handle_discovered_event()
return
except SmartDeviceException as ex:
_LOGGER.debug(f"[DISCOVERY] Unable to find device type for {ip}: {ex}")
self.invalid_device_exceptions[ip] = ex
if self.discovered_event is not None:
self.discovered_event.set()
self._handle_discovered_event()
return

self.discovered_devices[ip] = device

if self.on_discovered is not None:
asyncio.ensure_future(self.on_discovered(device))

self._handle_discovered_event()

def _handle_discovered_event(self):
"""If discovered_event is available set it and cancel discover_task."""
if self.discovered_event is not None:
if self.discover_task:
self.discover_task.cancel()
self.discovered_event.set()

def error_received(self, ex):
"""Handle asyncio.Protocol errors."""
_LOGGER.error("Got error: %s", ex)

def connection_lost(self, ex):
"""NOP implementation of connection lost."""
def connection_lost(self, ex): # pragma: no cover
"""Cancel the discover task if running."""
if self.discover_task:
self.discover_task.cancel()


class Discover:
Expand Down Expand Up @@ -260,6 +275,7 @@ async def discover(
on_unsupported=on_unsupported,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
port=port,
),
local_addr=("0.0.0.0", 0), # noqa: S104
Expand Down Expand Up @@ -334,6 +350,7 @@ async def discover_single(
discovered_event=event,
credentials=credentials,
timeout=timeout,
discovery_timeout=discovery_timeout,
),
local_addr=("0.0.0.0", 0), # noqa: S104
)
Expand Down
121 changes: 119 additions & 2 deletions kasa/tests/test_discovery.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
# type: ignore
import asyncio
import logging
import re
import socket
from unittest.mock import MagicMock

import aiohttp
import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342
from async_timeout import timeout as asyncio_timeout

from kasa import (
Credentials,
DeviceType,
Discover,
SmartDevice,
SmartDeviceException,
TPLinkSmartHomeProtocol,
protocol,
)
from kasa.deviceconfig import (
Expand Down Expand Up @@ -198,9 +202,9 @@ async def test_discover_send(mocker):
"""Test discovery parameters."""
proto = _DiscoverProtocol()
assert proto.discovery_packets == 3
assert proto.target == ("255.255.255.255", 9999)
assert proto.target_1 == ("255.255.255.255", 9999)
transport = mocker.patch.object(proto, "transport")
proto.do_discover()
await proto.do_discover()
assert transport.sendto.call_count == proto.discovery_packets * 2


Expand Down Expand Up @@ -341,3 +345,116 @@ async def test_discover_http_client(discovery_mock, mocker):
assert x.protocol._transport._http_client.client != http_client
x.config.http_client = http_client
assert x.protocol._transport._http_client.client == http_client


LEGACY_DISCOVER_DATA = {
"system": {
"get_sysinfo": {
"alias": "#MASKED_NAME#",
"dev_name": "Smart Wi-Fi Plug",
"deviceId": "0000000000000000000000000000000000000000",
"err_code": 0,
"hwId": "00000000000000000000000000000000",
"hw_ver": "0.0",
"mac": "00:00:00:00:00:00",
"mic_type": "IOT.SMARTPLUGSWITCH",
"model": "HS100(UK)",
"sw_ver": "1.1.0 Build 201016 Rel.175121",
"updating": 0,
}
}
}


class FakeDatagramTransport(asyncio.DatagramTransport):
GHOST_PORT = 8888

def __init__(self, dp, port, do_not_reply_count, unsupported=False):
self.dp = dp
self.port = port
self.do_not_reply_count = do_not_reply_count
self.send_count = 0
if port == 9999:
self.datagram = TPLinkSmartHomeProtocol.encrypt(
json_dumps(LEGACY_DISCOVER_DATA)
)[4:]
elif port == 20002:
discovery_data = UNSUPPORTED if unsupported else AUTHENTICATION_DATA_KLAP
self.datagram = (
b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8"
+ json_dumps(discovery_data).encode()
)
else:
self.datagram = {"foo": "bar"}

def get_extra_info(self, name, default=None):
return MagicMock()

def sendto(self, data, addr=None):
ip, port = addr
if port == self.port or self.port == self.GHOST_PORT:
self.send_count += 1
if self.send_count > self.do_not_reply_count:
self.dp.datagram_received(self.datagram, (ip, self.port))


@pytest.mark.parametrize("port", [9999, 20002])
@pytest.mark.parametrize("do_not_reply_count", [0, 1, 2, 3, 4])
async def test_do_discover_drop_packets(mocker, port, do_not_reply_count):
"""Make sure that discover_single handles authenticating devices correctly."""
host = "127.0.0.1"
discovery_timeout = 1

event = asyncio.Event()
dp = _DiscoverProtocol(
target=host,
discovery_timeout=discovery_timeout,
discovery_packets=5,
discovered_event=event,
)
ft = FakeDatagramTransport(dp, port, do_not_reply_count)
dp.connection_made(ft)

timed_out = False
try:
async with asyncio_timeout(discovery_timeout):
await event.wait()
except asyncio.TimeoutError:
timed_out = True

await asyncio.sleep(0)
assert ft.send_count == do_not_reply_count + 1
assert dp.discover_task.done()
assert timed_out is False


@pytest.mark.parametrize(
"port, will_timeout",
[(FakeDatagramTransport.GHOST_PORT, True), (20002, False)],
ids=["unknownport", "unsupporteddevice"],
)
async def test_do_discover_invalid(mocker, port, will_timeout):
"""Make sure that discover_single handles authenticating devices correctly."""
host = "127.0.0.1"
discovery_timeout = 1

event = asyncio.Event()
dp = _DiscoverProtocol(
target=host,
discovery_timeout=discovery_timeout,
discovery_packets=5,
discovered_event=event,
)
ft = FakeDatagramTransport(dp, port, 0, unsupported=True)
dp.connection_made(ft)

timed_out = False
try:
async with asyncio_timeout(15):
await event.wait()
except asyncio.TimeoutError:
timed_out = True

await asyncio.sleep(0)
assert dp.discover_task.done()
assert timed_out is will_timeout