Skip to content

Refactor aestransport to use a state enum #691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 24, 2024
80 changes: 44 additions & 36 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import hashlib
import logging
import time
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast

from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
Expand Down Expand Up @@ -41,6 +42,14 @@ def _sha1(payload: bytes) -> str:
return sha1_algo.hexdigest()


class TransportState(Enum):
"""Enum for AES state."""

HANDSHAKE_REQUIRED = auto() # Handshake needed
LOGIN_REQUIRED = auto() # Login needed
ESTABLISHED = auto() # Ready to send requests


class AesTransport(BaseTransport):
"""Implementation of the AES encryption protocol.

Expand Down Expand Up @@ -79,21 +88,21 @@ def __init__(
self._default_credentials: Optional[Credentials] = None
self._http_client: HttpClient = HttpClient(config)

self._handshake_done = False
self._state = TransportState.HANDSHAKE_REQUIRED

self._encryption_session: Optional[AesEncyptionSession] = None
self._session_expire_at: Optional[float] = None

self._session_cookie: Optional[Dict[str, str]] = None

self._login_token = None
self._login_token: Optional[str] = None

self._key_pair: Optional[KeyPair] = None

_LOGGER.debug("Created AES transport for %s", self._host)

@property
def default_port(self):
def default_port(self) -> int:
"""Default port for the transport."""
return self.DEFAULT_PORT

Expand All @@ -102,30 +111,25 @@ def credentials_hash(self) -> str:
"""The hashed credentials used by the transport."""
return base64.b64encode(json_dumps(self._login_params).encode()).decode()

def _get_login_params(self, credentials):
def _get_login_params(self, credentials: Credentials) -> Dict[str, str]:
"""Get the login parameters based on the login_version."""
un, pw = self.hash_credentials(self._login_version == 2, credentials)
password_field_name = "password2" if self._login_version == 2 else "password"
return {password_field_name: pw, "username": un}

@staticmethod
def hash_credentials(login_v2, credentials):
def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]:
"""Hash the credentials."""
un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode()
if login_v2:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(
_sha1(credentials.password.encode()).encode()
).decode()
else:
un = base64.b64encode(
_sha1(credentials.username.encode()).encode()
).decode()
pw = base64.b64encode(credentials.password.encode()).decode()
return un, pw

def _handle_response_error_code(self, resp_dict: dict, msg: str):
def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None:
error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
if error_code == SmartErrorCode.SUCCESS:
return
Expand All @@ -135,12 +139,11 @@ def _handle_response_error_code(self, resp_dict: dict, msg: str):
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg, error_code=error_code)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False
self._login_token = None
self._state = TransportState.HANDSHAKE_REQUIRED
raise AuthenticationException(msg, error_code=error_code)
raise SmartDeviceException(msg, error_code=error_code)

async def send_secure_passthrough(self, request: str):
async def send_secure_passthrough(self, request: str) -> Dict[str, Any]:
"""Send encrypted message as passthrough."""
url = f"http://{self._host}/app"
if self._login_token:
Expand All @@ -165,24 +168,25 @@ async def send_secure_passthrough(self, request: str):
+ f"status code {status_code} to passthrough"
)

resp_dict = cast(Dict, resp_dict)
self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)

response = self._encryption_session.decrypt( # type: ignore
resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
return resp_dict
if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)
assert self._encryption_session is not None

raw_response: str = resp_dict["result"]["response"]
response = self._encryption_session.decrypt(raw_response.encode())
return json_loads(response) # type: ignore[return-value]

async def perform_login(self):
"""Login to the device."""
try:
await self.try_login(self._login_params)
except AuthenticationException as aex:
try:
if aex.error_code != SmartErrorCode.LOGIN_ERROR:
if aex.error_code is not SmartErrorCode.LOGIN_ERROR:
raise aex
if self._default_credentials is None:
self._default_credentials = get_default_credentials(
Expand All @@ -203,9 +207,8 @@ async def perform_login(self):
ex,
) from ex

async def try_login(self, login_params):
async def try_login(self, login_params: Dict[str, Any]) -> None:
"""Try to login with supplied login_params."""
self._login_token = None
login_request = {
"method": "login_device",
"params": login_params,
Expand All @@ -216,6 +219,7 @@ async def try_login(self, login_params):
resp_dict = await self.send_secure_passthrough(request)
self._handle_response_error_code(resp_dict, "Error logging in")
self._login_token = resp_dict["result"]["token"]
self._state = TransportState.ESTABLISHED

async def _generate_key_pair_payload(self) -> AsyncGenerator:
"""Generate the request body and return an ascyn_generator.
Expand All @@ -236,12 +240,11 @@ async def _generate_key_pair_payload(self) -> AsyncGenerator:
_LOGGER.debug(f"Request {request_body}")
yield json_dumps(request_body).encode()

async def perform_handshake(self):
async def perform_handshake(self) -> None:
"""Perform the handshake."""
_LOGGER.debug("Will perform handshaking...")

self._key_pair = None
self._handshake_done = False
self._session_expire_at = None
self._session_cookie = None

Expand All @@ -258,7 +261,7 @@ async def perform_handshake(self):
cookies_dict=self._session_cookie,
)

_LOGGER.debug(f"Device responded with: {resp_dict}")
_LOGGER.debug("Device responded with: %s", resp_dict)

if status_code != 200:
raise SmartDeviceException(
Expand All @@ -268,6 +271,9 @@ async def perform_handshake(self):

self._handle_response_error_code(resp_dict, "Unable to complete handshake")

if TYPE_CHECKING:
resp_dict = cast(Dict[str, Any], resp_dict)

handshake_key = resp_dict["result"]["key"]

if (
Expand All @@ -283,12 +289,12 @@ async def perform_handshake(self):

self._session_expire_at = time.time() + 86400
if TYPE_CHECKING:
assert self._key_pair is not None # pragma: no cover
assert self._key_pair is not None
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, self._key_pair
)

self._handshake_done = True
self._state = TransportState.LOGIN_REQUIRED

_LOGGER.debug("Handshake with %s complete", self._host)

Expand All @@ -299,17 +305,20 @@ def _handshake_session_expired(self):
or self._session_expire_at - time.time() <= 0
)

async def send(self, request: str):
async def send(self, request: str) -> Dict[str, Any]:
"""Send the request."""
if not self._handshake_done or self._handshake_session_expired():
if (
self._state is TransportState.HANDSHAKE_REQUIRED
or self._handshake_session_expired()
):
await self.perform_handshake()
if not self._login_token:
if self._state is not TransportState.ESTABLISHED:
try:
await self.perform_login()
# After a login failure handshake needs to
# be redone or a 9999 error is received.
except AuthenticationException as ex:
self._handshake_done = False
self._state = TransportState.HANDSHAKE_REQUIRED
raise ex

return await self.send_secure_passthrough(request)
Expand All @@ -321,8 +330,7 @@ async def close(self) -> None:

async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._handshake_done = False
self._login_token = None
self._state = TransportState.HANDSHAKE_REQUIRED


class AesEncyptionSession:
Expand Down
10 changes: 5 additions & 5 deletions kasa/tests/test_aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding

from ..aestransport import AesEncyptionSession, AesTransport
from ..aestransport import AesEncyptionSession, AesTransport, TransportState
from ..credentials import Credentials
from ..deviceconfig import DeviceConfig
from ..exceptions import (
Expand Down Expand Up @@ -66,11 +66,11 @@ async def test_handshake(
)

assert transport._encryption_session is None
assert transport._handshake_done is False
assert transport._state is TransportState.HANDSHAKE_REQUIRED
with expectation:
await transport.perform_handshake()
assert transport._encryption_session is not None
assert transport._handshake_done is True
assert transport._state is TransportState.LOGIN_REQUIRED


@status_parameters
Expand All @@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session

Expand Down Expand Up @@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
transport = AesTransport(
config=DeviceConfig(host, credentials=Credentials("foo", "bar"))
)
transport._handshake_done = True
transport._state = TransportState.LOGIN_REQUIRED
transport._session_expire_at = time.time() + 86400
transport._encryption_session = mock_aes_device.encryption_session

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,16 @@ omit = ["kasa/tests/*"]

[tool.coverage.report]
exclude_lines = [
# ignore abstract methods
# Don't complain if tests don't hit defensive assertion code:
"raise AssertionError",
"raise NotImplementedError",
"def __repr__"
# Don't complain about missing debug-only code:
"def __repr__",
# Have to re-enable the standard pragma
"pragma: no cover",
# TYPE_CHECKING and @overload blocks are never executed during pytest run
"if TYPE_CHECKING:",
"@overload"
]

[tool.pytest.ini_options]
Expand Down