Skip to content

Migrate smart firmware module to mashumaro #1276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions kasa/smart/modules/firmware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
import logging
from asyncio import timeout as asyncio_timeout
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import date
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated

from pydantic.v1 import BaseModel, Field, validator
from mashumaro import DataClassDictMixin, field_options
from mashumaro.types import Alias

from ...exceptions import KasaException
from ...feature import Feature
Expand All @@ -22,36 +24,36 @@
_LOGGER = logging.getLogger(__name__)


class DownloadState(BaseModel):
@dataclass
class DownloadState(DataClassDictMixin):
"""Download state."""

# Example:
# {'status': 0, 'download_progress': 0, 'reboot_time': 5,
# 'upgrade_time': 5, 'auto_upgrade': False}
status: int
progress: int = Field(alias="download_progress")
progress: Annotated[int, Alias("download_progress")]
reboot_time: int
upgrade_time: int
auto_upgrade: bool


class UpdateInfo(BaseModel):
@dataclass
class UpdateInfo(DataClassDictMixin):
"""Update info status object."""

status: int = Field(alias="type")
version: str | None = Field(alias="fw_ver", default=None)
release_date: date | None = None
release_notes: str | None = Field(alias="release_note", default=None)
status: Annotated[int, Alias("type")]
needs_upgrade: Annotated[bool, Alias("need_to_upgrade")]
version: Annotated[str | None, Alias("fw_ver")] = None
release_date: date | None = field(
default=None,
metadata=field_options(
deserialize=lambda x: date.fromisoformat(x) if x else None
),
)
release_notes: Annotated[str | None, Alias("release_note")] = None
fw_size: int | None = None
oem_id: str | None = None
needs_upgrade: bool = Field(alias="need_to_upgrade")

@validator("release_date", pre=True)
def _release_date_optional(cls, v: str) -> str | None:
if not v:
return None

return v

@property
def update_available(self) -> bool:
Expand Down Expand Up @@ -139,7 +141,7 @@
"""Check for the latest firmware for the device."""
try:
fw = await self.call("get_latest_fw")
self._firmware_update_info = UpdateInfo.parse_obj(fw["get_latest_fw"])
self._firmware_update_info = UpdateInfo.from_dict(fw["get_latest_fw"])
return self._firmware_update_info
except Exception:
_LOGGER.exception("Error getting latest firmware for %s:", self._device)
Expand Down Expand Up @@ -174,7 +176,7 @@
"""Return update state."""
resp = await self.call("get_fw_download_state")
state = resp["get_fw_download_state"]
return DownloadState(**state)
return DownloadState.from_dict(state)

Check warning on line 179 in kasa/smart/modules/firmware.py

View check run for this annotation

Codecov / codecov/patch

kasa/smart/modules/firmware.py#L179

Added line #L179 was not covered by tests

@allow_update_after
async def update(
Expand Down Expand Up @@ -232,7 +234,7 @@
else:
_LOGGER.warning("Unhandled state code: %s", state)

return state.dict()
return state.to_dict()

@property
def auto_update_enabled(self) -> bool:
Expand Down
27 changes: 21 additions & 6 deletions tests/smart/modules/test_firmware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
from contextlib import nullcontext
from datetime import date
from typing import TypedDict

import pytest
Expand Down Expand Up @@ -52,6 +53,20 @@ async def test_firmware_features(
assert isinstance(feat.value, type)


@firmware
async def test_firmware_update_info(dev: SmartDevice):
"""Test that the firmware UpdateInfo object deserializes correctly."""
fw = dev.modules.get(Module.Firmware)
assert fw

if not dev.is_cloud_connected:
pytest.skip("Device is not cloud connected, skipping test")
assert fw.firmware_update_info is None
await fw.check_latest_firmware()
assert fw.firmware_update_info is not None
assert isinstance(fw.firmware_update_info.release_date, date | None)


@firmware
async def test_update_available_without_cloud(dev: SmartDevice):
"""Test that update_available returns None when disconnected."""
Expand Down Expand Up @@ -105,15 +120,15 @@ class Extras(TypedDict):
}
update_states = [
# Unknown 1
DownloadState(status=1, download_progress=0, **extras),
DownloadState(status=1, progress=0, **extras),
# Downloading
DownloadState(status=2, download_progress=10, **extras),
DownloadState(status=2, download_progress=100, **extras),
DownloadState(status=2, progress=10, **extras),
DownloadState(status=2, progress=100, **extras),
# Flashing
DownloadState(status=3, download_progress=100, **extras),
DownloadState(status=3, download_progress=100, **extras),
DownloadState(status=3, progress=100, **extras),
DownloadState(status=3, progress=100, **extras),
# Done
DownloadState(status=0, download_progress=100, **extras),
DownloadState(status=0, progress=100, **extras),
]

asyncio_sleep = asyncio.sleep
Expand Down
Loading