From 9c50a18f6fbbe802cf328e875a3f23e2449b6a9c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Tue, 5 Dec 2023 12:11:17 +0000 Subject: [PATCH 1/5] Add DeviceConfig handling --- kasa/__init__.py | 10 ++ kasa/aestransport.py | 23 +--- kasa/cli.py | 67 ++++++++-- kasa/connectionparams.py | 147 +++++++++++++++++++++ kasa/device_factory.py | 189 +++++++++----------------- kasa/device_type.py | 0 kasa/discover.py | 180 ++++++++++++++++--------- kasa/iotprotocol.py | 13 +- kasa/klaptransport.py | 22 ++-- kasa/protocol.py | 45 +++---- kasa/protocolfactory.py | 39 ++++++ kasa/smartdevice.py | 44 ++++--- kasa/smartprotocol.py | 16 +-- kasa/tapo/tapodevice.py | 17 ++- kasa/tests/conftest.py | 26 +++- kasa/tests/newfakes.py | 10 +- kasa/tests/test_aestransport.py | 14 +- kasa/tests/test_cli.py | 60 ++++++++- kasa/tests/test_connectionparams.py | 21 +++ kasa/tests/test_device_factory.py | 198 +++++++++++++++------------- kasa/tests/test_discovery.py | 135 ++++++++++++------- kasa/tests/test_klapprotocol.py | 106 ++++++++++----- kasa/tests/test_protocol.py | 67 ++++------ kasa/tests/test_smartdevice.py | 22 ++-- 24 files changed, 922 insertions(+), 549 deletions(-) create mode 100644 kasa/connectionparams.py mode change 100755 => 100644 kasa/device_factory.py mode change 100755 => 100644 kasa/device_type.py create mode 100644 kasa/protocolfactory.py create mode 100644 kasa/tests/test_connectionparams.py diff --git a/kasa/__init__.py b/kasa/__init__.py index 7de394c11..61e367244 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -13,6 +13,12 @@ """ from importlib.metadata import version +from kasa.connectionparams import ( + ConnectionParameters, + ConnectionType, + DeviceFamilyType, + EncryptType, +) from kasa.credentials import Credentials from kasa.discover import Discover from kasa.emeterstatus import EmeterStatus @@ -55,4 +61,8 @@ "AuthenticationException", "UnsupportedDeviceException", "Credentials", + "ConnectionParameters", + "ConnectionType", + "EncryptType", + "DeviceFamilyType", ] diff --git a/kasa/aestransport.py b/kasa/aestransport.py index e7dd53568..60c0df794 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from .credentials import Credentials +from .connectionparams import ConnectionParameters from .exceptions import ( SMART_AUTHENTICATION_ERRORS, SMART_RETRYABLE_ERRORS, @@ -48,7 +48,6 @@ class AesTransport(BaseTransport): """ DEFAULT_PORT = 80 - DEFAULT_TIMEOUT = 5 SESSION_COOKIE_NAME = "TP_SESSIONID" COMMON_HEADERS = { "Content-Type": "application/json", @@ -58,28 +57,22 @@ class AesTransport(BaseTransport): def __init__( self, - host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + cparams: ConnectionParameters, ) -> None: - super().__init__( - host, - port=port or self.DEFAULT_PORT, - credentials=credentials, - timeout=timeout, - ) + super().__init__(cparams=cparams) + self._port = cparams.port or self.DEFAULT_PORT + self._http_client: httpx.AsyncClient = ( + cparams.http_client or httpx.AsyncClient() + ) self._handshake_done = False self._encryption_session: Optional[AesEncyptionSession] = None self._session_expire_at: Optional[float] = None - self._timeout = timeout if timeout else self.DEFAULT_TIMEOUT self._session_cookie = None - self._http_client: httpx.AsyncClient = httpx.AsyncClient() self._login_token = None _LOGGER.debug("Created AES transport for %s", self._host) @@ -102,8 +95,6 @@ def hash_credentials(self, login_v2): async def client_post(self, url, params=None, data=None, json=None, headers=None): """Send an http post request to the device.""" - if not self._http_client: - self._http_client = httpx.AsyncClient() response_data = None cookies = None if self._session_cookie: diff --git a/kasa/cli.py b/kasa/cli.py index 3478c35a5..162e94865 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -12,15 +12,20 @@ from kasa import ( AuthenticationException, + ConnectionParameters, + ConnectionType, Credentials, - DeviceType, + DeviceFamilyType, Discover, + EncryptType, SmartBulb, SmartDevice, + SmartDimmer, + SmartLightStrip, + SmartPlug, SmartStrip, UnsupportedDeviceException, ) -from kasa.device_factory import DEVICE_TYPE_TO_CLASS from kasa.discover import DiscoveryResult try: @@ -49,10 +54,19 @@ def wrapper(message=None, *args, **kwargs): # --json has set it to _nop_echo echo = _do_echo -DEVICE_TYPES = [ - device_type.value - for device_type in DeviceType - if device_type in DEVICE_TYPE_TO_CLASS + +TYPE_TO_CLASS = { + "plug": SmartPlug, + "bulb": SmartBulb, + "dimmer": SmartDimmer, + "strip": SmartStrip, + "lightstrip": SmartLightStrip, +} + +ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in EncryptType] + +TPLINK_DEVICE_TYPES = [ + tplink_device_type.value for tplink_device_type in DeviceFamilyType ] click.anyio_backend = "asyncio" @@ -149,7 +163,7 @@ def _device_to_serializable(val: SmartDevice): "--type", envvar="KASA_TYPE", default=None, - type=click.Choice(DEVICE_TYPES, case_sensitive=False), + type=click.Choice(list(TYPE_TO_CLASS), case_sensitive=False), ) @click.option( "--json/--no-json", @@ -158,6 +172,18 @@ def _device_to_serializable(val: SmartDevice): is_flag=True, help="Output raw device response as JSON.", ) +@click.option( + "--encrypt-type", + envvar="KASA_ENCRYPT_TYPE", + default=None, + type=click.Choice(ENCRYPT_TYPES, case_sensitive=False), +) +@click.option( + "--device-family", + envvar="KASA_DEVICE_FAMILY", + default=None, + type=click.Choice(TPLINK_DEVICE_TYPES, case_sensitive=False), +) @click.option( "--timeout", envvar="KASA_TIMEOUT", @@ -199,6 +225,8 @@ async def cli( verbose, debug, type, + encrypt_type, + device_family, json, timeout, discovery_timeout, @@ -270,12 +298,19 @@ def _nop_echo(*args, **kwargs): return await ctx.invoke(discover) if type is not None: - device_type = DeviceType.from_value(type) - dev = await SmartDevice.connect( - host, credentials=credentials, device_type=device_type, timeout=timeout + dev = TYPE_TO_CLASS[type](host) + await dev.update() + elif device_family or encrypt_type: + ctype = ConnectionType( + DeviceFamilyType(device_family), + EncryptType(encrypt_type), ) + cparams = ConnectionParameters( + host=host, credentials=credentials, timeout=timeout, connection_type=ctype + ) + dev = await SmartDevice.connect(cparams=cparams) else: - echo("No --type defined, discovering..") + echo("No --type or --device-family and --encrypt-type defined, discovering..") dev = await Discover.discover_single( host, port=port, @@ -332,8 +367,10 @@ async def discover(ctx): target = ctx.parent.params["target"] username = ctx.parent.params["username"] password = ctx.parent.params["password"] - timeout = ctx.parent.params["discovery_timeout"] verbose = ctx.parent.params["verbose"] + discovery_timeout = ctx.parent.params["discovery_timeout"] + timeout = ctx.parent.params["timeout"] + port = ctx.parent.params["port"] credentials = Credentials(username, password) @@ -354,7 +391,7 @@ async def print_unsupported(unsupported_exception: UnsupportedDeviceException): echo(f"\t{unsupported_exception}") echo() - echo(f"Discovering devices on {target} for {timeout} seconds") + echo(f"Discovering devices on {target} for {discovery_timeout} seconds") async def print_discovered(dev: SmartDevice): async with sem: @@ -376,9 +413,11 @@ async def print_discovered(dev: SmartDevice): await Discover.discover( target=target, - timeout=timeout, + discovery_timeout=discovery_timeout, on_discovered=print_discovered, on_unsupported=print_unsupported, + port=port, + timeout=timeout, credentials=credentials, ) diff --git a/kasa/connectionparams.py b/kasa/connectionparams.py new file mode 100644 index 000000000..27c85423d --- /dev/null +++ b/kasa/connectionparams.py @@ -0,0 +1,147 @@ +"""Module for holding connection parameters.""" +import logging +from dataclasses import asdict, dataclass, field, fields, is_dataclass +from enum import Enum +from typing import Dict, Optional + +import httpx + +from .credentials import Credentials +from .exceptions import SmartDeviceException + +_LOGGER = logging.getLogger(__name__) + + +class EncryptType(Enum): + """Encrypt type enum.""" + + Klap = "KLAP" + Aes = "AES" + Xor = "XOR" + + +class DeviceFamilyType(Enum): + """Encrypt type enum.""" + + IotSmartPlugSwitch = "IOT.SMARTPLUGSWITCH" + IotSmartBulb = "IOT.SMARTBULB" + SmartKasaPlug = "SMART.KASAPLUG" + SmartTapoPlug = "SMART.TAPOPLUG" + SmartTapoBulb = "SMART.TAPOBULB" + + +def _dataclass_from_dict(klass, in_val): + if is_dataclass(klass): + fieldtypes = {f.name: f.type for f in fields(klass)} + val = {} + for dict_key in in_val: + if dict_key in fieldtypes and hasattr(fieldtypes[dict_key], "from_dict"): + val[dict_key] = fieldtypes[dict_key].from_dict(in_val[dict_key]) + else: + val[dict_key] = _dataclass_from_dict( + fieldtypes[dict_key], in_val[dict_key] + ) + return klass(**val) + else: + return in_val + + +def _dataclass_to_dict(in_val): + fieldtypes = {f.name: f.type for f in fields(in_val) if f.compare} + out_val = {} + for field_name in fieldtypes: + val = getattr(in_val, field_name) + if val is None: + continue + elif hasattr(val, "to_dict"): + out_val[field_name] = val.to_dict() + elif is_dataclass(fieldtypes[field_name]): + out_val[field_name] = asdict(val) + else: + out_val[field_name] = val + return out_val + + +@dataclass +class ConnectionType: + """Class to hold the the parameters determining connection type.""" + + device_family: DeviceFamilyType + encryption_type: EncryptType + + @staticmethod + def from_values( + device_family: str, + encryption_type: str, + ) -> "ConnectionType": + """Return connection parameters from string values.""" + try: + return ConnectionType( + DeviceFamilyType(device_family), + EncryptType(encryption_type), + ) + except ValueError as ex: + raise SmartDeviceException( + f"Invalid connection parameters for {device_family}.{encryption_type}" + ) from ex + + @staticmethod + def from_dict(connection_type_dict: Dict[str, str]) -> "ConnectionType": + """Return connection parameters from dict.""" + if ( + isinstance(connection_type_dict, dict) + and (device_family := connection_type_dict.get("device_family")) + and (encryption_type := connection_type_dict.get("encryption_type")) + ): + return ConnectionType.from_values(device_family, encryption_type) + + raise SmartDeviceException( + f"Invalid connection type data for {connection_type_dict}" + ) + + def to_dict(self) -> Dict[str, str]: + """Convert connection params to dict.""" + result = { + "device_family": self.device_family.value, + "encryption_type": self.encryption_type.value, + } + return result + + +@dataclass +class ConnectionParameters: + """Class to represent paramaters that determine how to connect to devices.""" + + DEFAULT_TIMEOUT = 5 + + host: str + timeout: Optional[int] = DEFAULT_TIMEOUT + port: Optional[int] = None + credentials: Credentials = field( + default_factory=lambda: Credentials(username="", password="") + ) + connection_type: ConnectionType = field( + default_factory=lambda: ConnectionType( + DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor + ) + ) + + # compare=False will be excluded from the serialization and object comparison. + http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False) + + def __post_init__(self): + if self.credentials is None: + self.credentials = Credentials(username="", password="") + if self.connection_type is None: + self.connection_type = ConnectionType( + DeviceFamilyType.IotSmartPlugSwitch, EncryptType.Xor + ) + + def to_dict(self) -> Dict[str, Dict[str, str]]: + """Convert connection params to dict.""" + return _dataclass_to_dict(self) + + @staticmethod + def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "ConnectionParameters": + """Return connection parameters from dict.""" + return _dataclass_from_dict(ConnectionParameters, cparam_dict) diff --git a/kasa/device_factory.py b/kasa/device_factory.py old mode 100755 new mode 100644 index d8a07beee..26f7538a8 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -1,123 +1,83 @@ -"""Device creation by type.""" - +"""Device creation via ConnectionParameters.""" import logging import time -from typing import Any, Dict, Optional, Tuple, Type - -from .aestransport import AesTransport -from .credentials import Credentials -from .device_type import DeviceType -from .exceptions import UnsupportedDeviceException -from .iotprotocol import IotProtocol -from .klaptransport import KlapTransport, TPlinkKlapTransportV2 -from .protocol import BaseTransport, TPLinkProtocol -from .smartbulb import SmartBulb -from .smartdevice import SmartDevice, SmartDeviceException -from .smartdimmer import SmartDimmer -from .smartlightstrip import SmartLightStrip -from .smartplug import SmartPlug -from .smartprotocol import SmartProtocol -from .smartstrip import SmartStrip -from .tapo import TapoBulb, TapoPlug - -DEVICE_TYPE_TO_CLASS = { - DeviceType.Plug: SmartPlug, - DeviceType.Bulb: SmartBulb, - DeviceType.Strip: SmartStrip, - DeviceType.Dimmer: SmartDimmer, - DeviceType.LightStrip: SmartLightStrip, - DeviceType.TapoPlug: TapoPlug, - DeviceType.TapoBulb: TapoBulb, -} +from typing import Any, Dict, Optional, Type + +from kasa.connectionparams import ConnectionParameters +from kasa.protocol import TPLinkSmartHomeProtocol +from kasa.smartbulb import SmartBulb +from kasa.smartdevice import SmartDevice +from kasa.smartdimmer import SmartDimmer +from kasa.smartlightstrip import SmartLightStrip +from kasa.smartplug import SmartPlug +from kasa.smartstrip import SmartStrip +from kasa.tapo import TapoBulb, TapoPlug + +from .exceptions import SmartDeviceException, UnsupportedDeviceException +from .protocolfactory import get_protocol _LOGGER = logging.getLogger(__name__) +GET_SYSINFO_QUERY = { + "system": {"get_sysinfo": None}, +} -async def connect( - host: str, - *, - port: Optional[int] = None, - timeout=5, - credentials: Optional[Credentials] = None, - device_type: Optional[DeviceType] = None, - protocol_class: Optional[Type[TPLinkProtocol]] = None, -) -> "SmartDevice": - """Connect to a single device by the given IP address. - - This method avoids the UDP based discovery process and - will connect directly to the device to query its type. - - It is generally preferred to avoid :func:`discover_single()` and - use this function instead as it should perform better when - the WiFi network is congested or the device is not responding - to discovery requests. - - The device type is discovered by querying the device. - - :param host: Hostname of device to query - :param device_type: Device type to use for the device. - If not given, the device type is discovered by querying the device. - If the device type is already known, it is preferred to pass it - to avoid the extra query to the device to discover its type. - :param protocol_class: Optionally provide the protocol class - to use. - :rtype: SmartDevice - :return: Object for querying/controlling found device. + +async def connect(*, cparams: ConnectionParameters) -> "SmartDevice": + """Connect to a single device by the given connection parameters. + + Do not use this function directly, use SmartDevice.Connect() """ debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) - if debug_enabled: start_time = time.perf_counter() - if device_type and (klass := DEVICE_TYPE_TO_CLASS.get(device_type)): - dev: SmartDevice = klass( - host=host, port=port, credentials=credentials, timeout=timeout - ) - if protocol_class is not None: - dev.protocol = protocol_class( - host, - transport=AesTransport( - host, port=port, credentials=credentials, timeout=timeout - ), - ) - await dev.update() + def _perf_log(has_params, perf_type): + nonlocal start_time if debug_enabled: end_time = time.perf_counter() _LOGGER.debug( - "Device %s with known type (%s) took %.2f seconds to connect", - host, - device_type.value, - end_time - start_time, + f"Device {cparams.host} with connection params {has_params} " + + f"took {end_time - start_time:.2f} seconds to {perf_type}", ) - return dev - - unknown_dev = SmartDevice( - host=host, port=port, credentials=credentials, timeout=timeout - ) - if protocol_class is not None: - # TODO this will be replaced with connection params - unknown_dev.protocol = protocol_class( - host, - transport=AesTransport( - host, port=port, credentials=credentials, timeout=timeout - ), + start_time = time.perf_counter() + + if (protocol := get_protocol(cparams=cparams)) is None: + raise UnsupportedDeviceException( + f"Unsupported device for {cparams.host}: " + + f"{cparams.connection_type.device_family.value}" ) - await unknown_dev.update() - device_class = get_device_class_from_sys_info(unknown_dev.internal_state) - dev = device_class(host=host, port=port, credentials=credentials, timeout=timeout) - # Reuse the connection from the unknown device - # so we don't have to reconnect - dev.protocol = unknown_dev.protocol - await dev.update() - if debug_enabled: - end_time = time.perf_counter() - _LOGGER.debug( - "Device %s with unknown type (%s) took %.2f seconds to connect", - host, - dev.device_type.value, - end_time - start_time, + + device_class: Optional[Type[SmartDevice]] + + if isinstance(protocol, TPLinkSmartHomeProtocol): + info = await protocol.query(GET_SYSINFO_QUERY) + _perf_log(True, "get_sysinfo") + device_class = get_device_class_from_sys_info(info) + device = device_class(cparams.host, port=cparams.port, timeout=cparams.timeout) + device.update_from_discover_info(info) + device.protocol = protocol + await device.update() + _perf_log(True, "update") + return device + elif device_class := get_device_class_from_family( + cparams.connection_type.device_family.value + ): + device = device_class( + cparams.host, + port=cparams.port, + timeout=cparams.timeout, + credentials=cparams.credentials, + ) + device.protocol = protocol + await device.update() + _perf_log(True, "update") + return device + else: + raise UnsupportedDeviceException( + f"Unsupported device for {cparams.host}: " + + f"{cparams.connection_type.device_family.value}" ) - return dev def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: @@ -147,32 +107,13 @@ def get_device_class_from_sys_info(info: Dict[str, Any]) -> Type[SmartDevice]: raise UnsupportedDeviceException("Unknown device type: %s" % type_) -def get_device_class_from_type_name(device_type: str) -> Optional[Type[SmartDevice]]: +def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice]]: """Return the device class from the type name.""" supported_device_types: dict[str, Type[SmartDevice]] = { "SMART.TAPOPLUG": TapoPlug, "SMART.TAPOBULB": TapoBulb, "SMART.KASAPLUG": TapoPlug, "IOT.SMARTPLUGSWITCH": SmartPlug, + "IOT.SMARTBULB": SmartBulb, } return supported_device_types.get(device_type) - - -def get_protocol_from_connection_name( - connection_name: str, host: str, credentials: Optional[Credentials] = None -) -> Optional[TPLinkProtocol]: - """Return the protocol from the connection name.""" - supported_device_protocols: dict[ - str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] - ] = { - "IOT.KLAP": (IotProtocol, KlapTransport), - "SMART.AES": (SmartProtocol, AesTransport), - "SMART.KLAP": (SmartProtocol, TPlinkKlapTransportV2), - } - if connection_name not in supported_device_protocols: - return None - - protocol_class, transport_class = supported_device_protocols.get(connection_name) # type: ignore - transport: BaseTransport = transport_class(host, credentials=credentials) - protocol: TPLinkProtocol = protocol_class(host, transport=transport) - return protocol diff --git a/kasa/device_type.py b/kasa/device_type.py old mode 100755 new mode 100644 diff --git a/kasa/discover.py b/kasa/discover.py index 4ec3775e9..a299fd122 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -1,39 +1,45 @@ """Discovery module for TP-Link Smart Home devices.""" import asyncio +import base64 import binascii import ipaddress import logging import socket from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast +import httpx + # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout try: - from pydantic.v1 import BaseModel, Field + from pydantic.v1 import BaseModel, ValidationError except ImportError: - from pydantic import BaseModel, Field + from pydantic import BaseModel, ValidationError +from kasa.connectionparams import ConnectionParameters, ConnectionType, EncryptType from kasa.credentials import Credentials +from kasa.device_factory import ( + get_device_class_from_family, + get_device_class_from_sys_info, +) from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import TPLinkSmartHomeProtocol +from kasa.protocolfactory import get_protocol from kasa.smartdevice import SmartDevice, SmartDeviceException -from .device_factory import ( - get_device_class_from_sys_info, - get_device_class_from_type_name, - get_protocol_from_connection_name, -) - _LOGGER = logging.getLogger(__name__) OnDiscoveredCallable = Callable[[SmartDevice], Awaitable[None]] DeviceDict = Dict[str, SmartDevice] +UNAVAILABLE_ALIAS = "Authentication required" +UNAVAILABLE_NICKNAME = base64.b64encode(UNAVAILABLE_ALIAS.encode()).decode() + class _DiscoverProtocol(asyncio.DatagramProtocol): """Implementation of the discovery protocol handler. @@ -57,14 +63,18 @@ def __init__( discovered_event: Optional[asyncio.Event] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, + http_client_generator: Optional[Callable[[], httpx.AsyncClient]] = None, ) -> None: self.transport = None self.discovery_packets = discovery_packets self.interface = interface self.on_discovered = on_discovered + + self.port = port self.discovery_port = port or Discover.DISCOVERY_PORT self.target = (target, self.discovery_port) self.target_2 = (target, Discover.DISCOVERY_PORT_2) + self.discovered_devices = {} self.unsupported_device_exceptions: Dict = {} self.invalid_device_exceptions: Dict = {} @@ -73,6 +83,9 @@ def __init__( self.credentials = credentials self.timeout = timeout self.seen_hosts: Set[str] = set() + self.http_client_generator: Optional[ + Callable[[], httpx.AsyncClient] + ] = http_client_generator def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -110,13 +123,19 @@ def datagram_received(self, data, addr) -> None: self.seen_hosts.add(ip) device = None + + cparams = ConnectionParameters(host=ip, port=self.port) + if self.credentials: + cparams.credentials = self.credentials + if self.timeout: + cparams.timeout = self.timeout try: if port == self.discovery_port: - device = Discover._get_device_instance_legacy(data, ip, port) + device = Discover._get_device_instance_legacy(data, cparams) elif port == Discover.DISCOVERY_PORT_2: - device = Discover._get_device_instance( - data, ip, port, self.credentials or Credentials() - ) + if self.http_client_generator: + cparams.http_client = self.http_client_generator() + device = Discover._get_device_instance(data, cparams) else: return except UnsupportedDeviceException as udex: @@ -200,11 +219,14 @@ async def discover( *, target="255.255.255.255", on_discovered=None, - timeout=5, + discovery_timeout=5, discovery_packets=3, interface=None, on_unsupported=None, credentials=None, + port=None, + timeout=None, + http_client_generator: Optional[Callable[[], httpx.AsyncClient]] = None, ) -> DeviceDict: """Discover supported devices. @@ -240,14 +262,16 @@ async def discover( on_unsupported=on_unsupported, credentials=credentials, timeout=timeout, + port=port, + http_client_generator=http_client_generator, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) protocol = cast(_DiscoverProtocol, protocol) try: - _LOGGER.debug("Waiting %s seconds for responses...", timeout) - await asyncio.sleep(timeout) + _LOGGER.debug("Waiting %s seconds for responses...", discovery_timeout) + await asyncio.sleep(discovery_timeout) finally: transport.close() @@ -259,10 +283,11 @@ async def discover( async def discover_single( host: str, *, + discovery_timeout: int = 5, port: Optional[int] = None, - timeout=5, + timeout: Optional[int] = None, credentials: Optional[Credentials] = None, - update_parent_devices: bool = True, + httpx_asyncclient: httpx.AsyncClient = None, ) -> SmartDevice: """Discover a single device by the given IP address. @@ -275,8 +300,6 @@ async def discover_single( :param port: Optionally set a different port for the device :param timeout: Timeout for discovery :param credentials: Credentials for devices that require authentication - :param update_parent_devices: Automatically call device.update() on - devices that have children :rtype: SmartDevice :return: Object for querying/controlling found device. """ @@ -314,15 +337,20 @@ async def discover_single( discovered_event=event, credentials=credentials, timeout=timeout, + http_client_generator=lambda: httpx_asyncclient + if httpx_asyncclient + else None, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) protocol = cast(_DiscoverProtocol, protocol) try: - _LOGGER.debug("Waiting a total of %s seconds for responses...", timeout) + _LOGGER.debug( + "Waiting a total of %s seconds for responses...", discovery_timeout + ) - async with asyncio_timeout(timeout): + async with asyncio_timeout(discovery_timeout): await event.wait() except asyncio.TimeoutError as ex: raise SmartDeviceException( @@ -334,9 +362,8 @@ async def discover_single( if ip in protocol.discovered_devices: dev = protocol.discovered_devices[ip] dev.host = host - # Call device update on devices that have children - if update_parent_devices and dev.has_children: - await dev.update() + if httpx_asyncclient and hasattr(dev.protocol._transport, "http_client"): + dev.protocol._transport.http_client = httpx_asyncclient # type: ignore[union-attr] return dev elif ip in protocol.unsupported_device_exceptions: raise protocol.unsupported_device_exceptions[ip] @@ -350,99 +377,128 @@ def _get_device_class(info: dict) -> Type[SmartDevice]: """Find SmartDevice subclass for device described by passed data.""" if "result" in info: discovery_result = DiscoveryResult(**info["result"]) - dev_class = get_device_class_from_type_name(discovery_result.device_type) + dev_class = get_device_class_from_family(discovery_result.device_type) if not dev_class: raise UnsupportedDeviceException( - "Unknown device type: %s" % discovery_result.device_type + "Unknown device type: %s" % discovery_result.device_type, + discovery_result=info, ) return dev_class else: return get_device_class_from_sys_info(info) @staticmethod - def _get_device_instance_legacy(data: bytes, ip: str, port: int) -> SmartDevice: + def _get_device_instance_legacy( + data: bytes, cparams: ConnectionParameters + ) -> SmartDevice: """Get SmartDevice from legacy 9999 response.""" try: info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) except Exception as ex: raise SmartDeviceException( - f"Unable to read response from device: {ip}: {ex}" + f"Unable to read response from device: {cparams.host}: {ex}" ) from ex - _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) + _LOGGER.debug("[DISCOVERY] %s << %s", cparams.host, info) device_class = Discover._get_device_class(info) - device = device_class(ip, port=port) + device = device_class(cparams.host, port=cparams.port) + sys_info = info["system"]["get_sysinfo"] + if (device_type := sys_info.get("mic_type")) or ( + device_type := sys_info.get("type") + ): + cparams.connection_type = ConnectionType.from_values( + device_family=device_type, encryption_type=EncryptType.Xor.value + ) + device.protocol = get_protocol(cparams) # type: ignore[assignment] device.update_from_discover_info(info) return device @staticmethod def _get_device_instance( - data: bytes, ip: str, port: int, credentials: Credentials + data: bytes, + cparams: ConnectionParameters, ) -> SmartDevice: """Get SmartDevice from the new 20002 response.""" try: info = json_loads(data[16:]) - discovery_result = DiscoveryResult(**info["result"]) except Exception as ex: + _LOGGER.debug("Got invalid response from device %s: %s", cparams.host, data) + raise SmartDeviceException( + f"Unable to read response from device: {cparams.host}: {ex}" + ) from ex + try: + discovery_result = DiscoveryResult(**info["result"]) + except ValidationError as ex: + _LOGGER.debug( + "Unable to parse discovery from device %s: %s", cparams.host, info + ) raise UnsupportedDeviceException( - f"Unable to read response from device: {ip}: {ex}" + f"Unable to parse discovery from device: {cparams.host}: {ex}" ) from ex type_ = discovery_result.device_type - encrypt_type_ = ( - f"{type_.split('.')[0]}.{discovery_result.mgt_encrypt_schm.encrypt_type}" - ) - if (device_class := get_device_class_from_type_name(type_)) is None: + try: + cparams.connection_type = ConnectionType.from_values( + type_, discovery_result.mgt_encrypt_schm.encrypt_type + ) + except SmartDeviceException as ex: + raise UnsupportedDeviceException( + f"Unsupported device {cparams.host} of type {type_} " + + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", + discovery_result=discovery_result.get_dict(), + ) from ex + if (device_class := get_device_class_from_family(type_)) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceException( - f"Unsupported device {ip} of type {type_}: {info}", + f"Unsupported device {cparams.host} of type {type_}: {info}", discovery_result=discovery_result.get_dict(), ) - if ( - protocol := get_protocol_from_connection_name( - encrypt_type_, ip, credentials=credentials + if (protocol := get_protocol(cparams)) is None: + _LOGGER.warning( + "Got unsupported connection type: %s", cparams.connection_type.to_dict() ) - ) is None: - _LOGGER.warning("Got unsupported device type: %s", encrypt_type_) raise UnsupportedDeviceException( - f"Unsupported encryption scheme {ip} of type {encrypt_type_}: {info}", + f"Unsupported encryption scheme {cparams.host} of " + + f"type {cparams.connection_type.to_dict()}: {info}", discovery_result=discovery_result.get_dict(), ) - _LOGGER.debug("[DISCOVERY] %s << %s", ip, info) - device = device_class(ip, port=port, credentials=credentials) + _LOGGER.debug("[DISCOVERY] %s << %s", cparams.host, info) + device = device_class( + cparams.host, port=cparams.port, credentials=cparams.credentials + ) device.protocol = protocol - device.update_from_discover_info(discovery_result.get_dict()) + + di = discovery_result.get_dict() + di["model"] = discovery_result.device_model + di["alias"] = UNAVAILABLE_ALIAS + di["nickname"] = UNAVAILABLE_NICKNAME + device.update_from_discover_info(di) return device class DiscoveryResult(BaseModel): """Base model for discovery result.""" - class Config: - """Class for configuring model behaviour.""" - - allow_population_by_field_name = True - class EncryptionScheme(BaseModel): """Base model for encryption scheme of discovery result.""" - is_support_https: Optional[bool] = None - encrypt_type: Optional[str] = None - http_port: Optional[int] = None - lv: Optional[int] = 1 + is_support_https: bool + encrypt_type: str + http_port: int + lv: Optional[int] = None - device_type: str = Field(alias="device_type_text") - device_model: str = Field(alias="model") - ip: str = Field(alias="alias") + device_type: str + device_model: str + ip: str mac: str mgt_encrypt_schm: EncryptionScheme + device_id: str - device_id: Optional[str] = Field(default=None, alias="device_id_hash") - owner: Optional[str] = Field(default=None, alias="device_owner_hash") hw_ver: Optional[str] = None + owner: Optional[str] = None is_support_iot_cloud: Optional[bool] = None obd_src: Optional[str] = None factory_default: Optional[bool] = None @@ -453,5 +509,5 @@ def get_dict(self) -> dict: containing only the values actually set and with aliases as field names. """ return self.dict( - by_alias=True, exclude_unset=True, exclude_none=True, exclude_defaults=True + by_alias=False, exclude_unset=True, exclude_none=True, exclude_defaults=True ) diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index fbb37b15a..e78d24e74 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -17,12 +17,11 @@ class IotProtocol(TPLinkProtocol): def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__(host, transport=transport) + super().__init__(transport=transport) self._query_lock = asyncio.Lock() @@ -39,21 +38,14 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() + except httpx.ConnectError as sdex: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self._host}: {cex}" - ) from cex except TimeoutError as tex: - await self.close() raise SmartDeviceException( f"Unable to connect to the device, timed out: {self._host}: {tex}" ) from tex @@ -70,7 +62,6 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: ) raise ex except Exception as ex: - await self.close() if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index e7bb8ae6c..6856521b5 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -53,6 +53,7 @@ from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from .connectionparams import ConnectionParameters from .credentials import Credentials from .exceptions import AuthenticationException, SmartDeviceException from .json import loads as json_loads @@ -84,25 +85,21 @@ class KlapTransport(BaseTransport): DEFAULT_PORT = 80 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} + KASA_SETUP_EMAIL = "kasa@tp-link.net" KASA_SETUP_PASSWORD = "kasaSetup" # noqa: S105 SESSION_COOKIE_NAME = "TP_SESSIONID" def __init__( self, - host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + cparams: ConnectionParameters, ) -> None: - super().__init__( - host, - port=port or self.DEFAULT_PORT, - credentials=credentials, - timeout=timeout, + super().__init__(cparams=cparams) + self._port = cparams.port or self.DEFAULT_PORT + self._http_client: httpx.AsyncClient = ( + cparams.http_client or httpx.AsyncClient() ) - self._local_seed: Optional[bytes] = None self._local_auth_hash = self.generate_auth_hash(self._credentials) self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() @@ -116,14 +113,11 @@ def __init__( self._session_expire_at: Optional[float] = None self._session_cookie = None - self._http_client: httpx.AsyncClient = httpx.AsyncClient() _LOGGER.debug("Created KLAP transport for %s", self._host) async def client_post(self, url, params=None, data=None): """Send an http post request to the device.""" - if not self._http_client: - self._http_client = httpx.AsyncClient() response_data = None cookies = None if self._session_cookie: @@ -390,7 +384,7 @@ def generate_owner_hash(creds: Credentials): return md5(un.encode()) -class TPlinkKlapTransportV2(KlapTransport): +class KlapTransportV2(KlapTransport): """Implementation of the KLAP encryption protocol with v2 hanshake hashes.""" @staticmethod diff --git a/kasa/protocol.py b/kasa/protocol.py index f73260bf0..501b33fe5 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -24,7 +24,7 @@ from async_timeout import timeout as asyncio_timeout from cryptography.hazmat.primitives import hashes -from .credentials import Credentials +from .connectionparams import ConnectionParameters from .exceptions import SmartDeviceException from .json import dumps as json_dumps from .json import loads as json_loads @@ -48,17 +48,15 @@ class BaseTransport(ABC): def __init__( self, - host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + cparams: ConnectionParameters, ) -> None: """Create a protocol object.""" - self._host = host - self._port = port - self._credentials = credentials or Credentials(username="", password="") - self._timeout = timeout or self.DEFAULT_TIMEOUT + self._cparams = cparams + self._host = cparams.host + self._port = cparams.port # Set by derived classes + self._credentials = cparams.credentials + self._timeout = cparams.timeout @abstractmethod async def send(self, request: str) -> Dict: @@ -74,7 +72,6 @@ class TPLinkProtocol(ABC): def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: @@ -85,6 +82,11 @@ def __init__( def _host(self): return self._transport._host + @property + def connection_parameters(self) -> ConnectionParameters: + """Return the connection parameters the device is using.""" + return self._transport._cparams + @abstractmethod async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: """Query the device for the protocol. Abstract method to be overriden.""" @@ -105,20 +107,9 @@ class _XorTransport(BaseTransport): DEFAULT_PORT = 9999 - def __init__( - self, - host: str, - *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, - ) -> None: - super().__init__( - host, - port=port or self.DEFAULT_PORT, - credentials=credentials, - timeout=timeout, - ) + def __init__(self, *, cparams: ConnectionParameters) -> None: + super().__init__(cparams=cparams) + self._port = cparams.port or self.DEFAULT_PORT async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" @@ -133,17 +124,15 @@ class TPLinkSmartHomeProtocol(TPLinkProtocol): INITIALIZATION_VECTOR = 171 DEFAULT_PORT = 9999 - DEFAULT_TIMEOUT = 5 BLOCK_SIZE = 4 def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__(host, transport=transport) + super().__init__(transport=transport) self.reader: Optional[asyncio.StreamReader] = None self.writer: Optional[asyncio.StreamWriter] = None @@ -167,7 +156,7 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: assert isinstance(request, str) # noqa: S101 async with self.query_lock: - return await self._query(request, retry_count, self._timeout) + return await self._query(request, retry_count, self._timeout) # type: ignore[arg-type] async def _connect(self, timeout: int) -> None: """Try to connect or reconnect to the device.""" diff --git a/kasa/protocolfactory.py b/kasa/protocolfactory.py new file mode 100644 index 000000000..5d76bd3db --- /dev/null +++ b/kasa/protocolfactory.py @@ -0,0 +1,39 @@ +"""Module for protocol factory class.""" +from typing import Optional, Tuple, Type + +from .aestransport import AesTransport +from .connectionparams import ConnectionParameters +from .iotprotocol import IotProtocol +from .klaptransport import KlapTransport, KlapTransportV2 +from .protocol import ( + BaseTransport, + TPLinkProtocol, + TPLinkSmartHomeProtocol, + _XorTransport, +) +from .smartprotocol import SmartProtocol + + +def get_protocol( + cparams: ConnectionParameters, +) -> Optional[TPLinkProtocol]: + """Return the protocol from the connection name.""" + protocol_name = cparams.connection_type.device_family.value.split(".")[0] + protocol_transport_key = ( + protocol_name + "." + cparams.connection_type.encryption_type.value + ) + supported_device_protocols: dict[ + str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] + ] = { + "IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport), + "IOT.KLAP": (IotProtocol, KlapTransport), + "SMART.AES": (SmartProtocol, AesTransport), + "SMART.KLAP": (SmartProtocol, KlapTransportV2), + } + if protocol_transport_key not in supported_device_protocols: + return None + + protocol_class, transport_class = supported_device_protocols.get( + protocol_transport_key + ) # type: ignore + return protocol_class(transport=transport_class(cparams=cparams)) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 5ad94a9f4..bbc7e9a0c 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -19,6 +19,7 @@ from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Set +from .connectionparams import ConnectionParameters from .credentials import Credentials from .device_type import DeviceType from .emeterstatus import EmeterStatus @@ -201,8 +202,9 @@ def __init__( """ self.host = host self.port = port + cparams = ConnectionParameters(host=host, port=port, timeout=timeout) self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( - host, transport=_XorTransport(host, port=port, timeout=timeout) + transport=_XorTransport(cparams=cparams), ) self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) @@ -394,7 +396,7 @@ class itself as @requires_update will be affected for other properties. """ return self._sys_info # type: ignore - @property # type: ignore + @property @requires_update def model(self) -> str: """Return device model.""" @@ -760,7 +762,7 @@ def internal_state(self) -> Any: The returned object contains the raw results from the last update call. This should only be used for debugging purposes. """ - return self._last_update + return self._last_update or self._discovery_info def __repr__(self): if self._last_update is None: @@ -771,19 +773,21 @@ def __repr__(self): f" - dev specific: {self.state_information}>" ) + @property + def connection_parameters(self) -> ConnectionParameters: + """Return the connection parameters the device is using.""" + return self.protocol.connection_parameters + @staticmethod async def connect( - host: str, *, - port: Optional[int] = None, - timeout=5, - credentials: Optional[Credentials] = None, - device_type: Optional[DeviceType] = None, + host: Optional[str] = None, + cparams: Optional[ConnectionParameters] = None, ) -> "SmartDevice": - """Connect to a single device by the given IP address. + """Connect to a single device by the given hostname or connection parameters. This method avoids the UDP based discovery process and - will connect directly to the device to query its type. + will connect directly to the device. It is generally preferred to avoid :func:`discover_single()` and use this function instead as it should perform better when @@ -793,19 +797,17 @@ async def connect( The device type is discovered by querying the device. :param host: Hostname of device to query - :param device_type: Device type to use for the device. - If not given, the device type is discovered by querying the device. - If the device type is already known, it is preferred to pass it - to avoid the extra query to the device to discover its type. + :param cparams: Connection parameters to ensure the correct protocol + and connection options are used. :rtype: SmartDevice :return: Object for querying/controlling found device. """ from .device_factory import connect # pylint: disable=import-outside-toplevel - return await connect( - host=host, - port=port, - timeout=timeout, - credentials=credentials, - device_type=device_type, - ) + if host and cparams or (not host and not cparams): + raise SmartDeviceException( + "One of host or cparams must be provded and not both" + ) + if host: + cparams = ConnectionParameters(host=host) + return await connect(cparams=cparams) # type: ignore[arg-type] diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index 443d1def1..b15645645 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -38,12 +38,11 @@ class SmartProtocol(TPLinkProtocol): def __init__( self, - host: str, *, transport: BaseTransport, ) -> None: """Create a protocol object.""" - super().__init__(host, transport=transport) + super().__init__(transport=transport) self._terminal_uuid: str = base64.b64encode(md5(uuid.uuid4().bytes)).decode() self._request_id_generator = SnowflakeId(1, 1) self._query_lock = asyncio.Lock() @@ -68,22 +67,15 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: for retry in range(retry_count + 1): try: return await self._execute_query(request, retry) - except httpx.CloseError as sdex: - await self.close() + except httpx.ConnectError as sdex: if retry >= retry_count: _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue - except httpx.ConnectError as cex: - await self.close() - raise SmartDeviceException( - f"Unable to connect to the device: {self._host}: {cex}" - ) from cex except TimeoutError as tex: if retry >= retry_count: - await self.close() raise SmartDeviceException( "Unable to connect to the device, " + f"timed out: {self._host}: {tex}" @@ -91,20 +83,17 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) continue except AuthenticationException as auex: - await self.close() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) @@ -116,7 +105,6 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: raise ex except Exception as ex: if retry >= retry_count: - await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {ex}" diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 97405b3f1..c80542c4c 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Set, cast from ..aestransport import AesTransport +from ..connectionparams import ConnectionParameters from ..credentials import Credentials from ..exceptions import AuthenticationException from ..smartdevice import SmartDevice @@ -28,11 +29,14 @@ def __init__( self._components: Optional[Dict[str, Any]] = None self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None + cparams = ConnectionParameters( + host=host, + port=port, + credentials=credentials, # type: ignore[arg-type] + timeout=timeout, + ) self.protocol = SmartProtocol( - host, - transport=AesTransport( - host, credentials=credentials, timeout=timeout, port=port - ), + transport=AesTransport(cparams=cparams), ) async def update(self, update_children: bool = True): @@ -66,7 +70,7 @@ async def update(self, update_children: bool = True): @property def sys_info(self) -> Dict[str, Any]: """Returns the device info.""" - return self._info + return self._info # type: ignore @property def model(self) -> str: @@ -135,7 +139,7 @@ def mac(self) -> str: @property def device_id(self) -> str: """Return the device id.""" - return str(self._info.get("device_id")) + return str(self._info.get("device_id")) # type: ignore @property def internal_state(self) -> Any: @@ -180,3 +184,4 @@ async def turn_off(self, **kwargs): def update_from_discover_info(self, info): """Update state from info from the discover call.""" self._discovery_info = info + self._info = info diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 43bba825b..8c23e4c14 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -428,20 +428,42 @@ class _DiscoveryMock: default_port: int discovery_data: dict query_data: dict + device_type: str + encrypt_type: str port_override: Optional[int] = None if "discovery_result" in all_fixture_data: discovery_data = {"result": all_fixture_data["discovery_result"]} + device_type = all_fixture_data["discovery_result"]["device_type"] + encrypt_type = all_fixture_data["discovery_result"]["mgt_encrypt_schm"][ + "encrypt_type" + ] datagram = ( b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" + json_dumps(discovery_data).encode() ) - dm = _DiscoveryMock("127.0.0.123", 20002, discovery_data, all_fixture_data) + dm = _DiscoveryMock( + "127.0.0.123", + 20002, + discovery_data, + all_fixture_data, + device_type, + encrypt_type, + ) else: sys_info = all_fixture_data["system"]["get_sysinfo"] discovery_data = {"system": {"get_sysinfo": sys_info}} + device_type = sys_info.get("mic_type") or sys_info.get("type") + encrypt_type = "XOR" datagram = TPLinkSmartHomeProtocol.encrypt(json_dumps(discovery_data))[4:] - dm = _DiscoveryMock("127.0.0.123", 9999, discovery_data, all_fixture_data) + dm = _DiscoveryMock( + "127.0.0.123", + 9999, + discovery_data, + all_fixture_data, + device_type, + encrypt_type, + ) def mock_discover(self): port = ( diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index cd7ad4fd9..4f9077ea4 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -15,6 +15,8 @@ Schema, ) +from ..connectionparams import ConnectionParameters +from ..credentials import Credentials from ..protocol import BaseTransport, TPLinkSmartHomeProtocol from ..smartprotocol import SmartProtocol @@ -290,7 +292,9 @@ def success(res): class FakeSmartProtocol(SmartProtocol): def __init__(self, info): - super().__init__("127.0.0.123", transport=FakeSmartTransport(info)) + super().__init__( + transport=FakeSmartTransport(info), + ) async def query(self, request, retry_count: int = 3): """Implement query here so can still patch SmartProtocol.query.""" @@ -301,7 +305,9 @@ async def query(self, request, retry_count: int = 3): class FakeSmartTransport(BaseTransport): def __init__(self, info): super().__init__( - "127.0.0.123", + cparams=ConnectionParameters( + "127.0.0.123", credentials=Credentials("", "") + ), ) self.info = info diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index 198e8f39e..f42181ae7 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -11,6 +11,7 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from ..aestransport import AesEncyptionSession, AesTransport +from ..connectionparams import ConnectionParameters from ..credentials import Credentials from ..exceptions import ( SMART_RETRYABLE_ERRORS, @@ -58,7 +59,9 @@ async def test_handshake( mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport = AesTransport( + cparams=ConnectionParameters(host, credentials=Credentials("foo", "bar")) + ) assert transport._encryption_session is None assert transport._handshake_done is False @@ -74,7 +77,9 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport = AesTransport( + cparams=ConnectionParameters(host, credentials=Credentials("foo", "bar")) + ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session @@ -91,13 +96,14 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati mock_aes_device = MockAesDevice(host, status_code, error_code, inner_error_code) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + transport = AesTransport( + cparams=ConnectionParameters(host, credentials=Credentials("foo", "bar")) + ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session transport._login_token = mock_aes_device.token - un, pw = transport.hash_credentials(True) request = { "method": "get_device_info", "params": None, diff --git a/kasa/tests/test_cli.py b/kasa/tests/test_cli.py index c46015ea0..1983b6ccb 100644 --- a/kasa/tests/test_cli.py +++ b/kasa/tests/test_cli.py @@ -4,10 +4,26 @@ import pytest from asyncclick.testing import CliRunner -from kasa import AuthenticationException, SmartDevice, UnsupportedDeviceException -from kasa.cli import alias, brightness, cli, emeter, raw_command, state, sysinfo, toggle -from kasa.device_factory import DEVICE_TYPE_TO_CLASS -from kasa.discover import Discover +from kasa import ( + AuthenticationException, + Credentials, + SmartDevice, + TPLinkSmartHomeProtocol, + UnsupportedDeviceException, +) +from kasa.cli import ( + TYPE_TO_CLASS, + alias, + brightness, + cli, + emeter, + raw_command, + state, + sysinfo, + toggle, +) +from kasa.discover import Discover, DiscoveryResult +from kasa.smartprotocol import SmartProtocol from .conftest import device_iot, handle_turn_on, new_discovery, turn_on @@ -145,9 +161,11 @@ async def _state(dev: SmartDevice): ) mocker.patch("kasa.cli.state", new=_state) - for subclass in DEVICE_TYPE_TO_CLASS.values(): - mocker.patch.object(subclass, "update") + mocker.patch("kasa.IotProtocol.query", return_value=discovery_mock.query_data) + mocker.patch("kasa.SmartProtocol.query", return_value=discovery_mock.query_data) + + dr = DiscoveryResult(**discovery_mock.discovery_data["result"]) runner = CliRunner() res = await runner.invoke( cli, @@ -158,6 +176,10 @@ async def _state(dev: SmartDevice): "foo", "--password", "bar", + "--device-family", + dr.device_type, + "--encrypt-type", + dr.mgt_encrypt_schm.encrypt_type, ], ) assert res.exit_code == 0 @@ -166,7 +188,7 @@ async def _state(dev: SmartDevice): @device_iot -async def test_without_device_type(discovery_data: dict, dev, mocker): +async def test_without_device_type(dev, mocker): """Test connecting without the device type.""" runner = CliRunner() mocker.patch("kasa.discover.Discover.discover_single", return_value=dev) @@ -342,3 +364,27 @@ async def test_host_auth_failed(discovery_mock, mocker): assert res.exit_code != 0 assert isinstance(res.exception, AuthenticationException) + + +@pytest.mark.parametrize("device_type", list(TYPE_TO_CLASS)) +async def test_type_param(device_type, mocker): + """Test for handling only one of username or password supplied.""" + runner = CliRunner() + + result_device = FileNotFoundError + pass_dev = click.make_pass_decorator(SmartDevice) + + @pass_dev + async def _state(dev: SmartDevice): + nonlocal result_device + result_device = dev + + mocker.patch("kasa.cli.state", new=_state) + expected_type = TYPE_TO_CLASS[device_type] + mocker.patch.object(expected_type, "update") + res = await runner.invoke( + cli, + ["--type", device_type, "--host", "127.0.0.1"], + ) + assert res.exit_code == 0 + assert isinstance(result_device, expected_type) diff --git a/kasa/tests/test_connectionparams.py b/kasa/tests/test_connectionparams.py new file mode 100644 index 000000000..baede2e95 --- /dev/null +++ b/kasa/tests/test_connectionparams.py @@ -0,0 +1,21 @@ +from json import dumps as json_dumps +from json import loads as json_loads + +import httpx + +from kasa.connectionparams import ( + ConnectionParameters, + ConnectionType, + DeviceFamilyType, + EncryptType, +) +from kasa.credentials import Credentials + + +def test_serialization(): + cp = ConnectionParameters(host="Foo", http_client=httpx.AsyncClient()) + cp_dict = cp.to_dict() + cp_json = json_dumps(cp_dict) + cp2_dict = json_loads(cp_json) + cp2 = ConnectionParameters.from_dict(cp2_dict) + assert cp == cp2 diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index eb12b3b0d..aa00d2728 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -2,6 +2,7 @@ import logging from typing import Type +import httpx import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( @@ -15,122 +16,139 @@ SmartLightStrip, SmartPlug, ) -from kasa.device_factory import ( - DEVICE_TYPE_TO_CLASS, - connect, - get_protocol_from_connection_name, +from kasa.connectionparams import ( + ConnectionParameters, + ConnectionType, + DeviceFamilyType, + EncryptType, ) +from kasa.device_factory import connect from kasa.discover import DiscoveryResult -from kasa.iotprotocol import IotProtocol -from kasa.protocol import TPLinkProtocol, TPLinkSmartHomeProtocol +from kasa.protocolfactory import get_protocol -@pytest.mark.parametrize("custom_port", [123, None]) -async def test_connect(discovery_data: dict, mocker, custom_port): - """Make sure that connect returns an initialized SmartDevice instance.""" - host = "127.0.0.1" +def _get_connection_type_device_class(the_fixture_data): + if "discovery_result" in the_fixture_data: + discovery_info = {"result": the_fixture_data["discovery_result"]} + device_class = Discover._get_device_class(discovery_info) + dr = DiscoveryResult(**discovery_info["result"]) - if "result" in discovery_data: - with pytest.raises(SmartDeviceException): - dev = await connect(host, port=custom_port) + connection_type = ConnectionType.from_values( + dr.device_type, dr.mgt_encrypt_schm.encrypt_type + ) else: - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port) - assert issubclass(dev.__class__, SmartDevice) - assert dev.port == custom_port or dev.port == 9999 + connection_type = ConnectionType.from_values( + DeviceFamilyType.IotSmartPlugSwitch.value, EncryptType.Xor.value + ) + device_class = Discover._get_device_class(the_fixture_data) + return connection_type, device_class -@pytest.mark.parametrize("custom_port", [123, None]) -@pytest.mark.parametrize( - ("device_type", "klass"), - ( - (DeviceType.Plug, SmartPlug), - (DeviceType.Bulb, SmartBulb), - (DeviceType.Dimmer, SmartDimmer), - (DeviceType.LightStrip, SmartLightStrip), - (DeviceType.Unknown, SmartDevice), - ), -) -async def test_connect_passed_device_type( - discovery_data: dict, + +async def test_connect( + all_fixture_data: dict, mocker, - device_type: DeviceType, - klass: Type[SmartDevice], - custom_port, ): - """Make sure that connect with a passed device type.""" + """Test that if the protocol is passed in it gets set correctly.""" host = "127.0.0.1" + ctype, device_class = _get_connection_type_device_class(all_fixture_data) - if "result" in discovery_data: - with pytest.raises(SmartDeviceException): - dev = await connect(host, port=custom_port) - else: - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - dev = await connect(host, port=custom_port, device_type=device_type) - assert isinstance(dev, klass) - assert dev.port == custom_port or dev.port == 9999 + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + cparams = ConnectionParameters( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + protocol_class = get_protocol(cparams).__class__ -async def test_connect_query_fails(discovery_data: dict, mocker): - """Make sure that connect fails when query fails.""" + dev = await connect( + cparams=cparams, + ) + assert isinstance(dev, device_class) + assert isinstance(dev.protocol, protocol_class) + + assert dev.connection_parameters == cparams + + +@pytest.mark.parametrize("custom_port", [123, None]) +async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): + """Make sure that connect returns an initialized SmartDevice instance.""" host = "127.0.0.1" - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) - with pytest.raises(SmartDeviceException): - await connect(host) + ctype, _ = _get_connection_type_device_class(all_fixture_data) + cparams = ConnectionParameters(host=host, port=custom_port, connection_type=ctype) + default_port = 80 if "discovery_result" in all_fixture_data else 9999 + + ctype, _ = _get_connection_type_device_class(all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + dev = await connect(cparams=cparams) + assert issubclass(dev.__class__, SmartDevice) + assert dev.port == custom_port or dev.port == default_port async def test_connect_logs_connect_time( - discovery_data: dict, caplog: pytest.LogCaptureFixture, mocker + all_fixture_data: dict, caplog: pytest.LogCaptureFixture, mocker ): """Test that the connect time is logged when debug logging is enabled.""" + ctype, _ = _get_connection_type_device_class(all_fixture_data) + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + host = "127.0.0.1" - if "result" in discovery_data: - with pytest.raises(SmartDeviceException): - await connect(host) - else: - mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=discovery_data) - logging.getLogger("kasa").setLevel(logging.DEBUG) - await connect(host) - assert "seconds to connect" in caplog.text + cparams = ConnectionParameters( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + logging.getLogger("kasa").setLevel(logging.DEBUG) + await connect( + cparams=cparams, + ) + assert "seconds to update" in caplog.text -async def test_connect_pass_protocol( - all_fixture_data: dict, - mocker, -): - """Test that if the protocol is passed in it's gets set correctly.""" - if "discovery_result" in all_fixture_data: - discovery_info = {"result": all_fixture_data["discovery_result"]} - device_class = Discover._get_device_class(discovery_info) - else: - device_class = Discover._get_device_class(all_fixture_data) +async def test_connect_query_fails(all_fixture_data: dict, mocker): + """Make sure that connect fails when query fails.""" + host = "127.0.0.1" + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", side_effect=SmartDeviceException) + mocker.patch("kasa.IotProtocol.query", side_effect=SmartDeviceException) + mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException) - device_type = list(DEVICE_TYPE_TO_CLASS.keys())[ - list(DEVICE_TYPE_TO_CLASS.values()).index(device_class) - ] + ctype, _ = _get_connection_type_device_class(all_fixture_data) + cparams = ConnectionParameters( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype + ) + with pytest.raises(SmartDeviceException): + await connect(cparams=cparams) + + +async def test_connect_http_client(all_fixture_data, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" host = "127.0.0.1" - if "discovery_result" in all_fixture_data: - mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) - mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) - dr = DiscoveryResult(**discovery_info["result"]) - connection_name = ( - dr.device_type.split(".")[0] + "." + dr.mgt_encrypt_schm.encrypt_type - ) - protocol_class = get_protocol_from_connection_name( - connection_name, host - ).__class__ - else: - mocker.patch( - "kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data - ) - protocol_class = TPLinkSmartHomeProtocol + ctype, _ = _get_connection_type_device_class(all_fixture_data) - dev = await connect( - host, - device_type=device_type, - protocol_class=protocol_class, - credentials=Credentials("", ""), + mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) + mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) + + http_client = httpx.AsyncClient() + + cparams = ConnectionParameters( + host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) - assert isinstance(dev.protocol, protocol_class) + dev = await connect(cparams=cparams) + if ctype.encryption_type != EncryptType.Xor: + assert dev.protocol._transport._http_client != http_client + + cparams = ConnectionParameters( + host=host, + credentials=Credentials("foor", "bar"), + connection_type=ctype, + http_client=http_client, + ) + dev = await connect(cparams=cparams) + if ctype.encryption_type != EncryptType.Xor: + assert dev.protocol._transport._http_client == http_client diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 18798ab90..970b750ec 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -1,21 +1,29 @@ # type: ignore +import logging import re import socket +import httpx import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 from kasa import ( + Credentials, DeviceType, Discover, SmartDevice, SmartDeviceException, - SmartStrip, protocol, ) +from kasa.connectionparams import ( + ConnectionParameters, + ConnectionType, + DeviceFamilyType, + EncryptType, +) from kasa.discover import DiscoveryResult, _DiscoverProtocol, json_dumps from kasa.exceptions import AuthenticationException, UnsupportedDeviceException -from .conftest import bulb, bulb_iot, dimmer, lightstrip, plug, strip +from .conftest import bulb, bulb_iot, dimmer, lightstrip, new_discovery, plug, strip UNSUPPORTED = { "result": { @@ -89,13 +97,23 @@ async def test_discover_single(discovery_mock, custom_port, mocker): host = "127.0.0.1" discovery_mock.ip = host discovery_mock.port_override = custom_port - update_mock = mocker.patch.object(SmartStrip, "update") - x = await Discover.discover_single(host, port=custom_port) + device_class = Discover._get_device_class(discovery_mock.discovery_data) + update_mock = mocker.patch.object(device_class, "update") + + x = await Discover.discover_single( + host, port=custom_port, credentials=Credentials("", "") + ) assert issubclass(x.__class__, SmartDevice) assert x._discovery_info is not None assert x.port == custom_port or x.port == discovery_mock.default_port - assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) + assert update_mock.call_count == 0 + + ct = ConnectionType.from_values( + discovery_mock.device_type, discovery_mock.encrypt_type + ) + cp = ConnectionParameters(host=host, port=custom_port, connection_type=ct) + assert x.connection_parameters == cp async def test_discover_single_hostname(discovery_mock, mocker): @@ -104,47 +122,39 @@ async def test_discover_single_hostname(discovery_mock, mocker): ip = "127.0.0.1" discovery_mock.ip = ip - update_mock = mocker.patch.object(SmartStrip, "update") + device_class = Discover._get_device_class(discovery_mock.discovery_data) + update_mock = mocker.patch.object(device_class, "update") - x = await Discover.discover_single(host) + x = await Discover.discover_single(host, credentials=Credentials("", "")) assert issubclass(x.__class__, SmartDevice) assert x._discovery_info is not None assert x.host == host - assert (update_mock.call_count > 0) == isinstance(x, SmartStrip) + assert update_mock.call_count == 0 mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) with pytest.raises(SmartDeviceException): - x = await Discover.discover_single(host) + x = await Discover.discover_single(host, credentials=Credentials("", "")) -async def test_discover_single_unsupported(mocker): +async def test_discover_single_unsupported(unsupported_device_info, mocker): """Make sure that discover_single handles unsupported devices correctly.""" host = "127.0.0.1" - def mock_discover(self): - if discovery_data: - data = ( - b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" - + json_dumps(discovery_data).encode() - ) - self.datagram_received(data, (host, 20002)) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) - # Test with a valid unsupported response - discovery_data = UNSUPPORTED with pytest.raises( UnsupportedDeviceException, - match=f"Unsupported device {host} of type SMART.TAPOXMASTREE: {re.escape(str(UNSUPPORTED))}", ): await Discover.discover_single(host) - # Test with no response - discovery_data = None + +async def test_discover_single_no_response(mocker): + """Make sure that discover_single handles no response correctly.""" + host = "127.0.0.1" + mocker.patch.object(_DiscoverProtocol, "do_discover") with pytest.raises( SmartDeviceException, match=f"Timed out getting discovery response for {host}" ): - await Discover.discover_single(host, timeout=0.001) + await Discover.discover_single(host, discovery_timeout=0) INVALIDS = [ @@ -241,52 +251,83 @@ async def test_discover_invalid_responses(msg, data, mocker): } -async def test_discover_single_authentication(mocker): +@new_discovery +async def test_discover_single_authentication(discovery_mock, mocker): """Make sure that discover_single handles authenticating devices correctly.""" host = "127.0.0.1" - - def mock_discover(self): - if discovery_data: - data = ( - b"\x02\x00\x00\x01\x01[\x00\x00\x00\x00\x00\x00W\xcev\xf8" - + json_dumps(discovery_data).encode() - ) - self.datagram_received(data, (host, 20002)) - - mocker.patch.object(_DiscoverProtocol, "do_discover", mock_discover) + discovery_mock.ip = host + device_class = Discover._get_device_class(discovery_mock.discovery_data) mocker.patch.object( - SmartDevice, + device_class, "update", side_effect=AuthenticationException("Failed to authenticate"), ) - # Test with a valid unsupported response - discovery_data = AUTHENTICATION_DATA_KLAP with pytest.raises( AuthenticationException, match="Failed to authenticate", ): - device = await Discover.discover_single(host) + device = await Discover.discover_single( + host, credentials=Credentials("foo", "bar") + ) await device.update() - mocker.patch.object(SmartDevice, "update") - device = await Discover.discover_single(host) + mocker.patch.object(device_class, "update") + device = await Discover.discover_single(host, credentials=Credentials("foo", "bar")) await device.update() - assert device.device_type == DeviceType.Plug + assert isinstance(device, device_class) -async def test_device_update_from_new_discovery_info(): +@new_discovery +async def test_device_update_from_new_discovery_info(discovery_data): device = SmartDevice("127.0.0.7") - discover_info = DiscoveryResult(**AUTHENTICATION_DATA_KLAP["result"]) + discover_info = DiscoveryResult(**discovery_data["result"]) discover_dump = discover_info.get_dict() + discover_dump["alias"] = "foobar" + discover_dump["model"] = discover_dump["device_model"] device.update_from_discover_info(discover_dump) - assert device.alias == discover_dump["alias"] + assert device.alias == "foobar" assert device.mac == discover_dump["mac"].replace("-", ":") - assert device.model == discover_dump["model"] + assert device.model == discover_dump["device_model"] with pytest.raises( SmartDeviceException, match=re.escape("You need to await update() to access the data"), ): assert device.supported_modules + + +async def test_discover_single_http_client(discovery_mock, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + discovery_mock.ip = host + + http_client = httpx.AsyncClient() + + x = await Discover.discover_single(host) + if discovery_mock.default_port == 20002: + assert x.protocol._transport._http_client != http_client + + x = await Discover.discover_single(host, httpx_asyncclient=http_client) + if discovery_mock.default_port == 20002: + assert x.protocol._transport._http_client == http_client + + +async def test_discover_http_client(discovery_mock, mocker): + """Make sure that discover_single returns an initialized SmartDevice instance.""" + host = "127.0.0.1" + discovery_mock.ip = host + + http_client = httpx.AsyncClient() + + def gen(): + return http_client + + devs = await Discover.discover(discovery_timeout=0) + if discovery_mock.default_port == 20002: + assert devs[host].protocol._transport._http_client != http_client + + devs = await Discover.discover(discovery_timeout=0, http_client_generator=gen) + if discovery_mock.default_port == 20002: + assert devs[host].protocol._transport._http_client == http_client diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 1ed57ef22..6263757d6 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -11,10 +11,16 @@ import pytest from ..aestransport import AesTransport +from ..connectionparams import ConnectionParameters from ..credentials import Credentials from ..exceptions import AuthenticationException, SmartDeviceException from ..iotprotocol import IotProtocol -from ..klaptransport import KlapEncryptionSession, KlapTransport, _sha256 +from ..klaptransport import ( + KlapEncryptionSession, + KlapTransport, + KlapTransportV2, + _sha256, +) from ..smartprotocol import SmartProtocol DUMMY_QUERY = {"foobar": {"foo": "bar", "bar": "foo"}} @@ -42,8 +48,10 @@ async def test_protocol_retries( ): host = "127.0.0.1" conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error) + + cparams = ConnectionParameters(host) with pytest.raises(SmartDeviceException): - await protocol_class(host, transport=transport_class(host)).query( + await protocol_class(transport=transport_class(cparams=cparams)).query( DUMMY_QUERY, retry_count=retry_count ) @@ -60,10 +68,11 @@ async def test_protocol_no_retry_on_connection_error( conn = mocker.patch.object( httpx.AsyncClient, "post", - side_effect=httpx.ConnectError("foo"), + side_effect=AuthenticationException("foo"), ) + cparams = ConnectionParameters(host) with pytest.raises(SmartDeviceException): - await protocol_class(host, transport=transport_class(host)).query( + await protocol_class(transport=transport_class(cparams=cparams)).query( DUMMY_QUERY, retry_count=5 ) @@ -81,8 +90,9 @@ async def test_protocol_retry_recoverable_error( "post", side_effect=httpx.CloseError("foo"), ) + cparams = ConnectionParameters(host) with pytest.raises(SmartDeviceException): - await protocol_class(host, transport=transport_class(host)).query( + await protocol_class(transport=transport_class(cparams=cparams)).query( DUMMY_QUERY, retry_count=5 ) @@ -115,7 +125,8 @@ def _fail_one_less_than_retry_count(*_, **__): side_effect=_fail_one_less_than_retry_count, ) - response = await protocol_class(host, transport=transport_class(host)).query( + cparams = ConnectionParameters(host) + response = await protocol_class(transport=transport_class(cparams=cparams)).query( DUMMY_QUERY, retry_count=retry_count ) assert "result" in response or "foobar" in response @@ -136,7 +147,9 @@ def _return_encrypted(*_, **__): seed = secrets.token_bytes(16) auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - protocol = IotProtocol("127.0.0.1", transport=KlapTransport("127.0.0.1")) + + cparams = ConnectionParameters("127.0.0.1") + protocol = IotProtocol(transport=KlapTransport(cparams=cparams)) protocol._transport._handshake_done = True protocol._transport._session_expire_at = time.time() + 86400 @@ -196,30 +209,37 @@ def test_encrypt_unicode(): ], ids=("client", "blank", "kasa_setup", "shouldfail"), ) -async def test_handshake1(mocker, device_credentials, expectation): +@pytest.mark.parametrize( + "transport_class, seed_auth_hash_calc", + [ + pytest.param(KlapTransport, lambda c, s, a: c + a, id="KLAP"), + pytest.param(KlapTransportV2, lambda c, s, a: c + s + a, id="KLAPV2"), + ], +) +async def test_handshake1( + mocker, device_credentials, expectation, transport_class, seed_auth_hash_calc +): async def _return_handshake1_response(url, params=None, data=None, *_, **__): nonlocal client_seed, server_seed, device_auth_hash client_seed = data - client_seed_auth_hash = _sha256(data + device_auth_hash) - - return _mock_response(200, server_seed + client_seed_auth_hash) + seed_auth_hash = _sha256( + seed_auth_hash_calc(client_seed, server_seed, device_auth_hash) + ) + return _mock_response(200, server_seed + seed_auth_hash) client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = KlapTransport.generate_auth_hash(device_credentials) + device_auth_hash = transport_class.generate_auth_hash(device_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=transport_class(cparams=cparams)) - protocol._transport.http_client = httpx.AsyncClient() with expectation: ( local_seed, @@ -233,31 +253,51 @@ async def _return_handshake1_response(url, params=None, data=None, *_, **__): await protocol.close() -async def test_handshake(mocker): +@pytest.mark.parametrize( + "transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2", + [ + pytest.param( + KlapTransport, lambda c, s, a: c + a, lambda c, s, a: s + a, id="KLAP" + ), + pytest.param( + KlapTransportV2, + lambda c, s, a: c + s + a, + lambda c, s, a: s + c + a, + id="KLAPV2", + ), + ], +) +async def test_handshake( + mocker, transport_class, seed_auth_hash_calc1, seed_auth_hash_calc2 +): async def _return_handshake_response(url, params=None, data=None, *_, **__): - nonlocal response_status, client_seed, server_seed, device_auth_hash + nonlocal client_seed, server_seed, device_auth_hash if url == "http://127.0.0.1/app/handshake1": client_seed = data - client_seed_auth_hash = _sha256(data + device_auth_hash) + seed_auth_hash = _sha256( + seed_auth_hash_calc1(client_seed, server_seed, device_auth_hash) + ) - return _mock_response(200, server_seed + client_seed_auth_hash) + return _mock_response(200, server_seed + seed_auth_hash) elif url == "http://127.0.0.1/app/handshake2": + seed_auth_hash = _sha256( + seed_auth_hash_calc2(client_seed, server_seed, device_auth_hash) + ) + assert data == seed_auth_hash return _mock_response(response_status, b"") client_seed = None server_seed = secrets.token_bytes(16) client_credentials = Credentials("foo", "bar") - device_auth_hash = KlapTransport.generate_auth_hash(client_credentials) + device_auth_hash = transport_class.generate_auth_hash(client_credentials) mocker.patch.object( httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=transport_class(cparams=cparams)) protocol._transport.http_client = httpx.AsyncClient() response_status = 200 @@ -273,7 +313,7 @@ async def _return_handshake_response(url, params=None, data=None, *_, **__): async def test_query(mocker): async def _return_response(url, params=None, data=None, *_, **__): - nonlocal client_seed, server_seed, device_auth_hash, protocol, seq + nonlocal client_seed, server_seed, device_auth_hash, seq if url == "http://127.0.0.1/app/handshake1": client_seed = data @@ -303,10 +343,8 @@ async def _return_response(url, params=None, data=None, *_, **__): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=KlapTransport(cparams=cparams)) for _ in range(10): resp = await protocol.query({}) @@ -350,10 +388,8 @@ async def _return_response(url, params=None, data=None, *_, **__): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - protocol = IotProtocol( - "127.0.0.1", - transport=KlapTransport("127.0.0.1", credentials=client_credentials), - ) + cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=KlapTransport(cparams=cparams)) with expectation: await protocol.query({}) diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 7bd6342b4..550e2e758 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -9,6 +9,7 @@ import pytest +from ..connectionparams import ConnectionParameters from ..exceptions import SmartDeviceException from ..protocol import ( BaseTransport, @@ -31,10 +32,11 @@ def aio_mock_writer(_, __): return reader, writer conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) + cparams = ConnectionParameters("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=retry_count) + await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + {}, retry_count=retry_count + ) assert conn.call_count == retry_count + 1 @@ -44,10 +46,11 @@ async def test_protocol_no_retry_on_unreachable(mocker): "asyncio.open_connection", side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), ) + cparams = ConnectionParameters("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=5) + await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + {}, retry_count=5 + ) assert conn.call_count == 1 @@ -57,10 +60,11 @@ async def test_protocol_no_retry_connection_refused(mocker): "asyncio.open_connection", side_effect=ConnectionRefusedError, ) + cparams = ConnectionParameters("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=5) + await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + {}, retry_count=5 + ) assert conn.call_count == 1 @@ -70,10 +74,11 @@ async def test_protocol_retry_recoverable_error(mocker): "asyncio.open_connection", side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), ) + cparams = ConnectionParameters("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ).query({}, retry_count=5) + await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + {}, retry_count=5 + ) assert conn.call_count == 6 @@ -107,9 +112,8 @@ def aio_mock_writer(_, __): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ) + cparams = ConnectionParameters("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}, retry_count=retry_count) assert response == {"great": "success"} @@ -137,9 +141,8 @@ def aio_mock_writer(_, __): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1") - ) + cparams = ConnectionParameters("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -173,9 +176,8 @@ def aio_mock_writer(_, port): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - protocol = TPLinkSmartHomeProtocol( - "127.0.0.1", transport=_XorTransport("127.0.0.1", port=custom_port) - ) + cparams = ConnectionParameters("127.0.0.1", port=custom_port) + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -271,18 +273,14 @@ def _get_subclasses(of_class): def test_protocol_init_signature(class_name_obj): params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) - assert len(params) == 3 + assert len(params) == 2 assert ( params[0].name == "self" and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) assert ( - params[1].name == "host" - and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - assert ( - params[2].name == "transport" - and params[2].kind == inspect.Parameter.KEYWORD_ONLY + params[1].name == "transport" + and params[1].kind == inspect.Parameter.KEYWORD_ONLY ) @@ -292,20 +290,11 @@ def test_protocol_init_signature(class_name_obj): def test_transport_init_signature(class_name_obj): params = list(inspect.signature(class_name_obj[1].__init__).parameters.values()) - assert len(params) == 5 + assert len(params) == 2 assert ( params[0].name == "self" and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) assert ( - params[1].name == "host" - and params[1].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ) - assert params[2].name == "port" and params[2].kind == inspect.Parameter.KEYWORD_ONLY - assert ( - params[3].name == "credentials" - and params[3].kind == inspect.Parameter.KEYWORD_ONLY - ) - assert ( - params[4].name == "timeout" and params[4].kind == inspect.Parameter.KEYWORD_ONLY + params[1].name == "cparams" and params[1].kind == inspect.Parameter.KEYWORD_ONLY ) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 47f523d00..46b9220c2 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -5,7 +5,7 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import kasa -from kasa import Credentials, SmartDevice, SmartDeviceException +from kasa import ConnectionParameters, Credentials, SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on @@ -238,22 +238,18 @@ async def test_create_smart_device_with_timeout(): async def test_create_thin_wrapper(): """Make sure thin wrapper is created with the correct device type.""" mock = Mock() - with patch("kasa.device_factory.connect", return_value=mock) as connect: - dev = await SmartDevice.connect( - host="test_host", - port=1234, - timeout=100, - credentials=Credentials("username", "password"), - device_type=DeviceType.Strip, - ) - assert dev is mock - - connect.assert_called_once_with( + cparams = ConnectionParameters( host="test_host", port=1234, timeout=100, credentials=Credentials("username", "password"), - device_type=DeviceType.Strip, + ) + with patch("kasa.device_factory.connect", return_value=mock) as connect: + dev = await SmartDevice.connect(cparams=cparams) + assert dev is mock + + connect.assert_called_once_with( + cparams=cparams, ) From 20c868ff7c55d6ceb66fad301fe6d3edfb0703f2 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Wed, 20 Dec 2023 12:03:23 +0000 Subject: [PATCH 2/5] Update post review --- kasa/__init__.py | 8 +- kasa/aestransport.py | 25 ++++--- kasa/cli.py | 6 +- kasa/credentials.py | 4 +- kasa/device_factory.py | 30 ++++---- kasa/device_type.py | 0 kasa/{connectionparams.py => deviceconfig.py} | 7 +- kasa/discover.py | 75 +++++++------------ kasa/iotprotocol.py | 4 + kasa/klaptransport.py | 24 +++--- kasa/protocol.py | 24 +++--- kasa/protocolfactory.py | 10 +-- kasa/smartdevice.py | 22 +++--- kasa/smartprotocol.py | 6 ++ kasa/tapo/tapodevice.py | 6 +- kasa/tests/conftest.py | 2 +- kasa/tests/newfakes.py | 6 +- kasa/tests/test_aestransport.py | 11 +-- kasa/tests/test_connectionparams.py | 21 ------ kasa/tests/test_device_factory.py | 34 ++++----- kasa/tests/test_deviceconfig.py | 21 ++++++ kasa/tests/test_discovery.py | 40 +++++----- kasa/tests/test_klapprotocol.py | 43 +++++------ kasa/tests/test_protocol.py | 32 ++++---- kasa/tests/test_smartdevice.py | 8 +- kasa/tests/test_smartprotocol.py | 8 +- 26 files changed, 244 insertions(+), 233 deletions(-) mode change 100644 => 100755 kasa/device_factory.py mode change 100644 => 100755 kasa/device_type.py rename kasa/{connectionparams.py => deviceconfig.py} (95%) delete mode 100644 kasa/tests/test_connectionparams.py create mode 100644 kasa/tests/test_deviceconfig.py diff --git a/kasa/__init__.py b/kasa/__init__.py index 61e367244..f5b795bdc 100755 --- a/kasa/__init__.py +++ b/kasa/__init__.py @@ -13,13 +13,13 @@ """ from importlib.metadata import version -from kasa.connectionparams import ( - ConnectionParameters, +from kasa.credentials import Credentials +from kasa.deviceconfig import ( ConnectionType, + DeviceConfig, DeviceFamilyType, EncryptType, ) -from kasa.credentials import Credentials from kasa.discover import Discover from kasa.emeterstatus import EmeterStatus from kasa.exceptions import ( @@ -61,7 +61,7 @@ "AuthenticationException", "UnsupportedDeviceException", "Credentials", - "ConnectionParameters", + "DeviceConfig", "ConnectionType", "EncryptType", "DeviceFamilyType", diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 60c0df794..117af63bc 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -16,7 +16,7 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from .connectionparams import ConnectionParameters +from .deviceconfig import DeviceConfig from .exceptions import ( SMART_AUTHENTICATION_ERRORS, SMART_RETRYABLE_ERRORS, @@ -58,14 +58,13 @@ class AesTransport(BaseTransport): def __init__( self, *, - cparams: ConnectionParameters, + config: DeviceConfig, ) -> None: - super().__init__(cparams=cparams) - self._port = cparams.port or self.DEFAULT_PORT + super().__init__(config=config) + self._port = config.port or self.DEFAULT_PORT + + self._default_http_client: Optional[httpx.AsyncClient] = None - self._http_client: httpx.AsyncClient = ( - cparams.http_client or httpx.AsyncClient() - ) self._handshake_done = False self._encryption_session: Optional[AesEncyptionSession] = None @@ -77,6 +76,14 @@ def __init__( _LOGGER.debug("Created AES transport for %s", self._host) + @property + def _http_client(self) -> httpx.AsyncClient: + if self._config.http_client: + return self._config.http_client + if not self._default_http_client: + self._default_http_client = httpx.AsyncClient() + return self._default_http_client + def hash_credentials(self, login_v2): """Hash the credentials.""" if login_v2: @@ -259,8 +266,8 @@ async def send(self, request: str): async def close(self) -> None: """Close the protocol.""" - client = self._http_client - self._http_client = None + client = self._default_http_client + self._default_http_client = None self._handshake_done = False self._login_token = None if client: diff --git a/kasa/cli.py b/kasa/cli.py index 162e94865..821e78693 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -12,9 +12,9 @@ from kasa import ( AuthenticationException, - ConnectionParameters, ConnectionType, Credentials, + DeviceConfig, DeviceFamilyType, Discover, EncryptType, @@ -305,10 +305,10 @@ def _nop_echo(*args, **kwargs): DeviceFamilyType(device_family), EncryptType(encrypt_type), ) - cparams = ConnectionParameters( + config = DeviceConfig( host=host, credentials=credentials, timeout=timeout, connection_type=ctype ) - dev = await SmartDevice.connect(cparams=cparams) + dev = await SmartDevice.connect(config=config) else: echo("No --type or --device-family and --encrypt-type defined, discovering..") dev = await Discover.discover_single( diff --git a/kasa/credentials.py b/kasa/credentials.py index a56f5710d..4ae4df356 100644 --- a/kasa/credentials.py +++ b/kasa/credentials.py @@ -8,5 +8,5 @@ class Credentials: """Credentials for authentication.""" - username: Optional[str] = field(default=None, repr=False) - password: Optional[str] = field(default=None, repr=False) + username: Optional[str] = field(default="", repr=False) + password: Optional[str] = field(default="", repr=False) diff --git a/kasa/device_factory.py b/kasa/device_factory.py old mode 100644 new mode 100755 index 26f7538a8..254c4509c --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -1,9 +1,9 @@ -"""Device creation via ConnectionParameters.""" +"""Device creation via DeviceConfig.""" import logging import time from typing import Any, Dict, Optional, Type -from kasa.connectionparams import ConnectionParameters +from kasa.deviceconfig import DeviceConfig from kasa.protocol import TPLinkSmartHomeProtocol from kasa.smartbulb import SmartBulb from kasa.smartdevice import SmartDevice @@ -23,7 +23,7 @@ } -async def connect(*, cparams: ConnectionParameters) -> "SmartDevice": +async def connect(*, config: DeviceConfig) -> "SmartDevice": """Connect to a single device by the given connection parameters. Do not use this function directly, use SmartDevice.Connect() @@ -37,15 +37,15 @@ def _perf_log(has_params, perf_type): if debug_enabled: end_time = time.perf_counter() _LOGGER.debug( - f"Device {cparams.host} with connection params {has_params} " + f"Device {config.host} with connection params {has_params} " + f"took {end_time - start_time:.2f} seconds to {perf_type}", ) start_time = time.perf_counter() - if (protocol := get_protocol(cparams=cparams)) is None: + if (protocol := get_protocol(config=config)) is None: raise UnsupportedDeviceException( - f"Unsupported device for {cparams.host}: " - + f"{cparams.connection_type.device_family.value}" + f"Unsupported device for {config.host}: " + + f"{config.connection_type.device_family.value}" ) device_class: Optional[Type[SmartDevice]] @@ -54,20 +54,20 @@ def _perf_log(has_params, perf_type): info = await protocol.query(GET_SYSINFO_QUERY) _perf_log(True, "get_sysinfo") device_class = get_device_class_from_sys_info(info) - device = device_class(cparams.host, port=cparams.port, timeout=cparams.timeout) + device = device_class(config.host, port=config.port, timeout=config.timeout) device.update_from_discover_info(info) device.protocol = protocol await device.update() _perf_log(True, "update") return device elif device_class := get_device_class_from_family( - cparams.connection_type.device_family.value + config.connection_type.device_family.value ): device = device_class( - cparams.host, - port=cparams.port, - timeout=cparams.timeout, - credentials=cparams.credentials, + config.host, + port=config.port, + timeout=config.timeout, + credentials=config.credentials, ) device.protocol = protocol await device.update() @@ -75,8 +75,8 @@ def _perf_log(has_params, perf_type): return device else: raise UnsupportedDeviceException( - f"Unsupported device for {cparams.host}: " - + f"{cparams.connection_type.device_family.value}" + f"Unsupported device for {config.host}: " + + f"{config.connection_type.device_family.value}" ) diff --git a/kasa/device_type.py b/kasa/device_type.py old mode 100644 new mode 100755 diff --git a/kasa/connectionparams.py b/kasa/deviceconfig.py similarity index 95% rename from kasa/connectionparams.py rename to kasa/deviceconfig.py index 27c85423d..c6994286b 100644 --- a/kasa/connectionparams.py +++ b/kasa/deviceconfig.py @@ -109,7 +109,7 @@ def to_dict(self) -> Dict[str, str]: @dataclass -class ConnectionParameters: +class DeviceConfig: """Class to represent paramaters that determine how to connect to devices.""" DEFAULT_TIMEOUT = 5 @@ -126,6 +126,7 @@ class ConnectionParameters: ) ) + uses_http: bool = False # compare=False will be excluded from the serialization and object comparison. http_client: Optional[httpx.AsyncClient] = field(default=None, compare=False) @@ -142,6 +143,6 @@ def to_dict(self) -> Dict[str, Dict[str, str]]: return _dataclass_to_dict(self) @staticmethod - def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "ConnectionParameters": + def from_dict(cparam_dict: Dict[str, Dict[str, str]]) -> "DeviceConfig": """Return connection parameters from dict.""" - return _dataclass_from_dict(ConnectionParameters, cparam_dict) + return _dataclass_from_dict(DeviceConfig, cparam_dict) diff --git a/kasa/discover.py b/kasa/discover.py index a299fd122..21a681752 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -7,23 +7,21 @@ import socket from typing import Awaitable, Callable, Dict, Optional, Set, Type, cast -import httpx - # When support for cpython older than 3.11 is dropped # async_timeout can be replaced with asyncio.timeout from async_timeout import timeout as asyncio_timeout try: - from pydantic.v1 import BaseModel, ValidationError + from pydantic.v1 import BaseModel, ValidationError # pragma: no cover except ImportError: - from pydantic import BaseModel, ValidationError + from pydantic import BaseModel, ValidationError # pragma: no cover -from kasa.connectionparams import ConnectionParameters, ConnectionType, EncryptType from kasa.credentials import Credentials from kasa.device_factory import ( get_device_class_from_family, get_device_class_from_sys_info, ) +from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads @@ -63,7 +61,6 @@ def __init__( discovered_event: Optional[asyncio.Event] = None, credentials: Optional[Credentials] = None, timeout: Optional[int] = None, - http_client_generator: Optional[Callable[[], httpx.AsyncClient]] = None, ) -> None: self.transport = None self.discovery_packets = discovery_packets @@ -83,9 +80,6 @@ def __init__( self.credentials = credentials self.timeout = timeout self.seen_hosts: Set[str] = set() - self.http_client_generator: Optional[ - Callable[[], httpx.AsyncClient] - ] = http_client_generator def connection_made(self, transport) -> None: """Set socket options for broadcasting.""" @@ -124,18 +118,17 @@ def datagram_received(self, data, addr) -> None: device = None - cparams = ConnectionParameters(host=ip, port=self.port) + config = DeviceConfig(host=ip, port=self.port) if self.credentials: - cparams.credentials = self.credentials + config.credentials = self.credentials if self.timeout: - cparams.timeout = self.timeout + config.timeout = self.timeout try: if port == self.discovery_port: - device = Discover._get_device_instance_legacy(data, cparams) + device = Discover._get_device_instance_legacy(data, config) elif port == Discover.DISCOVERY_PORT_2: - if self.http_client_generator: - cparams.http_client = self.http_client_generator() - device = Discover._get_device_instance(data, cparams) + config.uses_http = True + device = Discover._get_device_instance(data, config) else: return except UnsupportedDeviceException as udex: @@ -226,7 +219,6 @@ async def discover( credentials=None, port=None, timeout=None, - http_client_generator: Optional[Callable[[], httpx.AsyncClient]] = None, ) -> DeviceDict: """Discover supported devices. @@ -263,7 +255,6 @@ async def discover( credentials=credentials, timeout=timeout, port=port, - http_client_generator=http_client_generator, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) @@ -287,7 +278,6 @@ async def discover_single( port: Optional[int] = None, timeout: Optional[int] = None, credentials: Optional[Credentials] = None, - httpx_asyncclient: httpx.AsyncClient = None, ) -> SmartDevice: """Discover a single device by the given IP address. @@ -337,9 +327,6 @@ async def discover_single( discovered_event=event, credentials=credentials, timeout=timeout, - http_client_generator=lambda: httpx_asyncclient - if httpx_asyncclient - else None, ), local_addr=("0.0.0.0", 0), # noqa: S104 ) @@ -362,8 +349,6 @@ async def discover_single( if ip in protocol.discovered_devices: dev = protocol.discovered_devices[ip] dev.host = host - if httpx_asyncclient and hasattr(dev.protocol._transport, "http_client"): - dev.protocol._transport.http_client = httpx_asyncclient # type: ignore[union-attr] return dev elif ip in protocol.unsupported_device_exceptions: raise protocol.unsupported_device_exceptions[ip] @@ -388,86 +373,84 @@ def _get_device_class(info: dict) -> Type[SmartDevice]: return get_device_class_from_sys_info(info) @staticmethod - def _get_device_instance_legacy( - data: bytes, cparams: ConnectionParameters - ) -> SmartDevice: + def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevice: """Get SmartDevice from legacy 9999 response.""" try: info = json_loads(TPLinkSmartHomeProtocol.decrypt(data)) except Exception as ex: raise SmartDeviceException( - f"Unable to read response from device: {cparams.host}: {ex}" + f"Unable to read response from device: {config.host}: {ex}" ) from ex - _LOGGER.debug("[DISCOVERY] %s << %s", cparams.host, info) + _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info) device_class = Discover._get_device_class(info) - device = device_class(cparams.host, port=cparams.port) + device = device_class(config.host, port=config.port) sys_info = info["system"]["get_sysinfo"] if (device_type := sys_info.get("mic_type")) or ( device_type := sys_info.get("type") ): - cparams.connection_type = ConnectionType.from_values( + config.connection_type = ConnectionType.from_values( device_family=device_type, encryption_type=EncryptType.Xor.value ) - device.protocol = get_protocol(cparams) # type: ignore[assignment] + device.protocol = get_protocol(config) # type: ignore[assignment] device.update_from_discover_info(info) return device @staticmethod def _get_device_instance( data: bytes, - cparams: ConnectionParameters, + config: DeviceConfig, ) -> SmartDevice: """Get SmartDevice from the new 20002 response.""" try: info = json_loads(data[16:]) except Exception as ex: - _LOGGER.debug("Got invalid response from device %s: %s", cparams.host, data) + _LOGGER.debug("Got invalid response from device %s: %s", config.host, data) raise SmartDeviceException( - f"Unable to read response from device: {cparams.host}: {ex}" + f"Unable to read response from device: {config.host}: {ex}" ) from ex try: discovery_result = DiscoveryResult(**info["result"]) except ValidationError as ex: _LOGGER.debug( - "Unable to parse discovery from device %s: %s", cparams.host, info + "Unable to parse discovery from device %s: %s", config.host, info ) raise UnsupportedDeviceException( - f"Unable to parse discovery from device: {cparams.host}: {ex}" + f"Unable to parse discovery from device: {config.host}: {ex}" ) from ex type_ = discovery_result.device_type try: - cparams.connection_type = ConnectionType.from_values( + config.connection_type = ConnectionType.from_values( type_, discovery_result.mgt_encrypt_schm.encrypt_type ) except SmartDeviceException as ex: raise UnsupportedDeviceException( - f"Unsupported device {cparams.host} of type {type_} " + f"Unsupported device {config.host} of type {type_} " + f"with encrypt_type {discovery_result.mgt_encrypt_schm.encrypt_type}", discovery_result=discovery_result.get_dict(), ) from ex if (device_class := get_device_class_from_family(type_)) is None: _LOGGER.warning("Got unsupported device type: %s", type_) raise UnsupportedDeviceException( - f"Unsupported device {cparams.host} of type {type_}: {info}", + f"Unsupported device {config.host} of type {type_}: {info}", discovery_result=discovery_result.get_dict(), ) - if (protocol := get_protocol(cparams)) is None: + if (protocol := get_protocol(config)) is None: _LOGGER.warning( - "Got unsupported connection type: %s", cparams.connection_type.to_dict() + "Got unsupported connection type: %s", config.connection_type.to_dict() ) raise UnsupportedDeviceException( - f"Unsupported encryption scheme {cparams.host} of " - + f"type {cparams.connection_type.to_dict()}: {info}", + f"Unsupported encryption scheme {config.host} of " + + f"type {config.connection_type.to_dict()}: {info}", discovery_result=discovery_result.get_dict(), ) - _LOGGER.debug("[DISCOVERY] %s << %s", cparams.host, info) + _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info) device = device_class( - cparams.host, port=cparams.port, credentials=cparams.credentials + config.host, port=config.port, credentials=config.credentials ) device.protocol = protocol diff --git a/kasa/iotprotocol.py b/kasa/iotprotocol.py index e78d24e74..470f40552 100755 --- a/kasa/iotprotocol.py +++ b/kasa/iotprotocol.py @@ -40,16 +40,19 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: return await self._execute_query(request, retry) except httpx.ConnectError as sdex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {sdex}" ) from sdex continue except TimeoutError as tex: + await self.close() raise SmartDeviceException( f"Unable to connect to the device, timed out: {self._host}: {tex}" ) from tex except AuthenticationException as auex: + await self.close() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) @@ -63,6 +66,7 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict: raise ex except Exception as ex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {ex}" diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 6856521b5..3704f0efb 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -53,8 +53,8 @@ from cryptography.hazmat.primitives import hashes, padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from .connectionparams import ConnectionParameters from .credentials import Credentials +from .deviceconfig import DeviceConfig from .exceptions import AuthenticationException, SmartDeviceException from .json import loads as json_loads from .protocol import BaseTransport, md5 @@ -93,13 +93,11 @@ class KlapTransport(BaseTransport): def __init__( self, *, - cparams: ConnectionParameters, + config: DeviceConfig, ) -> None: - super().__init__(cparams=cparams) - self._port = cparams.port or self.DEFAULT_PORT - self._http_client: httpx.AsyncClient = ( - cparams.http_client or httpx.AsyncClient() - ) + super().__init__(config=config) + self._port = config.port or self.DEFAULT_PORT + self._default_http_client: Optional[httpx.AsyncClient] = None self._local_seed: Optional[bytes] = None self._local_auth_hash = self.generate_auth_hash(self._credentials) self._local_auth_owner = self.generate_owner_hash(self._credentials).hex() @@ -116,6 +114,14 @@ def __init__( _LOGGER.debug("Created KLAP transport for %s", self._host) + @property + def _http_client(self) -> httpx.AsyncClient: + if self._config.http_client: + return self._config.http_client + if not self._default_http_client: + self._default_http_client = httpx.AsyncClient() + return self._default_http_client + async def client_post(self, url, params=None, data=None): """Send an http post request to the device.""" response_data = None @@ -349,8 +355,8 @@ async def send(self, request: str): async def close(self) -> None: """Close the transport.""" - client = self._http_client - self._http_client = None + client = self._default_http_client + self._default_http_client = None self._handshake_done = False if client: await client.aclose() diff --git a/kasa/protocol.py b/kasa/protocol.py index 501b33fe5..4c9570637 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -24,7 +24,7 @@ from async_timeout import timeout as asyncio_timeout from cryptography.hazmat.primitives import hashes -from .connectionparams import ConnectionParameters +from .deviceconfig import DeviceConfig from .exceptions import SmartDeviceException from .json import dumps as json_dumps from .json import loads as json_loads @@ -49,14 +49,14 @@ class BaseTransport(ABC): def __init__( self, *, - cparams: ConnectionParameters, + config: DeviceConfig, ) -> None: """Create a protocol object.""" - self._cparams = cparams - self._host = cparams.host - self._port = cparams.port # Set by derived classes - self._credentials = cparams.credentials - self._timeout = cparams.timeout + self._config = config + self._host = config.host + self._port = config.port # Set by derived classes + self._credentials = config.credentials + self._timeout = config.timeout @abstractmethod async def send(self, request: str) -> Dict: @@ -83,9 +83,9 @@ def _host(self): return self._transport._host @property - def connection_parameters(self) -> ConnectionParameters: + def config(self) -> DeviceConfig: """Return the connection parameters the device is using.""" - return self._transport._cparams + return self._transport._config @abstractmethod async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: @@ -107,9 +107,9 @@ class _XorTransport(BaseTransport): DEFAULT_PORT = 9999 - def __init__(self, *, cparams: ConnectionParameters) -> None: - super().__init__(cparams=cparams) - self._port = cparams.port or self.DEFAULT_PORT + def __init__(self, *, config: DeviceConfig) -> None: + super().__init__(config=config) + self._port = config.port or self.DEFAULT_PORT async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" diff --git a/kasa/protocolfactory.py b/kasa/protocolfactory.py index 5d76bd3db..d30a986eb 100644 --- a/kasa/protocolfactory.py +++ b/kasa/protocolfactory.py @@ -2,7 +2,7 @@ from typing import Optional, Tuple, Type from .aestransport import AesTransport -from .connectionparams import ConnectionParameters +from .deviceconfig import DeviceConfig from .iotprotocol import IotProtocol from .klaptransport import KlapTransport, KlapTransportV2 from .protocol import ( @@ -15,12 +15,12 @@ def get_protocol( - cparams: ConnectionParameters, + config: DeviceConfig, ) -> Optional[TPLinkProtocol]: """Return the protocol from the connection name.""" - protocol_name = cparams.connection_type.device_family.value.split(".")[0] + protocol_name = config.connection_type.device_family.value.split(".")[0] protocol_transport_key = ( - protocol_name + "." + cparams.connection_type.encryption_type.value + protocol_name + "." + config.connection_type.encryption_type.value ) supported_device_protocols: dict[ str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] @@ -36,4 +36,4 @@ def get_protocol( protocol_class, transport_class = supported_device_protocols.get( protocol_transport_key ) # type: ignore - return protocol_class(transport=transport_class(cparams=cparams)) + return protocol_class(transport=transport_class(config=config)) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index bbc7e9a0c..d7a7be620 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -19,9 +19,9 @@ from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Set -from .connectionparams import ConnectionParameters from .credentials import Credentials from .device_type import DeviceType +from .deviceconfig import DeviceConfig from .emeterstatus import EmeterStatus from .exceptions import SmartDeviceException from .modules import Emeter, Module @@ -202,9 +202,9 @@ def __init__( """ self.host = host self.port = port - cparams = ConnectionParameters(host=host, port=port, timeout=timeout) + config = DeviceConfig(host=host, port=port, timeout=timeout) self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( - transport=_XorTransport(cparams=cparams), + transport=_XorTransport(config=config), ) self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) @@ -774,15 +774,15 @@ def __repr__(self): ) @property - def connection_parameters(self) -> ConnectionParameters: + def config(self) -> DeviceConfig: """Return the connection parameters the device is using.""" - return self.protocol.connection_parameters + return self.protocol.config @staticmethod async def connect( *, host: Optional[str] = None, - cparams: Optional[ConnectionParameters] = None, + config: Optional[DeviceConfig] = None, ) -> "SmartDevice": """Connect to a single device by the given hostname or connection parameters. @@ -797,17 +797,17 @@ async def connect( The device type is discovered by querying the device. :param host: Hostname of device to query - :param cparams: Connection parameters to ensure the correct protocol + :param config: Connection parameters to ensure the correct protocol and connection options are used. :rtype: SmartDevice :return: Object for querying/controlling found device. """ from .device_factory import connect # pylint: disable=import-outside-toplevel - if host and cparams or (not host and not cparams): + if host and config or (not host and not config): raise SmartDeviceException( - "One of host or cparams must be provded and not both" + "One of host or config must be provded and not both" ) if host: - cparams = ConnectionParameters(host=host) - return await connect(cparams=cparams) # type: ignore[arg-type] + config = DeviceConfig(host=host) + return await connect(config=config) # type: ignore[arg-type] diff --git a/kasa/smartprotocol.py b/kasa/smartprotocol.py index b15645645..97573d933 100644 --- a/kasa/smartprotocol.py +++ b/kasa/smartprotocol.py @@ -69,6 +69,7 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: return await self._execute_query(request, retry) except httpx.ConnectError as sdex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {sdex}" @@ -76,6 +77,7 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: continue except TimeoutError as tex: if retry >= retry_count: + await self.close() raise SmartDeviceException( "Unable to connect to the device, " + f"timed out: {self._host}: {tex}" @@ -83,17 +85,20 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) continue except AuthenticationException as auex: + await self.close() _LOGGER.debug( "Unable to authenticate with %s, not retrying", self._host ) raise auex except RetryableException as ex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex continue except TimeoutException as ex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise ex await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT) @@ -105,6 +110,7 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict: raise ex except Exception as ex: if retry >= retry_count: + await self.close() _LOGGER.debug("Giving up on %s after %s retries", self._host, retry) raise SmartDeviceException( f"Unable to connect to the device: {self._host}: {ex}" diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index c80542c4c..e42448595 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Optional, Set, cast from ..aestransport import AesTransport -from ..connectionparams import ConnectionParameters from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import AuthenticationException from ..smartdevice import SmartDevice from ..smartprotocol import SmartProtocol @@ -29,14 +29,14 @@ def __init__( self._components: Optional[Dict[str, Any]] = None self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - cparams = ConnectionParameters( + config = DeviceConfig( host=host, port=port, credentials=credentials, # type: ignore[arg-type] timeout=timeout, ) self.protocol = SmartProtocol( - transport=AesTransport(cparams=cparams), + transport=AesTransport(config=config), ) async def update(self, update_children: bool = True): diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 8c23e4c14..2320f1161 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -388,7 +388,7 @@ def load_file(): d = device_for_file(model, protocol)(host="127.0.0.123") if protocol == "SMART": d.protocol = FakeSmartProtocol(sysinfo) - d.credentials = Credentials("", "") + d.credentials = Credentials() else: d.protocol = FakeTransportProtocol(sysinfo) await _update_and_close(d) diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 4f9077ea4..53748700f 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -15,8 +15,8 @@ Schema, ) -from ..connectionparams import ConnectionParameters from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..protocol import BaseTransport, TPLinkSmartHomeProtocol from ..smartprotocol import SmartProtocol @@ -305,9 +305,7 @@ async def query(self, request, retry_count: int = 3): class FakeSmartTransport(BaseTransport): def __init__(self, info): super().__init__( - cparams=ConnectionParameters( - "127.0.0.123", credentials=Credentials("", "") - ), + config=DeviceConfig("127.0.0.123", credentials=Credentials()), ) self.info = info diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index f42181ae7..faf47a75e 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -11,8 +11,8 @@ from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding from ..aestransport import AesEncyptionSession, AesTransport -from ..connectionparams import ConnectionParameters from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, SMART_TIMEOUT_ERRORS, @@ -60,7 +60,7 @@ async def test_handshake( mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) transport = AesTransport( - cparams=ConnectionParameters(host, credentials=Credentials("foo", "bar")) + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) assert transport._encryption_session is None @@ -78,7 +78,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) transport = AesTransport( - cparams=ConnectionParameters(host, credentials=Credentials("foo", "bar")) + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 @@ -97,7 +97,7 @@ async def test_send(mocker, status_code, error_code, inner_error_code, expectati mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) transport = AesTransport( - cparams=ConnectionParameters(host, credentials=Credentials("foo", "bar")) + config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 @@ -125,7 +125,8 @@ async def test_passthrough_errors(mocker, error_code): mock_aes_device = MockAesDevice(host, 200, error_code, 0) mocker.patch.object(httpx.AsyncClient, "post", side_effect=mock_aes_device.post) - transport = AesTransport(host=host, credentials=Credentials("foo", "bar")) + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + transport = AesTransport(config=config) transport._handshake_done = True transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session diff --git a/kasa/tests/test_connectionparams.py b/kasa/tests/test_connectionparams.py deleted file mode 100644 index baede2e95..000000000 --- a/kasa/tests/test_connectionparams.py +++ /dev/null @@ -1,21 +0,0 @@ -from json import dumps as json_dumps -from json import loads as json_loads - -import httpx - -from kasa.connectionparams import ( - ConnectionParameters, - ConnectionType, - DeviceFamilyType, - EncryptType, -) -from kasa.credentials import Credentials - - -def test_serialization(): - cp = ConnectionParameters(host="Foo", http_client=httpx.AsyncClient()) - cp_dict = cp.to_dict() - cp_json = json_dumps(cp_dict) - cp2_dict = json_loads(cp_json) - cp2 = ConnectionParameters.from_dict(cp2_dict) - assert cp == cp2 diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index aa00d2728..f8e2aa01d 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -16,13 +16,13 @@ SmartLightStrip, SmartPlug, ) -from kasa.connectionparams import ( - ConnectionParameters, +from kasa.device_factory import connect +from kasa.deviceconfig import ( ConnectionType, + DeviceConfig, DeviceFamilyType, EncryptType, ) -from kasa.device_factory import connect from kasa.discover import DiscoveryResult from kasa.protocolfactory import get_protocol @@ -57,18 +57,18 @@ async def test_connect( mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) - cparams = ConnectionParameters( + config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) - protocol_class = get_protocol(cparams).__class__ + protocol_class = get_protocol(config).__class__ dev = await connect( - cparams=cparams, + config=config, ) assert isinstance(dev, device_class) assert isinstance(dev.protocol, protocol_class) - assert dev.connection_parameters == cparams + assert dev.config == config @pytest.mark.parametrize("custom_port", [123, None]) @@ -77,14 +77,14 @@ async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): host = "127.0.0.1" ctype, _ = _get_connection_type_device_class(all_fixture_data) - cparams = ConnectionParameters(host=host, port=custom_port, connection_type=ctype) + config = DeviceConfig(host=host, port=custom_port, connection_type=ctype) default_port = 80 if "discovery_result" in all_fixture_data else 9999 ctype, _ = _get_connection_type_device_class(all_fixture_data) mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.IotProtocol.query", return_value=all_fixture_data) mocker.patch("kasa.SmartProtocol.query", return_value=all_fixture_data) - dev = await connect(cparams=cparams) + dev = await connect(config=config) assert issubclass(dev.__class__, SmartDevice) assert dev.port == custom_port or dev.port == default_port @@ -99,12 +99,12 @@ async def test_connect_logs_connect_time( mocker.patch("kasa.TPLinkSmartHomeProtocol.query", return_value=all_fixture_data) host = "127.0.0.1" - cparams = ConnectionParameters( + config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) logging.getLogger("kasa").setLevel(logging.DEBUG) await connect( - cparams=cparams, + config=config, ) assert "seconds to update" in caplog.text @@ -117,11 +117,11 @@ async def test_connect_query_fails(all_fixture_data: dict, mocker): mocker.patch("kasa.SmartProtocol.query", side_effect=SmartDeviceException) ctype, _ = _get_connection_type_device_class(all_fixture_data) - cparams = ConnectionParameters( + config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) with pytest.raises(SmartDeviceException): - await connect(cparams=cparams) + await connect(config=config) async def test_connect_http_client(all_fixture_data, mocker): @@ -136,19 +136,19 @@ async def test_connect_http_client(all_fixture_data, mocker): http_client = httpx.AsyncClient() - cparams = ConnectionParameters( + config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype ) - dev = await connect(cparams=cparams) + dev = await connect(config=config) if ctype.encryption_type != EncryptType.Xor: assert dev.protocol._transport._http_client != http_client - cparams = ConnectionParameters( + config = DeviceConfig( host=host, credentials=Credentials("foor", "bar"), connection_type=ctype, http_client=http_client, ) - dev = await connect(cparams=cparams) + dev = await connect(config=config) if ctype.encryption_type != EncryptType.Xor: assert dev.protocol._transport._http_client == http_client diff --git a/kasa/tests/test_deviceconfig.py b/kasa/tests/test_deviceconfig.py new file mode 100644 index 000000000..7970449dd --- /dev/null +++ b/kasa/tests/test_deviceconfig.py @@ -0,0 +1,21 @@ +from json import dumps as json_dumps +from json import loads as json_loads + +import httpx + +from kasa.credentials import Credentials +from kasa.deviceconfig import ( + ConnectionType, + DeviceConfig, + DeviceFamilyType, + EncryptType, +) + + +def test_serialization(): + config = DeviceConfig(host="Foo", http_client=httpx.AsyncClient()) + config_dict = config.to_dict() + config_json = json_dumps(config_dict) + config2_dict = json_loads(config_json) + config2 = DeviceConfig.from_dict(config2_dict) + assert config == config2 diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index 970b750ec..d9b4631cf 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -14,9 +14,9 @@ SmartDeviceException, protocol, ) -from kasa.connectionparams import ( - ConnectionParameters, +from kasa.deviceconfig import ( ConnectionType, + DeviceConfig, DeviceFamilyType, EncryptType, ) @@ -102,7 +102,7 @@ async def test_discover_single(discovery_mock, custom_port, mocker): update_mock = mocker.patch.object(device_class, "update") x = await Discover.discover_single( - host, port=custom_port, credentials=Credentials("", "") + host, port=custom_port, credentials=Credentials() ) assert issubclass(x.__class__, SmartDevice) assert x._discovery_info is not None @@ -112,8 +112,11 @@ async def test_discover_single(discovery_mock, custom_port, mocker): ct = ConnectionType.from_values( discovery_mock.device_type, discovery_mock.encrypt_type ) - cp = ConnectionParameters(host=host, port=custom_port, connection_type=ct) - assert x.connection_parameters == cp + uses_http = discovery_mock.default_port == 20002 + config = DeviceConfig( + host=host, port=custom_port, connection_type=ct, uses_http=uses_http + ) + assert x.config == config async def test_discover_single_hostname(discovery_mock, mocker): @@ -125,7 +128,7 @@ async def test_discover_single_hostname(discovery_mock, mocker): device_class = Discover._get_device_class(discovery_mock.discovery_data) update_mock = mocker.patch.object(device_class, "update") - x = await Discover.discover_single(host, credentials=Credentials("", "")) + x = await Discover.discover_single(host, credentials=Credentials()) assert issubclass(x.__class__, SmartDevice) assert x._discovery_info is not None assert x.host == host @@ -133,7 +136,7 @@ async def test_discover_single_hostname(discovery_mock, mocker): mocker.patch("socket.getaddrinfo", side_effect=socket.gaierror()) with pytest.raises(SmartDeviceException): - x = await Discover.discover_single(host, credentials=Credentials("", "")) + x = await Discover.discover_single(host, credentials=Credentials()) async def test_discover_single_unsupported(unsupported_device_info, mocker): @@ -305,12 +308,13 @@ async def test_discover_single_http_client(discovery_mock, mocker): http_client = httpx.AsyncClient() - x = await Discover.discover_single(host) - if discovery_mock.default_port == 20002: - assert x.protocol._transport._http_client != http_client + x: SmartDevice = await Discover.discover_single(host) + + assert x.config.uses_http == (discovery_mock.default_port == 20002) - x = await Discover.discover_single(host, httpx_asyncclient=http_client) if discovery_mock.default_port == 20002: + assert x.protocol._transport._http_client != http_client + x.config.http_client = http_client assert x.protocol._transport._http_client == http_client @@ -321,13 +325,11 @@ async def test_discover_http_client(discovery_mock, mocker): http_client = httpx.AsyncClient() - def gen(): - return http_client + devices = await Discover.discover(discovery_timeout=0) + x: SmartDevice = devices[host] + assert x.config.uses_http == (discovery_mock.default_port == 20002) - devs = await Discover.discover(discovery_timeout=0) if discovery_mock.default_port == 20002: - assert devs[host].protocol._transport._http_client != http_client - - devs = await Discover.discover(discovery_timeout=0, http_client_generator=gen) - if discovery_mock.default_port == 20002: - assert devs[host].protocol._transport._http_client == http_client + assert x.protocol._transport._http_client != http_client + x.config.http_client = http_client + assert x.protocol._transport._http_client == http_client diff --git a/kasa/tests/test_klapprotocol.py b/kasa/tests/test_klapprotocol.py index 6263757d6..5108fef05 100644 --- a/kasa/tests/test_klapprotocol.py +++ b/kasa/tests/test_klapprotocol.py @@ -11,8 +11,8 @@ import pytest from ..aestransport import AesTransport -from ..connectionparams import ConnectionParameters from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import AuthenticationException, SmartDeviceException from ..iotprotocol import IotProtocol from ..klaptransport import ( @@ -37,8 +37,9 @@ def __init__(self, status_code, content: bytes): [ (Exception("dummy exception"), True), (SmartDeviceException("dummy exception"), False), + (httpx.ConnectError("dummy exception"), True), ], - ids=("Exception", "SmartDeviceException"), + ids=("Exception", "SmartDeviceException", "httpx.ConnectError"), ) @pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport]) @pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol]) @@ -49,9 +50,9 @@ async def test_protocol_retries( host = "127.0.0.1" conn = mocker.patch.object(httpx.AsyncClient, "post", side_effect=error) - cparams = ConnectionParameters(host) + config = DeviceConfig(host) with pytest.raises(SmartDeviceException): - await protocol_class(transport=transport_class(cparams=cparams)).query( + await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) @@ -70,9 +71,9 @@ async def test_protocol_no_retry_on_connection_error( "post", side_effect=AuthenticationException("foo"), ) - cparams = ConnectionParameters(host) + config = DeviceConfig(host) with pytest.raises(SmartDeviceException): - await protocol_class(transport=transport_class(cparams=cparams)).query( + await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=5 ) @@ -90,9 +91,9 @@ async def test_protocol_retry_recoverable_error( "post", side_effect=httpx.CloseError("foo"), ) - cparams = ConnectionParameters(host) + config = DeviceConfig(host) with pytest.raises(SmartDeviceException): - await protocol_class(transport=transport_class(cparams=cparams)).query( + await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=5 ) @@ -125,8 +126,8 @@ def _fail_one_less_than_retry_count(*_, **__): side_effect=_fail_one_less_than_retry_count, ) - cparams = ConnectionParameters(host) - response = await protocol_class(transport=transport_class(cparams=cparams)).query( + config = DeviceConfig(host) + response = await protocol_class(transport=transport_class(config=config)).query( DUMMY_QUERY, retry_count=retry_count ) assert "result" in response or "foobar" in response @@ -148,8 +149,8 @@ def _return_encrypted(*_, **__): auth_hash = KlapTransport.generate_auth_hash(Credentials("foo", "bar")) encryption_session = KlapEncryptionSession(seed, seed, auth_hash) - cparams = ConnectionParameters("127.0.0.1") - protocol = IotProtocol(transport=KlapTransport(cparams=cparams)) + config = DeviceConfig("127.0.0.1") + protocol = IotProtocol(transport=KlapTransport(config=config)) protocol._transport._handshake_done = True protocol._transport._session_expire_at = time.time() + 86400 @@ -194,7 +195,7 @@ def test_encrypt_unicode(): "device_credentials, expectation", [ (Credentials("foo", "bar"), does_not_raise()), - (Credentials("", ""), does_not_raise()), + (Credentials(), does_not_raise()), ( Credentials( KlapTransport.KASA_SETUP_EMAIL, @@ -237,8 +238,8 @@ async def _return_handshake1_response(url, params=None, data=None, *_, **__): httpx.AsyncClient, "post", side_effect=_return_handshake1_response ) - cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) - protocol = IotProtocol(transport=transport_class(cparams=cparams)) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=transport_class(config=config)) with expectation: ( @@ -296,8 +297,8 @@ async def _return_handshake_response(url, params=None, data=None, *_, **__): httpx.AsyncClient, "post", side_effect=_return_handshake_response ) - cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) - protocol = IotProtocol(transport=transport_class(cparams=cparams)) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=transport_class(config=config)) protocol._transport.http_client = httpx.AsyncClient() response_status = 200 @@ -343,8 +344,8 @@ async def _return_response(url, params=None, data=None, *_, **__): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) - protocol = IotProtocol(transport=KlapTransport(cparams=cparams)) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=KlapTransport(config=config)) for _ in range(10): resp = await protocol.query({}) @@ -388,8 +389,8 @@ async def _return_response(url, params=None, data=None, *_, **__): mocker.patch.object(httpx.AsyncClient, "post", side_effect=_return_response) - cparams = ConnectionParameters("127.0.0.1", credentials=client_credentials) - protocol = IotProtocol(transport=KlapTransport(cparams=cparams)) + config = DeviceConfig("127.0.0.1", credentials=client_credentials) + protocol = IotProtocol(transport=KlapTransport(config=config)) with expectation: await protocol.query({}) diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 550e2e758..80289ccbc 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -9,7 +9,7 @@ import pytest -from ..connectionparams import ConnectionParameters +from ..deviceconfig import DeviceConfig from ..exceptions import SmartDeviceException from ..protocol import ( BaseTransport, @@ -32,9 +32,9 @@ def aio_mock_writer(_, __): return reader, writer conn = mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) - cparams = ConnectionParameters("127.0.0.1") + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( {}, retry_count=retry_count ) @@ -46,9 +46,9 @@ async def test_protocol_no_retry_on_unreachable(mocker): "asyncio.open_connection", side_effect=OSError(errno.EHOSTUNREACH, "No route to host"), ) - cparams = ConnectionParameters("127.0.0.1") + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( {}, retry_count=5 ) @@ -60,9 +60,9 @@ async def test_protocol_no_retry_connection_refused(mocker): "asyncio.open_connection", side_effect=ConnectionRefusedError, ) - cparams = ConnectionParameters("127.0.0.1") + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( {}, retry_count=5 ) @@ -74,9 +74,9 @@ async def test_protocol_retry_recoverable_error(mocker): "asyncio.open_connection", side_effect=OSError(errno.ECONNRESET, "Connection reset by peer"), ) - cparams = ConnectionParameters("127.0.0.1") + config = DeviceConfig("127.0.0.1") with pytest.raises(SmartDeviceException): - await TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)).query( + await TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)).query( {}, retry_count=5 ) @@ -112,8 +112,8 @@ def aio_mock_writer(_, __): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - cparams = ConnectionParameters("127.0.0.1") - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)) + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}, retry_count=retry_count) assert response == {"great": "success"} @@ -141,8 +141,8 @@ def aio_mock_writer(_, __): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - cparams = ConnectionParameters("127.0.0.1") - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)) + config = DeviceConfig("127.0.0.1") + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -176,8 +176,8 @@ def aio_mock_writer(_, port): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - cparams = ConnectionParameters("127.0.0.1", port=custom_port) - protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(cparams=cparams)) + config = DeviceConfig("127.0.0.1", port=custom_port) + protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) assert response == {"great": "success"} @@ -296,5 +296,5 @@ def test_transport_init_signature(class_name_obj): and params[0].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) assert ( - params[1].name == "cparams" and params[1].kind == inspect.Parameter.KEYWORD_ONLY + params[1].name == "config" and params[1].kind == inspect.Parameter.KEYWORD_ONLY ) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 46b9220c2..0cc8925c6 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -5,7 +5,7 @@ import pytest # type: ignore # https://github.com/pytest-dev/pytest/issues/3342 import kasa -from kasa import ConnectionParameters, Credentials, SmartDevice, SmartDeviceException +from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException from kasa.smartdevice import DeviceType from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on @@ -238,18 +238,18 @@ async def test_create_smart_device_with_timeout(): async def test_create_thin_wrapper(): """Make sure thin wrapper is created with the correct device type.""" mock = Mock() - cparams = ConnectionParameters( + config = DeviceConfig( host="test_host", port=1234, timeout=100, credentials=Credentials("username", "password"), ) with patch("kasa.device_factory.connect", return_value=mock) as connect: - dev = await SmartDevice.connect(cparams=cparams) + dev = await SmartDevice.connect(config=config) assert dev is mock connect.assert_called_once_with( - cparams=cparams, + config=config, ) diff --git a/kasa/tests/test_smartprotocol.py b/kasa/tests/test_smartprotocol.py index 5dbbed279..301e367f5 100644 --- a/kasa/tests/test_smartprotocol.py +++ b/kasa/tests/test_smartprotocol.py @@ -13,6 +13,7 @@ from ..aestransport import AesTransport from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..exceptions import ( SMART_RETRYABLE_ERRORS, SMART_TIMEOUT_ERRORS, @@ -37,7 +38,8 @@ async def test_smart_device_errors(mocker, error_code): send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) - protocol = SmartProtocol(host, transport=AesTransport(host)) + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + protocol = SmartProtocol(transport=AesTransport(config=config)) with pytest.raises(SmartDeviceException): await protocol.query(DUMMY_QUERY, retry_count=2) @@ -70,8 +72,8 @@ async def test_smart_device_errors_in_multiple_request(mocker, error_code): mocker.patch.object(AesTransport, "perform_login") send_mock = mocker.patch.object(AesTransport, "send", return_value=mock_response) - - protocol = SmartProtocol(host, transport=AesTransport(host)) + config = DeviceConfig(host, credentials=Credentials("foo", "bar")) + protocol = SmartProtocol(transport=AesTransport(config=config)) with pytest.raises(SmartDeviceException): await protocol.query(DUMMY_QUERY, retry_count=2) if error_code in chain(SMART_TIMEOUT_ERRORS, SMART_RETRYABLE_ERRORS): From 5896e2064a4e888140fcbff175ca74790ce11287 Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Thu, 28 Dec 2023 11:56:25 +0000 Subject: [PATCH 3/5] Further update post latest review --- kasa/aestransport.py | 8 +++++-- kasa/cli.py | 6 ++--- kasa/device_factory.py | 11 ++------- kasa/deviceconfig.py | 2 +- kasa/discover.py | 9 +++---- kasa/klaptransport.py | 9 +++++-- kasa/protocol.py | 15 +++++++++--- kasa/smartbulb.py | 10 ++++---- kasa/smartdevice.py | 39 ++++++++++++++++++++++++------- kasa/smartdimmer.py | 10 ++++---- kasa/smartlightstrip.py | 10 ++++---- kasa/smartplug.py | 10 ++++---- kasa/smartstrip.py | 10 ++++---- kasa/tapo/tapodevice.py | 21 ++++++----------- kasa/tapo/tapoplug.py | 10 ++++---- kasa/tests/conftest.py | 8 ++++--- kasa/tests/newfakes.py | 12 +++++++++- kasa/tests/test_device_factory.py | 2 +- kasa/tests/test_discovery.py | 12 +++++----- kasa/tests/test_protocol.py | 2 +- kasa/tests/test_smartdevice.py | 9 +++---- 21 files changed, 130 insertions(+), 95 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 117af63bc..b6fa34723 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -47,7 +47,7 @@ class AesTransport(BaseTransport): protocol, sometimes used by newer firmware versions on kasa devices. """ - DEFAULT_PORT = 80 + DEFAULT_PORT: int = 80 SESSION_COOKIE_NAME = "TP_SESSIONID" COMMON_HEADERS = { "Content-Type": "application/json", @@ -61,7 +61,6 @@ def __init__( config: DeviceConfig, ) -> None: super().__init__(config=config) - self._port = config.port or self.DEFAULT_PORT self._default_http_client: Optional[httpx.AsyncClient] = None @@ -76,6 +75,11 @@ def __init__( _LOGGER.debug("Created AES transport for %s", self._host) + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT + @property def _http_client(self) -> httpx.AsyncClient: if self._config.http_client: diff --git a/kasa/cli.py b/kasa/cli.py index 821e78693..13458b0e0 100755 --- a/kasa/cli.py +++ b/kasa/cli.py @@ -65,8 +65,8 @@ def wrapper(message=None, *args, **kwargs): ENCRYPT_TYPES = [encrypt_type.value for encrypt_type in EncryptType] -TPLINK_DEVICE_TYPES = [ - tplink_device_type.value for tplink_device_type in DeviceFamilyType +DEVICE_FAMILY_TYPES = [ + device_family_type.value for device_family_type in DeviceFamilyType ] click.anyio_backend = "asyncio" @@ -182,7 +182,7 @@ def _device_to_serializable(val: SmartDevice): "--device-family", envvar="KASA_DEVICE_FAMILY", default=None, - type=click.Choice(TPLINK_DEVICE_TYPES, case_sensitive=False), + type=click.Choice(DEVICE_FAMILY_TYPES, case_sensitive=False), ) @click.option( "--timeout", diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 254c4509c..cd6fe8bb0 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -54,22 +54,15 @@ def _perf_log(has_params, perf_type): info = await protocol.query(GET_SYSINFO_QUERY) _perf_log(True, "get_sysinfo") device_class = get_device_class_from_sys_info(info) - device = device_class(config.host, port=config.port, timeout=config.timeout) + device = device_class(config.host, protocol=protocol) device.update_from_discover_info(info) - device.protocol = protocol await device.update() _perf_log(True, "update") return device elif device_class := get_device_class_from_family( config.connection_type.device_family.value ): - device = device_class( - config.host, - port=config.port, - timeout=config.timeout, - credentials=config.credentials, - ) - device.protocol = protocol + device = device_class(host=config.host, protocol=protocol) await device.update() _perf_log(True, "update") return device diff --git a/kasa/deviceconfig.py b/kasa/deviceconfig.py index c6994286b..7a774b2ea 100644 --- a/kasa/deviceconfig.py +++ b/kasa/deviceconfig.py @@ -116,7 +116,7 @@ class DeviceConfig: host: str timeout: Optional[int] = DEFAULT_TIMEOUT - port: Optional[int] = None + port_override: Optional[int] = None credentials: Credentials = field( default_factory=lambda: Credentials(username="", password="") ) diff --git a/kasa/discover.py b/kasa/discover.py index 21a681752..f6189dd74 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -118,7 +118,7 @@ def datagram_received(self, data, addr) -> None: device = None - config = DeviceConfig(host=ip, port=self.port) + config = DeviceConfig(host=ip, port_override=self.port) if self.credentials: config.credentials = self.credentials if self.timeout: @@ -385,7 +385,7 @@ def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevic _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info) device_class = Discover._get_device_class(info) - device = device_class(config.host, port=config.port) + device = device_class(config.host, config=config) sys_info = info["system"]["get_sysinfo"] if (device_type := sys_info.get("mic_type")) or ( device_type := sys_info.get("type") @@ -449,10 +449,7 @@ def _get_device_instance( ) _LOGGER.debug("[DISCOVERY] %s << %s", config.host, info) - device = device_class( - config.host, port=config.port, credentials=config.credentials - ) - device.protocol = protocol + device = device_class(config.host, protocol=protocol) di = discovery_result.get_dict() di["model"] = discovery_result.device_model diff --git a/kasa/klaptransport.py b/kasa/klaptransport.py index 3704f0efb..0e7ef565a 100644 --- a/kasa/klaptransport.py +++ b/kasa/klaptransport.py @@ -83,7 +83,7 @@ class KlapTransport(BaseTransport): protocol, used by newer firmware versions. """ - DEFAULT_PORT = 80 + DEFAULT_PORT: int = 80 DISCOVERY_QUERY = {"system": {"get_sysinfo": None}} KASA_SETUP_EMAIL = "kasa@tp-link.net" @@ -96,7 +96,7 @@ def __init__( config: DeviceConfig, ) -> None: super().__init__(config=config) - self._port = config.port or self.DEFAULT_PORT + self._default_http_client: Optional[httpx.AsyncClient] = None self._local_seed: Optional[bytes] = None self._local_auth_hash = self.generate_auth_hash(self._credentials) @@ -114,6 +114,11 @@ def __init__( _LOGGER.debug("Created KLAP transport for %s", self._host) + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT + @property def _http_client(self) -> httpx.AsyncClient: if self._config.http_client: diff --git a/kasa/protocol.py b/kasa/protocol.py index 4c9570637..c998807c5 100755 --- a/kasa/protocol.py +++ b/kasa/protocol.py @@ -54,10 +54,15 @@ def __init__( """Create a protocol object.""" self._config = config self._host = config.host - self._port = config.port # Set by derived classes + self._port = config.port_override or self.default_port self._credentials = config.credentials self._timeout = config.timeout + @property + @abstractmethod + def default_port(self) -> int: + """The default port for the transport.""" + @abstractmethod async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" @@ -105,11 +110,15 @@ class _XorTransport(BaseTransport): class. """ - DEFAULT_PORT = 9999 + DEFAULT_PORT: int = 9999 def __init__(self, *, config: DeviceConfig) -> None: super().__init__(config=config) - self._port = config.port or self.DEFAULT_PORT + + @property + def default_port(self): + """Default port for the transport.""" + return self.DEFAULT_PORT async def send(self, request: str) -> Dict: """Send a message to the device and return a response.""" diff --git a/kasa/smartbulb.py b/kasa/smartbulb.py index 6dd4513c6..8897ceceb 100644 --- a/kasa/smartbulb.py +++ b/kasa/smartbulb.py @@ -9,8 +9,9 @@ except ImportError: from pydantic import BaseModel, Field, root_validator -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .modules import Antitheft, Cloud, Countdown, Emeter, Schedule, Time, Usage +from .protocol import TPLinkProtocol from .smartdevice import DeviceType, SmartDevice, SmartDeviceException, requires_update @@ -220,11 +221,10 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Bulb self.add_module("schedule", Schedule(self, "smartlife.iot.common.schedule")) self.add_module("usage", Usage(self, "smartlife.iot.common.schedule")) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index d7a7be620..9c8ae27c5 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -192,21 +192,18 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: """Create a new SmartDevice instance. :param str host: host name or ip address on which the device listens """ - self.host = host - self.port = port - config = DeviceConfig(host=host, port=port, timeout=timeout) - self.protocol: TPLinkProtocol = TPLinkSmartHomeProtocol( - transport=_XorTransport(config=config), + if config and protocol: + protocol._transport._config = config + self.protocol: TPLinkProtocol = protocol or TPLinkSmartHomeProtocol( + transport=_XorTransport(config=config or DeviceConfig(host=host)), ) - self.credentials = credentials _LOGGER.debug("Initializing %s of type %s", self.host, type(self)) self._device_type = DeviceType.Unknown # TODO: typing Any is just as using Optional[Dict] would require separate @@ -221,6 +218,30 @@ def __init__( self.children: List["SmartDevice"] = [] + @property + def host(self) -> str: + """The device host.""" + return self.protocol._transport._host + + @host.setter + def host(self, value): + """Set the device host. + + Generally used by discovery to set the hostname after ip discovery. + """ + self.protocol._transport._host = value + self.protocol._transport._config.host = value + + @property + def port(self) -> int: + """The device port.""" + return self.protocol._transport._port + + @property + def credentials(self) -> Optional[Credentials]: + """The device credentials.""" + return self.protocol._transport._credentials + def add_module(self, name: str, module: Module): """Register a module.""" if name in self.modules: diff --git a/kasa/smartdimmer.py b/kasa/smartdimmer.py index 7980319c7..ca0960f11 100644 --- a/kasa/smartdimmer.py +++ b/kasa/smartdimmer.py @@ -2,8 +2,9 @@ from enum import Enum from typing import Any, Dict, Optional -from kasa.credentials import Credentials +from kasa.deviceconfig import DeviceConfig from kasa.modules import AmbientLight, Motion +from kasa.protocol import TPLinkProtocol from kasa.smartdevice import DeviceType, SmartDeviceException, requires_update from kasa.smartplug import SmartPlug @@ -68,11 +69,10 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Dimmer # TODO: need to be verified if it's okay to call these on HS220 w/o these # TODO: need to be figured out what's the best approach to detect support diff --git a/kasa/smartlightstrip.py b/kasa/smartlightstrip.py index 2990e1fa4..27ebf8381 100644 --- a/kasa/smartlightstrip.py +++ b/kasa/smartlightstrip.py @@ -1,8 +1,9 @@ """Module for light strips (KL430).""" from typing import Any, Dict, List, Optional -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .effects import EFFECT_MAPPING_V1, EFFECT_NAMES_V1 +from .protocol import TPLinkProtocol from .smartbulb import SmartBulb from .smartdevice import DeviceType, SmartDeviceException, requires_update @@ -46,11 +47,10 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.LightStrip @property # type: ignore diff --git a/kasa/smartplug.py b/kasa/smartplug.py index 4ba230b49..d9ac0c863 100644 --- a/kasa/smartplug.py +++ b/kasa/smartplug.py @@ -2,8 +2,9 @@ import logging from typing import Any, Dict, Optional -from kasa.credentials import Credentials +from kasa.deviceconfig import DeviceConfig from kasa.modules import Antitheft, Cloud, Schedule, Time, Usage +from kasa.protocol import TPLinkProtocol from kasa.smartdevice import DeviceType, SmartDevice, requires_update _LOGGER = logging.getLogger(__name__) @@ -43,11 +44,10 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Plug self.add_module("schedule", Schedule(self, "schedule")) self.add_module("usage", Usage(self, "schedule")) diff --git a/kasa/smartstrip.py b/kasa/smartstrip.py index 80aa27d1b..793931325 100755 --- a/kasa/smartstrip.py +++ b/kasa/smartstrip.py @@ -14,8 +14,9 @@ ) from kasa.smartplug import SmartPlug -from .credentials import Credentials +from .deviceconfig import DeviceConfig from .modules import Antitheft, Countdown, Emeter, Schedule, Time, Usage +from .protocol import TPLinkProtocol _LOGGER = logging.getLogger(__name__) @@ -85,11 +86,10 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host=host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self.emeter_type = "emeter" self._device_type = DeviceType.Strip self.add_module("antitheft", Antitheft(self, "anti_theft")) diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index e42448595..0678bef16 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -5,9 +5,9 @@ from typing import Any, Dict, Optional, Set, cast from ..aestransport import AesTransport -from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import AuthenticationException +from ..protocol import TPLinkProtocol from ..smartdevice import SmartDevice from ..smartprotocol import SmartProtocol @@ -21,23 +21,16 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + _protocol = protocol or SmartProtocol( + transport=AesTransport(config=config or DeviceConfig(host=host)), + ) + super().__init__(host=host, config=config, protocol=_protocol) self._components: Optional[Dict[str, Any]] = None self._state_information: Dict[str, Any] = {} self._discovery_info: Optional[Dict[str, Any]] = None - config = DeviceConfig( - host=host, - port=port, - credentials=credentials, # type: ignore[arg-type] - timeout=timeout, - ) - self.protocol = SmartProtocol( - transport=AesTransport(config=config), - ) async def update(self, update_children: bool = True): """Update the device.""" diff --git a/kasa/tapo/tapoplug.py b/kasa/tapo/tapoplug.py index 9d868253e..67aed565a 100644 --- a/kasa/tapo/tapoplug.py +++ b/kasa/tapo/tapoplug.py @@ -3,9 +3,10 @@ from datetime import datetime, timedelta from typing import Any, Dict, Optional, cast -from ..credentials import Credentials +from ..deviceconfig import DeviceConfig from ..emeterstatus import EmeterStatus from ..modules import Emeter +from ..protocol import TPLinkProtocol from ..smartdevice import DeviceType, requires_update from .tapodevice import TapoDevice @@ -19,11 +20,10 @@ def __init__( self, host: str, *, - port: Optional[int] = None, - credentials: Optional[Credentials] = None, - timeout: Optional[int] = None, + config: Optional[DeviceConfig] = None, + protocol: Optional[TPLinkProtocol] = None, ) -> None: - super().__init__(host, port=port, credentials=credentials, timeout=timeout) + super().__init__(host=host, config=config, protocol=protocol) self._device_type = DeviceType.Plug self.modules: Dict[str, Any] = {} self.emeter_type = "emeter" diff --git a/kasa/tests/conftest.py b/kasa/tests/conftest.py index 2320f1161..11efe6937 100644 --- a/kasa/tests/conftest.py +++ b/kasa/tests/conftest.py @@ -388,7 +388,6 @@ def load_file(): d = device_for_file(model, protocol)(host="127.0.0.123") if protocol == "SMART": d.protocol = FakeSmartProtocol(sysinfo) - d.credentials = Credentials() else: d.protocol = FakeTransportProtocol(sysinfo) await _update_and_close(d) @@ -426,6 +425,7 @@ def discovery_mock(all_fixture_data, mocker): class _DiscoveryMock: ip: str default_port: int + discovery_port: int discovery_data: dict query_data: dict device_type: str @@ -444,6 +444,7 @@ class _DiscoveryMock: ) dm = _DiscoveryMock( "127.0.0.123", + 80, 20002, discovery_data, all_fixture_data, @@ -459,6 +460,7 @@ class _DiscoveryMock: dm = _DiscoveryMock( "127.0.0.123", 9999, + 9999, discovery_data, all_fixture_data, device_type, @@ -468,8 +470,8 @@ class _DiscoveryMock: def mock_discover(self): port = ( dm.port_override - if dm.port_override and dm.default_port != 20002 - else dm.default_port + if dm.port_override and dm.discovery_port != 20002 + else dm.discovery_port ) self.datagram_received( datagram, diff --git a/kasa/tests/newfakes.py b/kasa/tests/newfakes.py index 53748700f..13d11d3d9 100644 --- a/kasa/tests/newfakes.py +++ b/kasa/tests/newfakes.py @@ -17,7 +17,7 @@ from ..credentials import Credentials from ..deviceconfig import DeviceConfig -from ..protocol import BaseTransport, TPLinkSmartHomeProtocol +from ..protocol import BaseTransport, TPLinkSmartHomeProtocol, _XorTransport from ..smartprotocol import SmartProtocol _LOGGER = logging.getLogger(__name__) @@ -309,6 +309,11 @@ def __init__(self, info): ) self.info = info + @property + def default_port(self): + """Default port for the transport.""" + return 80 + async def send(self, request: str): request_dict = json_loads(request) method = request_dict["method"] @@ -348,6 +353,11 @@ async def close(self) -> None: class FakeTransportProtocol(TPLinkSmartHomeProtocol): def __init__(self, info): + super().__init__( + transport=_XorTransport( + config=DeviceConfig("127.0.0.123"), + ) + ) self.discovery_data = info self.writer = None self.reader = None diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index f8e2aa01d..835dcd3c8 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -77,7 +77,7 @@ async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port): host = "127.0.0.1" ctype, _ = _get_connection_type_device_class(all_fixture_data) - config = DeviceConfig(host=host, port=custom_port, connection_type=ctype) + config = DeviceConfig(host=host, port_override=custom_port, connection_type=ctype) default_port = 80 if "discovery_result" in all_fixture_data else 9999 ctype, _ = _get_connection_type_device_class(all_fixture_data) diff --git a/kasa/tests/test_discovery.py b/kasa/tests/test_discovery.py index d9b4631cf..396ef2f2e 100644 --- a/kasa/tests/test_discovery.py +++ b/kasa/tests/test_discovery.py @@ -112,9 +112,9 @@ async def test_discover_single(discovery_mock, custom_port, mocker): ct = ConnectionType.from_values( discovery_mock.device_type, discovery_mock.encrypt_type ) - uses_http = discovery_mock.default_port == 20002 + uses_http = discovery_mock.default_port == 80 config = DeviceConfig( - host=host, port=custom_port, connection_type=ct, uses_http=uses_http + host=host, port_override=custom_port, connection_type=ct, uses_http=uses_http ) assert x.config == config @@ -310,9 +310,9 @@ async def test_discover_single_http_client(discovery_mock, mocker): x: SmartDevice = await Discover.discover_single(host) - assert x.config.uses_http == (discovery_mock.default_port == 20002) + assert x.config.uses_http == (discovery_mock.default_port == 80) - if discovery_mock.default_port == 20002: + if discovery_mock.default_port == 80: assert x.protocol._transport._http_client != http_client x.config.http_client = http_client assert x.protocol._transport._http_client == http_client @@ -327,9 +327,9 @@ async def test_discover_http_client(discovery_mock, mocker): devices = await Discover.discover(discovery_timeout=0) x: SmartDevice = devices[host] - assert x.config.uses_http == (discovery_mock.default_port == 20002) + assert x.config.uses_http == (discovery_mock.default_port == 80) - if discovery_mock.default_port == 20002: + if discovery_mock.default_port == 80: assert x.protocol._transport._http_client != http_client x.config.http_client = http_client assert x.protocol._transport._http_client == http_client diff --git a/kasa/tests/test_protocol.py b/kasa/tests/test_protocol.py index 80289ccbc..0e74da3b8 100644 --- a/kasa/tests/test_protocol.py +++ b/kasa/tests/test_protocol.py @@ -176,7 +176,7 @@ def aio_mock_writer(_, port): mocker.patch.object(reader, "readexactly", _mock_read) return reader, writer - config = DeviceConfig("127.0.0.1", port=custom_port) + config = DeviceConfig("127.0.0.1", port_override=custom_port) protocol = TPLinkSmartHomeProtocol(transport=_XorTransport(config=config)) mocker.patch("asyncio.open_connection", side_effect=aio_mock_writer) response = await protocol.query({}) diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 0cc8925c6..501ef6241 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -6,7 +6,6 @@ import kasa from kasa import Credentials, DeviceConfig, SmartDevice, SmartDeviceException -from kasa.smartdevice import DeviceType from .conftest import device_iot, handle_turn_on, has_emeter, no_emeter_iot, turn_on from .newfakes import PLUG_SCHEMA, TZ_SCHEMA, FakeTransportProtocol @@ -215,7 +214,8 @@ def test_device_class_ctors(device_class): host = "127.0.0.2" port = 1234 credentials = Credentials("foo", "bar") - dev = device_class(host, port=port, credentials=credentials) + config = DeviceConfig(host, port_override=port, credentials=credentials) + dev = device_class(host, config=config) assert dev.host == host assert dev.port == port assert dev.credentials == credentials @@ -231,7 +231,8 @@ async def test_modules_preserved(dev: SmartDevice): async def test_create_smart_device_with_timeout(): """Make sure timeout is passed to the protocol.""" - dev = SmartDevice(host="127.0.0.1", timeout=100) + host = "127.0.0.1" + dev = SmartDevice(host, config=DeviceConfig(host, timeout=100)) assert dev.protocol._transport._timeout == 100 @@ -240,7 +241,7 @@ async def test_create_thin_wrapper(): mock = Mock() config = DeviceConfig( host="test_host", - port=1234, + port_override=1234, timeout=100, credentials=Credentials("username", "password"), ) From d742513d22e6da8b70fb1fc6bbcb8ca71c61247c Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 29 Dec 2023 09:14:13 +0000 Subject: [PATCH 4/5] Update following latest review --- kasa/device_factory.py | 63 ++++++++++++++++++++++++------- kasa/discover.py | 6 +-- kasa/protocolfactory.py | 39 ------------------- kasa/smartdevice.py | 10 +---- kasa/tapo/tapodevice.py | 2 +- kasa/tests/test_device_factory.py | 3 +- kasa/tests/test_smartdevice.py | 1 + 7 files changed, 57 insertions(+), 67 deletions(-) delete mode 100644 kasa/protocolfactory.py diff --git a/kasa/device_factory.py b/kasa/device_factory.py index cd6fe8bb0..0a34d9ca5 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -1,20 +1,27 @@ """Device creation via DeviceConfig.""" import logging import time -from typing import Any, Dict, Optional, Type - -from kasa.deviceconfig import DeviceConfig -from kasa.protocol import TPLinkSmartHomeProtocol -from kasa.smartbulb import SmartBulb -from kasa.smartdevice import SmartDevice -from kasa.smartdimmer import SmartDimmer -from kasa.smartlightstrip import SmartLightStrip -from kasa.smartplug import SmartPlug -from kasa.smartstrip import SmartStrip -from kasa.tapo import TapoBulb, TapoPlug +from typing import Any, Dict, Optional, Tuple, Type +from .aestransport import AesTransport +from .deviceconfig import DeviceConfig from .exceptions import SmartDeviceException, UnsupportedDeviceException -from .protocolfactory import get_protocol +from .iotprotocol import IotProtocol +from .klaptransport import KlapTransport, KlapTransportV2 +from .protocol import ( + BaseTransport, + TPLinkProtocol, + TPLinkSmartHomeProtocol, + _XorTransport, +) +from .smartbulb import SmartBulb +from .smartdevice import SmartDevice +from .smartdimmer import SmartDimmer +from .smartlightstrip import SmartLightStrip +from .smartplug import SmartPlug +from .smartprotocol import SmartProtocol +from .smartstrip import SmartStrip +from .tapo import TapoBulb, TapoPlug _LOGGER = logging.getLogger(__name__) @@ -23,11 +30,16 @@ } -async def connect(*, config: DeviceConfig) -> "SmartDevice": +async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "SmartDevice": """Connect to a single device by the given connection parameters. Do not use this function directly, use SmartDevice.Connect() """ + if host and config or (not host and not config): + raise SmartDeviceException("One of host or config must be provded and not both") + if host: + config = DeviceConfig(host=host) + debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG) if debug_enabled: start_time = time.perf_counter() @@ -110,3 +122,28 @@ def get_device_class_from_family(device_type: str) -> Optional[Type[SmartDevice] "IOT.SMARTBULB": SmartBulb, } return supported_device_types.get(device_type) + + +def get_protocol( + config: DeviceConfig, +) -> Optional[TPLinkProtocol]: + """Return the protocol from the connection name.""" + protocol_name = config.connection_type.device_family.value.split(".")[0] + protocol_transport_key = ( + protocol_name + "." + config.connection_type.encryption_type.value + ) + supported_device_protocols: dict[ + str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] + ] = { + "IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport), + "IOT.KLAP": (IotProtocol, KlapTransport), + "SMART.AES": (SmartProtocol, AesTransport), + "SMART.KLAP": (SmartProtocol, KlapTransportV2), + } + if protocol_transport_key not in supported_device_protocols: + return None + + protocol_class, transport_class = supported_device_protocols.get( + protocol_transport_key + ) # type: ignore + return protocol_class(transport=transport_class(config=config)) diff --git a/kasa/discover.py b/kasa/discover.py index f6189dd74..e39122f3b 100755 --- a/kasa/discover.py +++ b/kasa/discover.py @@ -20,13 +20,13 @@ from kasa.device_factory import ( get_device_class_from_family, get_device_class_from_sys_info, + get_protocol, ) from kasa.deviceconfig import ConnectionType, DeviceConfig, EncryptType from kasa.exceptions import UnsupportedDeviceException from kasa.json import dumps as json_dumps from kasa.json import loads as json_loads from kasa.protocol import TPLinkSmartHomeProtocol -from kasa.protocolfactory import get_protocol from kasa.smartdevice import SmartDevice, SmartDeviceException _LOGGER = logging.getLogger(__name__) @@ -387,9 +387,7 @@ def _get_device_instance_legacy(data: bytes, config: DeviceConfig) -> SmartDevic device_class = Discover._get_device_class(info) device = device_class(config.host, config=config) sys_info = info["system"]["get_sysinfo"] - if (device_type := sys_info.get("mic_type")) or ( - device_type := sys_info.get("type") - ): + if device_type := sys_info.get("mic_type", sys_info.get("type")): config.connection_type = ConnectionType.from_values( device_family=device_type, encryption_type=EncryptType.Xor.value ) diff --git a/kasa/protocolfactory.py b/kasa/protocolfactory.py deleted file mode 100644 index d30a986eb..000000000 --- a/kasa/protocolfactory.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Module for protocol factory class.""" -from typing import Optional, Tuple, Type - -from .aestransport import AesTransport -from .deviceconfig import DeviceConfig -from .iotprotocol import IotProtocol -from .klaptransport import KlapTransport, KlapTransportV2 -from .protocol import ( - BaseTransport, - TPLinkProtocol, - TPLinkSmartHomeProtocol, - _XorTransport, -) -from .smartprotocol import SmartProtocol - - -def get_protocol( - config: DeviceConfig, -) -> Optional[TPLinkProtocol]: - """Return the protocol from the connection name.""" - protocol_name = config.connection_type.device_family.value.split(".")[0] - protocol_transport_key = ( - protocol_name + "." + config.connection_type.encryption_type.value - ) - supported_device_protocols: dict[ - str, Tuple[Type[TPLinkProtocol], Type[BaseTransport]] - ] = { - "IOT.XOR": (TPLinkSmartHomeProtocol, _XorTransport), - "IOT.KLAP": (IotProtocol, KlapTransport), - "SMART.AES": (SmartProtocol, AesTransport), - "SMART.KLAP": (SmartProtocol, KlapTransportV2), - } - if protocol_transport_key not in supported_device_protocols: - return None - - protocol_class, transport_class = supported_device_protocols.get( - protocol_transport_key - ) # type: ignore - return protocol_class(transport=transport_class(config=config)) diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 9c8ae27c5..016c99417 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -417,7 +417,7 @@ class itself as @requires_update will be affected for other properties. """ return self._sys_info # type: ignore - @property + @property # type: ignore @requires_update def model(self) -> str: """Return device model.""" @@ -825,10 +825,4 @@ async def connect( """ from .device_factory import connect # pylint: disable=import-outside-toplevel - if host and config or (not host and not config): - raise SmartDeviceException( - "One of host or config must be provded and not both" - ) - if host: - config = DeviceConfig(host=host) - return await connect(config=config) # type: ignore[arg-type] + return await connect(host=host, config=config) # type: ignore[arg-type] diff --git a/kasa/tapo/tapodevice.py b/kasa/tapo/tapodevice.py index 0678bef16..717de7ef4 100644 --- a/kasa/tapo/tapodevice.py +++ b/kasa/tapo/tapodevice.py @@ -132,7 +132,7 @@ def mac(self) -> str: @property def device_id(self) -> str: """Return the device id.""" - return str(self._info.get("device_id")) # type: ignore + return str(self._info.get("device_id")) @property def internal_state(self) -> Any: diff --git a/kasa/tests/test_device_factory.py b/kasa/tests/test_device_factory.py index 835dcd3c8..666bd9e95 100644 --- a/kasa/tests/test_device_factory.py +++ b/kasa/tests/test_device_factory.py @@ -16,7 +16,7 @@ SmartLightStrip, SmartPlug, ) -from kasa.device_factory import connect +from kasa.device_factory import connect, get_protocol from kasa.deviceconfig import ( ConnectionType, DeviceConfig, @@ -24,7 +24,6 @@ EncryptType, ) from kasa.discover import DiscoveryResult -from kasa.protocolfactory import get_protocol def _get_connection_type_device_class(the_fixture_data): diff --git a/kasa/tests/test_smartdevice.py b/kasa/tests/test_smartdevice.py index 501ef6241..a3019bff6 100644 --- a/kasa/tests/test_smartdevice.py +++ b/kasa/tests/test_smartdevice.py @@ -250,6 +250,7 @@ async def test_create_thin_wrapper(): assert dev is mock connect.assert_called_once_with( + host=None, config=config, ) From 9325ba5ecd347532ca041398fb9e48578f92039b Mon Sep 17 00:00:00 2001 From: sdb9696 Date: Fri, 29 Dec 2023 17:28:45 +0000 Subject: [PATCH 5/5] Update docstrings and docs --- docs/source/design.rst | 9 ++++++--- docs/source/deviceconfig.rst | 18 ++++++++++++++++++ docs/source/index.rst | 1 + kasa/device_factory.py | 18 ++++++++++++++++-- kasa/smartdevice.py | 4 +--- 5 files changed, 42 insertions(+), 8 deletions(-) create mode 100644 docs/source/deviceconfig.rst diff --git a/docs/source/design.rst b/docs/source/design.rst index 5679943d2..6538c8b80 100644 --- a/docs/source/design.rst +++ b/docs/source/design.rst @@ -23,9 +23,12 @@ This will return you a list of device instances based on the discovery replies. If the device's host is already known, you can use to construct a device instance with :meth:`~kasa.SmartDevice.connect()`. -When connecting a device with the :meth:`~kasa.SmartDevice.connect()` method, it is recommended to -pass the device type as well as this allows the library to use the correct device class for the -device without having to query the device. +The :meth:`~kasa.SmartDevice.connect()` also enables support for connecting to new +KASA SMART protocol and TAPO devices directly using the parameter :class:`~kasa.DeviceConfig`. +Simply serialize the :attr:`~kasa.SmartDevice.config` property via :meth:`~kasa.DeviceConfig.to_dict()` +and then deserialize it later with :func:`~kasa.DeviceConfig.from_dict()` +and then pass it into :meth:`~kasa.SmartDevice.connect()`. + .. _update_cycle: diff --git a/docs/source/deviceconfig.rst b/docs/source/deviceconfig.rst new file mode 100644 index 000000000..25bf077ba --- /dev/null +++ b/docs/source/deviceconfig.rst @@ -0,0 +1,18 @@ +DeviceConfig +============ + +.. contents:: Contents + :local: + +.. note:: + + Feel free to open a pull request to improve the documentation! + + +API documentation +***************** + +.. autoclass:: kasa.DeviceConfig + :members: + :inherited-members: + :undoc-members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 346c53d08..16e7cbd07 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -15,3 +15,4 @@ smartdimmer smartstrip smartlightstrip + deviceconfig diff --git a/kasa/device_factory.py b/kasa/device_factory.py index 0a34d9ca5..505b64870 100755 --- a/kasa/device_factory.py +++ b/kasa/device_factory.py @@ -31,9 +31,23 @@ async def connect(*, host: Optional[str] = None, config: DeviceConfig) -> "SmartDevice": - """Connect to a single device by the given connection parameters. + """Connect to a single device by the given hostname or device configuration. - Do not use this function directly, use SmartDevice.Connect() + This method avoids the UDP based discovery process and + will connect directly to the device. + + It is generally preferred to avoid :func:`discover_single()` and + use this function instead as it should perform better when + the WiFi network is congested or the device is not responding + to discovery requests. + + Do not use this function directly, use SmartDevice.connect() + + :param host: Hostname of device to query + :param config: Connection parameters to ensure the correct protocol + and connection options are used. + :rtype: SmartDevice + :return: Object for querying/controlling found device. """ if host and config or (not host and not config): raise SmartDeviceException("One of host or config must be provded and not both") diff --git a/kasa/smartdevice.py b/kasa/smartdevice.py index 016c99417..97b46ddca 100755 --- a/kasa/smartdevice.py +++ b/kasa/smartdevice.py @@ -805,7 +805,7 @@ async def connect( host: Optional[str] = None, config: Optional[DeviceConfig] = None, ) -> "SmartDevice": - """Connect to a single device by the given hostname or connection parameters. + """Connect to a single device by the given hostname or device configuration. This method avoids the UDP based discovery process and will connect directly to the device. @@ -815,8 +815,6 @@ async def connect( the WiFi network is congested or the device is not responding to discovery requests. - The device type is discovered by querying the device. - :param host: Hostname of device to query :param config: Connection parameters to ensure the correct protocol and connection options are used.