From 446b4bf7f0c3806fdb74b3f21b6483cfabc1f4a2 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:28:49 +0000 Subject: [PATCH 1/3] Migrate smart firmware module to mashumaru --- kasa/smart/modules/firmware.py | 38 +++++++++++++++------------- tests/smart/modules/test_firmware.py | 12 ++++----- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 5956a3575..90a5b9860 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -6,10 +6,11 @@ import logging from asyncio import timeout as asyncio_timeout from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field from datetime import date from typing import TYPE_CHECKING -from pydantic.v1 import BaseModel, Field, validator +from mashumaro import DataClassDictMixin, field_options from ...exceptions import KasaException from ...feature import Feature @@ -22,36 +23,39 @@ _LOGGER = logging.getLogger(__name__) -class DownloadState(BaseModel): +@dataclass +class DownloadState(DataClassDictMixin): """Download state.""" # Example: # {'status': 0, 'download_progress': 0, 'reboot_time': 5, # 'upgrade_time': 5, 'auto_upgrade': False} status: int - progress: int = Field(alias="download_progress") + progress: int = field(metadata=field_options(alias="download_progress")) reboot_time: int upgrade_time: int auto_upgrade: bool -class UpdateInfo(BaseModel): +@dataclass +class UpdateInfo(DataClassDictMixin): """Update info status object.""" - status: int = Field(alias="type") - version: str | None = Field(alias="fw_ver", default=None) + status: int = field(metadata=field_options(alias="type")) + needs_upgrade: bool = field(metadata=field_options(alias="need_to_upgrade")) + version: str | None = field(metadata=field_options(alias="fw_ver"), default=None) release_date: date | None = None - release_notes: str | None = Field(alias="release_note", default=None) + release_notes: str | None = field( + metadata=field_options(alias="release_note"), default=None + ) fw_size: int | None = None oem_id: str | None = None - needs_upgrade: bool = Field(alias="need_to_upgrade") - @validator("release_date", pre=True) - def _release_date_optional(cls, v: str) -> str | None: - if not v: - return None - - return v + @classmethod + def __pre_deserialize__(cls, d: dict) -> dict: + if d.get("release_date") == "": + return {**d, "release_date": None} + return d @property def update_available(self) -> bool: @@ -139,7 +143,7 @@ async def check_latest_firmware(self) -> UpdateInfo | None: """Check for the latest firmware for the device.""" try: fw = await self.call("get_latest_fw") - self._firmware_update_info = UpdateInfo.parse_obj(fw["get_latest_fw"]) + self._firmware_update_info = UpdateInfo.from_dict(fw["get_latest_fw"]) return self._firmware_update_info except Exception: _LOGGER.exception("Error getting latest firmware for %s:", self._device) @@ -174,7 +178,7 @@ async def get_update_state(self) -> DownloadState: """Return update state.""" resp = await self.call("get_fw_download_state") state = resp["get_fw_download_state"] - return DownloadState(**state) + return DownloadState.from_dict(state) @allow_update_after async def update( @@ -232,7 +236,7 @@ async def update( else: _LOGGER.warning("Unhandled state code: %s", state) - return state.dict() + return state.to_dict() @property def auto_update_enabled(self) -> bool: diff --git a/tests/smart/modules/test_firmware.py b/tests/smart/modules/test_firmware.py index 3115c56f1..f4d7108cf 100644 --- a/tests/smart/modules/test_firmware.py +++ b/tests/smart/modules/test_firmware.py @@ -105,15 +105,15 @@ class Extras(TypedDict): } update_states = [ # Unknown 1 - DownloadState(status=1, download_progress=0, **extras), + DownloadState(status=1, progress=0, **extras), # Downloading - DownloadState(status=2, download_progress=10, **extras), - DownloadState(status=2, download_progress=100, **extras), + DownloadState(status=2, progress=10, **extras), + DownloadState(status=2, progress=100, **extras), # Flashing - DownloadState(status=3, download_progress=100, **extras), - DownloadState(status=3, download_progress=100, **extras), + DownloadState(status=3, progress=100, **extras), + DownloadState(status=3, progress=100, **extras), # Done - DownloadState(status=0, download_progress=100, **extras), + DownloadState(status=0, progress=100, **extras), ] asyncio_sleep = asyncio.sleep From 9263f7ec874f81c07d2db44a6ac6aee15a35c9ce Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:15:30 +0000 Subject: [PATCH 2/3] Use Annotated Alias --- kasa/smart/modules/firmware.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index 90a5b9860..ea9dc2430 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -6,11 +6,12 @@ import logging from asyncio import timeout as asyncio_timeout from collections.abc import Callable, Coroutine -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import date -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated -from mashumaro import DataClassDictMixin, field_options +from mashumaro import DataClassDictMixin +from mashumaro.types import Alias from ...exceptions import KasaException from ...feature import Feature @@ -31,7 +32,7 @@ class DownloadState(DataClassDictMixin): # {'status': 0, 'download_progress': 0, 'reboot_time': 5, # 'upgrade_time': 5, 'auto_upgrade': False} status: int - progress: int = field(metadata=field_options(alias="download_progress")) + progress: Annotated[int, Alias("download_progress")] reboot_time: int upgrade_time: int auto_upgrade: bool @@ -41,13 +42,11 @@ class DownloadState(DataClassDictMixin): class UpdateInfo(DataClassDictMixin): """Update info status object.""" - status: int = field(metadata=field_options(alias="type")) - needs_upgrade: bool = field(metadata=field_options(alias="need_to_upgrade")) - version: str | None = field(metadata=field_options(alias="fw_ver"), default=None) + status: Annotated[int, Alias("type")] + needs_upgrade: Annotated[bool, Alias("need_to_upgrade")] + version: Annotated[str | None, Alias("fw_ver")] = None release_date: date | None = None - release_notes: str | None = field( - metadata=field_options(alias="release_note"), default=None - ) + release_notes: Annotated[str | None, Alias("release_note")] = None fw_size: int | None = None oem_id: str | None = None From d33ba679b5e33552038f6d61482a8f0f1fce5b17 Mon Sep 17 00:00:00 2001 From: Steven B <51370195+sdb9696@users.noreply.github.com> Date: Wed, 20 Nov 2024 11:46:42 +0000 Subject: [PATCH 3/3] Use deserialize field_option for release_date --- kasa/smart/modules/firmware.py | 17 ++++++++--------- tests/smart/modules/test_firmware.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/kasa/smart/modules/firmware.py b/kasa/smart/modules/firmware.py index ea9dc2430..8dd3a6b32 100644 --- a/kasa/smart/modules/firmware.py +++ b/kasa/smart/modules/firmware.py @@ -6,11 +6,11 @@ import logging from asyncio import timeout as asyncio_timeout from collections.abc import Callable, Coroutine -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import date from typing import TYPE_CHECKING, Annotated -from mashumaro import DataClassDictMixin +from mashumaro import DataClassDictMixin, field_options from mashumaro.types import Alias from ...exceptions import KasaException @@ -45,17 +45,16 @@ class UpdateInfo(DataClassDictMixin): status: Annotated[int, Alias("type")] needs_upgrade: Annotated[bool, Alias("need_to_upgrade")] version: Annotated[str | None, Alias("fw_ver")] = None - release_date: date | None = None + release_date: date | None = field( + default=None, + metadata=field_options( + deserialize=lambda x: date.fromisoformat(x) if x else None + ), + ) release_notes: Annotated[str | None, Alias("release_note")] = None fw_size: int | None = None oem_id: str | None = None - @classmethod - def __pre_deserialize__(cls, d: dict) -> dict: - if d.get("release_date") == "": - return {**d, "release_date": None} - return d - @property def update_available(self) -> bool: """Return True if update available.""" diff --git a/tests/smart/modules/test_firmware.py b/tests/smart/modules/test_firmware.py index f4d7108cf..0bc0a4eab 100644 --- a/tests/smart/modules/test_firmware.py +++ b/tests/smart/modules/test_firmware.py @@ -3,6 +3,7 @@ import asyncio import logging from contextlib import nullcontext +from datetime import date from typing import TypedDict import pytest @@ -52,6 +53,20 @@ async def test_firmware_features( assert isinstance(feat.value, type) +@firmware +async def test_firmware_update_info(dev: SmartDevice): + """Test that the firmware UpdateInfo object deserializes correctly.""" + fw = dev.modules.get(Module.Firmware) + assert fw + + if not dev.is_cloud_connected: + pytest.skip("Device is not cloud connected, skipping test") + assert fw.firmware_update_info is None + await fw.check_latest_firmware() + assert fw.firmware_update_info is not None + assert isinstance(fw.firmware_update_info.release_date, date | None) + + @firmware async def test_update_available_without_cloud(dev: SmartDevice): """Test that update_available returns None when disconnected."""