From d2ffbbe82a90b7b20c88f369b8697c8e9e3d43e5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 12:51:21 -1000 Subject: [PATCH 01/12] Refactor aestransport to use a state enum --- kasa/aestransport.py | 69 +++++++++++++++++++-------------- kasa/tests/test_aestransport.py | 10 ++--- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 73d02b0ee..e3dc3e95e 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -8,7 +8,8 @@ import hashlib import logging import time -from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional, cast +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Optional, Tuple, cast from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding @@ -41,6 +42,14 @@ def _sha1(payload: bytes) -> str: return sha1_algo.hexdigest() +class AesState(Enum): + """Enum for AES state.""" + + HANDSHAKE = auto() + LOGIN = auto() + ESTABLISHED = auto() + + class AesTransport(BaseTransport): """Implementation of the AES encryption protocol. @@ -79,21 +88,21 @@ def __init__( self._default_credentials: Optional[Credentials] = None self._http_client: HttpClient = HttpClient(config) - self._handshake_done = False + self._state = AesState.HANDSHAKE self._encryption_session: Optional[AesEncyptionSession] = None self._session_expire_at: Optional[float] = None self._session_cookie: Optional[Dict[str, str]] = None - self._login_token = None + self._login_token: Optional[str] = None self._key_pair: Optional[KeyPair] = None _LOGGER.debug("Created AES transport for %s", self._host) @property - def default_port(self): + def default_port(self) -> int: """Default port for the transport.""" return self.DEFAULT_PORT @@ -102,14 +111,14 @@ def credentials_hash(self) -> str: """The hashed credentials used by the transport.""" return base64.b64encode(json_dumps(self._login_params).encode()).decode() - def _get_login_params(self, credentials): + def _get_login_params(self, credentials: Credentials) -> Dict[str, str]: """Get the login parameters based on the login_version.""" un, pw = self.hash_credentials(self._login_version == 2, credentials) password_field_name = "password2" if self._login_version == 2 else "password" return {password_field_name: pw, "username": un} @staticmethod - def hash_credentials(login_v2, credentials): + def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]: """Hash the credentials.""" if login_v2: un = base64.b64encode( @@ -125,8 +134,8 @@ def hash_credentials(login_v2, credentials): pw = base64.b64encode(credentials.password.encode()).decode() return un, pw - def _handle_response_error_code(self, resp_dict: dict, msg: str): - error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] + def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: + error_code = SmartErrorCode((resp_dict or {}).get("error_code", -1001)) if error_code == SmartErrorCode.SUCCESS: return msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" @@ -135,12 +144,11 @@ def _handle_response_error_code(self, resp_dict: dict, msg: str): if error_code in SMART_RETRYABLE_ERRORS: raise RetryableException(msg, error_code=error_code) if error_code in SMART_AUTHENTICATION_ERRORS: - self._handshake_done = False - self._login_token = None + self._state = AesState.HANDSHAKE raise AuthenticationException(msg, error_code=error_code) raise SmartDeviceException(msg, error_code=error_code) - async def send_secure_passthrough(self, request: str): + async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: """Send encrypted message as passthrough.""" url = f"http://{self._host}/app" if self._login_token: @@ -165,16 +173,17 @@ async def send_secure_passthrough(self, request: str): + f"status code {status_code} to passthrough" ) - resp_dict = cast(Dict, resp_dict) self._handle_response_error_code( resp_dict, "Error sending secure_passthrough message" ) - response = self._encryption_session.decrypt( # type: ignore - resp_dict["result"]["response"].encode() - ) - resp_dict = json_loads(response) - return resp_dict + if TYPE_CHECKING: + resp_dict = cast(Dict[str, Any], resp_dict) # pragma: no cover + assert self._encryption_session is not None # pragma: no cover + + raw_response: str = resp_dict["result"]["response"] + response = self._encryption_session.decrypt(raw_response.encode()) + return json_loads(response) # type: ignore[return-value] async def perform_login(self): """Login to the device.""" @@ -182,7 +191,7 @@ async def perform_login(self): await self.try_login(self._login_params) except AuthenticationException as aex: try: - if aex.error_code != SmartErrorCode.LOGIN_ERROR: + if aex.error_code is not SmartErrorCode.LOGIN_ERROR: raise aex if self._default_credentials is None: self._default_credentials = get_default_credentials( @@ -203,9 +212,8 @@ async def perform_login(self): ex, ) from ex - async def try_login(self, login_params): + async def try_login(self, login_params: Dict[str, Any]) -> None: """Try to login with supplied login_params.""" - self._login_token = None login_request = { "method": "login_device", "params": login_params, @@ -236,12 +244,11 @@ async def _generate_key_pair_payload(self) -> AsyncGenerator: _LOGGER.debug(f"Request {request_body}") yield json_dumps(request_body).encode() - async def perform_handshake(self): + async def perform_handshake(self) -> None: """Perform the handshake.""" _LOGGER.debug("Will perform handshaking...") self._key_pair = None - self._handshake_done = False self._session_expire_at = None self._session_cookie = None @@ -258,7 +265,7 @@ async def perform_handshake(self): cookies_dict=self._session_cookie, ) - _LOGGER.debug(f"Device responded with: {resp_dict}") + _LOGGER.debug("Device responded with: %s", resp_dict) if status_code != 200: raise SmartDeviceException( @@ -268,6 +275,9 @@ async def perform_handshake(self): self._handle_response_error_code(resp_dict, "Unable to complete handshake") + if TYPE_CHECKING: + resp_dict = cast(Dict[str, Any], resp_dict) # pragma: no cover + handshake_key = resp_dict["result"]["key"] if ( @@ -288,7 +298,7 @@ async def perform_handshake(self): handshake_key, self._key_pair ) - self._handshake_done = True + self._state = AesState.LOGIN _LOGGER.debug("Handshake with %s complete", self._host) @@ -299,17 +309,17 @@ def _handshake_session_expired(self): or self._session_expire_at - time.time() <= 0 ) - async def send(self, request: str): + async def send(self, request: str) -> Dict[str, Any]: """Send the request.""" - if not self._handshake_done or self._handshake_session_expired(): + if self._state is AesState.HANDSHAKE or self._handshake_session_expired(): await self.perform_handshake() - if not self._login_token: + if self._state is not AesState.ESTABLISHED: try: await self.perform_login() # After a login failure handshake needs to # be redone or a 9999 error is received. except AuthenticationException as ex: - self._handshake_done = False + self._state = AesState.HANDSHAKE raise ex return await self.send_secure_passthrough(request) @@ -321,8 +331,7 @@ async def close(self) -> None: async def reset(self) -> None: """Reset internal handshake and login state.""" - self._handshake_done = False - self._login_token = None + self._state = AesState.HANDSHAKE class AesEncyptionSession: diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index cfd292845..d9f7cb29d 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -10,7 +10,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding -from ..aestransport import AesEncyptionSession, AesTransport +from ..aestransport import AesEncyptionSession, AesState, AesTransport from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import ( @@ -66,11 +66,11 @@ async def test_handshake( ) assert transport._encryption_session is None - assert transport._handshake_done is False + assert transport._state is AesState.HANDSHAKE with expectation: await transport.perform_handshake() assert transport._encryption_session is not None - assert transport._handshake_done is True + assert transport._state is AesState.LOGIN @status_parameters @@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) - transport._handshake_done = True + transport._state = AesState.LOGIN transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session @@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) - transport._handshake_done = True + transport._state = AesState.LOGIN transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session From d98ee3e235fdd5dfd29fdef21936cb0a684f478c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 12:53:03 -1000 Subject: [PATCH 02/12] Refactor aestransport to use a state enum --- kasa/aestransport.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index e3dc3e95e..297324a82 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -135,7 +135,9 @@ def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str return un, pw def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: - error_code = SmartErrorCode((resp_dict or {}).get("error_code", -1001)) + error_code = SmartErrorCode( + (resp_dict or {}).get("error_code", SmartErrorCode.UNSPECIFIC_ERROR.value) + ) if error_code == SmartErrorCode.SUCCESS: return msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" From fa313097961429e3fd80a89e3fc6ea70b45df62e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 12:56:18 -1000 Subject: [PATCH 03/12] Refactor aestransport to use a state enum --- kasa/aestransport.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 297324a82..32444e5a7 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -120,17 +120,12 @@ def _get_login_params(self, credentials: Credentials) -> Dict[str, str]: @staticmethod def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str]: """Hash the credentials.""" + un = base64.b64encode(_sha1(credentials.username.encode()).encode()).decode() if login_v2: - un = base64.b64encode( - _sha1(credentials.username.encode()).encode() - ).decode() pw = base64.b64encode( _sha1(credentials.password.encode()).encode() ).decode() else: - un = base64.b64encode( - _sha1(credentials.username.encode()).encode() - ).decode() pw = base64.b64encode(credentials.password.encode()).decode() return un, pw From 2e7613bcae9ffdc828529aa5a384c87b9f4b4517 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 12:58:25 -1000 Subject: [PATCH 04/12] Refactor aestransport to use a state enum --- kasa/aestransport.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 32444e5a7..97a3ae3c4 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -221,6 +221,7 @@ async def try_login(self, login_params: Dict[str, Any]) -> None: resp_dict = await self.send_secure_passthrough(request) self._handle_response_error_code(resp_dict, "Error logging in") self._login_token = resp_dict["result"]["token"] + self._state = AesState.ESTABLISHED async def _generate_key_pair_payload(self) -> AsyncGenerator: """Generate the request body and return an ascyn_generator. From a68e8bacde0f24fa35c1bcb607359fe818da43fc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 13:00:06 -1000 Subject: [PATCH 05/12] Refactor aestransport to use a state enum --- kasa/aestransport.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 97a3ae3c4..bc6f22e55 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -45,9 +45,9 @@ def _sha1(payload: bytes) -> str: class AesState(Enum): """Enum for AES state.""" - HANDSHAKE = auto() - LOGIN = auto() - ESTABLISHED = auto() + HANDSHAKE = auto() # Handshake needed + LOGIN = auto() # Login needed + ESTABLISHED = auto() # Ready to send requests class AesTransport(BaseTransport): From 1d541244e261d3518ce965d3f8cea97d8c497a27 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 13:05:36 -1000 Subject: [PATCH 06/12] Refactor aestransport to use a state enum --- .coveragerc | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..8209d716f --- /dev/null +++ b/.coveragerc @@ -0,0 +1,21 @@ +[run] +source = kasa +omit = + kasa/tests + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # TYPE_CHECKING and @overload blocks are never executed during pytest run + if TYPE_CHECKING: + @overload From ac2a8f37f709c51236346089a65201b7d26a4d66 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 13:06:33 -1000 Subject: [PATCH 07/12] Refactor aestransport to use a state enum --- kasa/aestransport.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index bc6f22e55..d449036be 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -175,8 +175,8 @@ async def send_secure_passthrough(self, request: str) -> Dict[str, Any]: ) if TYPE_CHECKING: - resp_dict = cast(Dict[str, Any], resp_dict) # pragma: no cover - assert self._encryption_session is not None # pragma: no cover + resp_dict = cast(Dict[str, Any], resp_dict) + assert self._encryption_session is not None raw_response: str = resp_dict["result"]["response"] response = self._encryption_session.decrypt(raw_response.encode()) @@ -274,7 +274,7 @@ async def perform_handshake(self) -> None: self._handle_response_error_code(resp_dict, "Unable to complete handshake") if TYPE_CHECKING: - resp_dict = cast(Dict[str, Any], resp_dict) # pragma: no cover + resp_dict = cast(Dict[str, Any], resp_dict) handshake_key = resp_dict["result"]["key"] @@ -291,7 +291,7 @@ async def perform_handshake(self) -> None: self._session_expire_at = time.time() + 86400 if TYPE_CHECKING: - assert self._key_pair is not None # pragma: no cover + assert self._key_pair is not None self._encryption_session = AesEncyptionSession.create_from_keypair( handshake_key, self._key_pair ) From ade408fe734fffaaf45fd7b6458cfff639e13e14 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 15:12:37 -1000 Subject: [PATCH 08/12] adjust naming --- kasa/aestransport.py | 19 +++++++++++-------- kasa/tests/test_aestransport.py | 8 ++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index d449036be..83288f599 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -45,8 +45,8 @@ def _sha1(payload: bytes) -> str: class AesState(Enum): """Enum for AES state.""" - HANDSHAKE = auto() # Handshake needed - LOGIN = auto() # Login needed + HANDSHAKE_REQUIRED = auto() # Handshake needed + LOGIN_REQUIRED = auto() # Login needed ESTABLISHED = auto() # Ready to send requests @@ -88,7 +88,7 @@ def __init__( self._default_credentials: Optional[Credentials] = None self._http_client: HttpClient = HttpClient(config) - self._state = AesState.HANDSHAKE + self._state = AesState.HANDSHAKE_REQUIRED self._encryption_session: Optional[AesEncyptionSession] = None self._session_expire_at: Optional[float] = None @@ -141,7 +141,7 @@ def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: if error_code in SMART_RETRYABLE_ERRORS: raise RetryableException(msg, error_code=error_code) if error_code in SMART_AUTHENTICATION_ERRORS: - self._state = AesState.HANDSHAKE + self._state = AesState.HANDSHAKE_REQUIRED raise AuthenticationException(msg, error_code=error_code) raise SmartDeviceException(msg, error_code=error_code) @@ -296,7 +296,7 @@ async def perform_handshake(self) -> None: handshake_key, self._key_pair ) - self._state = AesState.LOGIN + self._state = AesState.LOGIN_REQUIRED _LOGGER.debug("Handshake with %s complete", self._host) @@ -309,7 +309,10 @@ def _handshake_session_expired(self): async def send(self, request: str) -> Dict[str, Any]: """Send the request.""" - if self._state is AesState.HANDSHAKE or self._handshake_session_expired(): + if ( + self._state is AesState.HANDSHAKE_REQUIRED + or self._handshake_session_expired() + ): await self.perform_handshake() if self._state is not AesState.ESTABLISHED: try: @@ -317,7 +320,7 @@ async def send(self, request: str) -> Dict[str, Any]: # After a login failure handshake needs to # be redone or a 9999 error is received. except AuthenticationException as ex: - self._state = AesState.HANDSHAKE + self._state = AesState.HANDSHAKE_REQUIRED raise ex return await self.send_secure_passthrough(request) @@ -329,7 +332,7 @@ async def close(self) -> None: async def reset(self) -> None: """Reset internal handshake and login state.""" - self._state = AesState.HANDSHAKE + self._state = AesState.HANDSHAKE_REQUIRED class AesEncyptionSession: diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index d9f7cb29d..dcd9ac459 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -66,11 +66,11 @@ async def test_handshake( ) assert transport._encryption_session is None - assert transport._state is AesState.HANDSHAKE + assert transport._state is AesState.HANDSHAKE_REQUIRED with expectation: await transport.perform_handshake() assert transport._encryption_session is not None - assert transport._state is AesState.LOGIN + assert transport._state is AesState.LOGIN_REQUIRED @status_parameters @@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) - transport._state = AesState.LOGIN + transport._state = AesState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session @@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) - transport._state = AesState.LOGIN + transport._state = AesState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session From e88e7c3901fe799d189e2d4cb52af72380a69ea6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 15:13:12 -1000 Subject: [PATCH 09/12] revert --- kasa/aestransport.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 83288f599..07bec7b4c 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -130,9 +130,7 @@ def hash_credentials(login_v2: bool, credentials: Credentials) -> Tuple[str, str return un, pw def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: - error_code = SmartErrorCode( - (resp_dict or {}).get("error_code", SmartErrorCode.UNSPECIFIC_ERROR.value) - ) + error_code = SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type] if error_code == SmartErrorCode.SUCCESS: return msg = f"{msg}: {self._host}: {error_code.name}({error_code.value})" From ec7ec9c0ba45602abd8dc3eeae8a28aeb115992d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 15:13:51 -1000 Subject: [PATCH 10/12] state name --- kasa/aestransport.py | 18 +++++++++--------- kasa/tests/test_aestransport.py | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/kasa/aestransport.py b/kasa/aestransport.py index 07bec7b4c..5269d185c 100644 --- a/kasa/aestransport.py +++ b/kasa/aestransport.py @@ -42,7 +42,7 @@ def _sha1(payload: bytes) -> str: return sha1_algo.hexdigest() -class AesState(Enum): +class TransportState(Enum): """Enum for AES state.""" HANDSHAKE_REQUIRED = auto() # Handshake needed @@ -88,7 +88,7 @@ def __init__( self._default_credentials: Optional[Credentials] = None self._http_client: HttpClient = HttpClient(config) - self._state = AesState.HANDSHAKE_REQUIRED + self._state = TransportState.HANDSHAKE_REQUIRED self._encryption_session: Optional[AesEncyptionSession] = None self._session_expire_at: Optional[float] = None @@ -139,7 +139,7 @@ def _handle_response_error_code(self, resp_dict: Any, msg: str) -> None: if error_code in SMART_RETRYABLE_ERRORS: raise RetryableException(msg, error_code=error_code) if error_code in SMART_AUTHENTICATION_ERRORS: - self._state = AesState.HANDSHAKE_REQUIRED + self._state = TransportState.HANDSHAKE_REQUIRED raise AuthenticationException(msg, error_code=error_code) raise SmartDeviceException(msg, error_code=error_code) @@ -219,7 +219,7 @@ async def try_login(self, login_params: Dict[str, Any]) -> None: resp_dict = await self.send_secure_passthrough(request) self._handle_response_error_code(resp_dict, "Error logging in") self._login_token = resp_dict["result"]["token"] - self._state = AesState.ESTABLISHED + self._state = TransportState.ESTABLISHED async def _generate_key_pair_payload(self) -> AsyncGenerator: """Generate the request body and return an ascyn_generator. @@ -294,7 +294,7 @@ async def perform_handshake(self) -> None: handshake_key, self._key_pair ) - self._state = AesState.LOGIN_REQUIRED + self._state = TransportState.LOGIN_REQUIRED _LOGGER.debug("Handshake with %s complete", self._host) @@ -308,17 +308,17 @@ def _handshake_session_expired(self): async def send(self, request: str) -> Dict[str, Any]: """Send the request.""" if ( - self._state is AesState.HANDSHAKE_REQUIRED + self._state is TransportState.HANDSHAKE_REQUIRED or self._handshake_session_expired() ): await self.perform_handshake() - if self._state is not AesState.ESTABLISHED: + if self._state is not TransportState.ESTABLISHED: try: await self.perform_login() # After a login failure handshake needs to # be redone or a 9999 error is received. except AuthenticationException as ex: - self._state = AesState.HANDSHAKE_REQUIRED + self._state = TransportState.HANDSHAKE_REQUIRED raise ex return await self.send_secure_passthrough(request) @@ -330,7 +330,7 @@ async def close(self) -> None: async def reset(self) -> None: """Reset internal handshake and login state.""" - self._state = AesState.HANDSHAKE_REQUIRED + self._state = TransportState.HANDSHAKE_REQUIRED class AesEncyptionSession: diff --git a/kasa/tests/test_aestransport.py b/kasa/tests/test_aestransport.py index dcd9ac459..086f6ea60 100644 --- a/kasa/tests/test_aestransport.py +++ b/kasa/tests/test_aestransport.py @@ -10,7 +10,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding -from ..aestransport import AesEncyptionSession, AesState, AesTransport +from ..aestransport import AesEncyptionSession, AesTransport, TransportState from ..credentials import Credentials from ..deviceconfig import DeviceConfig from ..exceptions import ( @@ -66,11 +66,11 @@ async def test_handshake( ) assert transport._encryption_session is None - assert transport._state is AesState.HANDSHAKE_REQUIRED + assert transport._state is TransportState.HANDSHAKE_REQUIRED with expectation: await transport.perform_handshake() assert transport._encryption_session is not None - assert transport._state is AesState.LOGIN_REQUIRED + assert transport._state is TransportState.LOGIN_REQUIRED @status_parameters @@ -82,7 +82,7 @@ async def test_login(mocker, status_code, error_code, inner_error_code, expectat transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) - transport._state = AesState.LOGIN_REQUIRED + transport._state = TransportState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session @@ -129,7 +129,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count): transport = AesTransport( config=DeviceConfig(host, credentials=Credentials("foo", "bar")) ) - transport._state = AesState.LOGIN_REQUIRED + transport._state = TransportState.LOGIN_REQUIRED transport._session_expire_at = time.time() + 86400 transport._encryption_session = mock_aes_device.encryption_session From 5fa59d94b688e12748bc8faaa9ada99968a51bf4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 15:15:37 -1000 Subject: [PATCH 11/12] use pyproject --- .coveragerc | 21 --------------------- pyproject.toml | 9 ++++++++- 2 files changed, 8 insertions(+), 22 deletions(-) delete mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 8209d716f..000000000 --- a/.coveragerc +++ /dev/null @@ -1,21 +0,0 @@ -[run] -source = kasa -omit = - kasa/tests - -[report] -# Regexes for lines to exclude from consideration -exclude_lines = - # Have to re-enable the standard pragma - pragma: no cover - - # Don't complain about missing debug-only code: - def __repr__ - - # Don't complain if tests don't hit defensive assertion code: - raise AssertionError - raise NotImplementedError - - # TYPE_CHECKING and @overload blocks are never executed during pytest run - if TYPE_CHECKING: - @overload diff --git a/pyproject.toml b/pyproject.toml index 6bd81a900..015f68097 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,9 +65,16 @@ omit = ["kasa/tests/*"] [tool.coverage.report] exclude_lines = [ - # ignore abstract methods + # Don't complain if tests don't hit defensive assertion code: + "raise AssertionError", "raise NotImplementedError", + # Don't complain about missing debug-only code: "def __repr__" + # Have to re-enable the standard pragma + "pragma: no cover", + # TYPE_CHECKING and @overload blocks are never executed during pytest run + "if TYPE_CHECKING:" + "@overload" ] [tool.pytest.ini_options] From 36334cdbf0edd459122d7047ee54c10fc8bc1d3b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 23 Jan 2024 15:16:06 -1000 Subject: [PATCH 12/12] lint --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 015f68097..206565559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,11 +69,11 @@ exclude_lines = [ "raise AssertionError", "raise NotImplementedError", # Don't complain about missing debug-only code: - "def __repr__" + "def __repr__", # Have to re-enable the standard pragma "pragma: no cover", # TYPE_CHECKING and @overload blocks are never executed during pytest run - "if TYPE_CHECKING:" + "if TYPE_CHECKING:", "@overload" ]