Skip to content

Improve smartprotocol error handling and retries #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 63 additions & 36 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
TimeoutException,
)
from .json import dumps as json_dumps
from .json import loads as json_loads
from .protocol import BaseTransport
Expand Down Expand Up @@ -110,6 +119,21 @@ async def client_post(self, url, params=None, data=None, json=None, headers=None

return resp.status_code, response_data

def _handle_response_error_code(self, resp_dict: dict, msg: str):
if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = f"{msg}: {self.host}: {error_code.name}({error_code.value})"
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
self._handshake_done = False
self._login_token = None
raise AuthenticationException(msg)
raise SmartDeviceException(msg)

async def send_secure_passthrough(self, request: str):
"""Send encrypted message as passthrough."""
url = f"http://{self.host}/app"
Expand All @@ -123,17 +147,22 @@ async def send_secure_passthrough(self, request: str):
}
status_code, resp_dict = await self.client_post(url, json=passthrough_request)
# _LOGGER.debug(f"secure_passthrough response is {status_code}: {resp_dict}")
if status_code == 200 and resp_dict["error_code"] == 0:
response = self._encryption_session.decrypt( # type: ignore
resp_dict["result"]["response"].encode()

if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
+ f"status code {status_code} to passthrough"
)
_LOGGER.debug(f"decrypted secure_passthrough response is {response}")
resp_dict = json_loads(response)
return resp_dict
else:
self._handshake_done = False
self._login_token = None
raise AuthenticationException("Could not complete send")

self._handle_response_error_code(
resp_dict, "Error sending secure_passthrough message"
)

response = self._encryption_session.decrypt( # type: ignore
resp_dict["result"]["response"].encode()
)
resp_dict = json_loads(response)
return resp_dict

async def perform_login(self, login_request: Union[str, dict], *, login_v2: bool):
"""Login to the device."""
Expand Down Expand Up @@ -207,29 +236,32 @@ async def perform_handshake(self):

_LOGGER.debug(f"Device responded with: {resp_dict}")

if status_code == 200 and resp_dict["error_code"] == 0:
_LOGGER.debug("Decoding handshake key...")
handshake_key = resp_dict["result"]["key"]

self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
if status_code != 200:
raise SmartDeviceException(
f"{self.host} responded with an unexpected "
+ f"status code {status_code} to handshake"
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
"SESSIONID"
)

self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
self._handle_response_error_code(resp_dict, "Unable to complete handshake")

handshake_key = resp_dict["result"]["key"]

self._session_cookie = self._http_client.cookies.get( # type: ignore
self.SESSION_COOKIE_NAME
)
if not self._session_cookie:
self._session_cookie = self._http_client.cookies.get( # type: ignore
"SESSIONID"
)

self._handshake_done = True
self._session_expire_at = time.time() + 86400
self._encryption_session = AesEncyptionSession.create_from_keypair(
handshake_key, key_pair
)

_LOGGER.debug("Handshake with %s complete", self.host)
self._handshake_done = True

else:
raise AuthenticationException("Could not complete handshake")
_LOGGER.debug("Handshake with %s complete", self.host)

def _handshake_session_expired(self):
"""Return true if session has expired."""
Expand All @@ -247,19 +279,14 @@ async def send(self, request: str):
if self.needs_login:
raise SmartDeviceException("Login must be complete before trying to send")

resp_dict = await self.send_secure_passthrough(request)
if resp_dict["error_code"] != 0:
self._handshake_done = False
self._login_token = None
raise SmartDeviceException(
f"Could not complete send, response was {resp_dict}",
)
return resp_dict
return await self.send_secure_passthrough(request)

async def close(self) -> None:
"""Close the protocol."""
client = self._http_client
self._http_client = None
self._handshake_done = False
self._login_token = None
if client:
await client.aclose()

Expand Down
85 changes: 85 additions & 0 deletions kasa/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""python-kasa exceptions."""
from enum import IntEnum


class SmartDeviceException(Exception):
Expand All @@ -11,3 +12,87 @@ class UnsupportedDeviceException(SmartDeviceException):

class AuthenticationException(SmartDeviceException):
"""Base exception for device authentication errors."""


class RetryableException(SmartDeviceException):
"""Retryable exception for device errors."""


class TimeoutException(SmartDeviceException):
"""Timeout exception for device errors."""


class SmartErrorCode(IntEnum):
"""Enum for SMART Error Codes."""

SUCCESS = 0

# Transport Errors
SESSION_TIMEOUT_ERROR = 9999
MULTI_REQUEST_FAILED_ERROR = 1200
HTTP_TRANSPORT_FAILED_ERROR = 1112
LOGIN_FAILED_ERROR = 1111
HAND_SHAKE_FAILED_ERROR = 1100
TRANSPORT_NOT_AVAILABLE_ERROR = 1002
CMD_COMMAND_CANCEL_ERROR = 1001
NULL_TRANSPORT_ERROR = 1000

# Common Method Errors
COMMON_FAILED_ERROR = -1
UNSPECIFIC_ERROR = -1001
UNKNOWN_METHOD_ERROR = -1002
JSON_DECODE_FAIL_ERROR = -1003
JSON_ENCODE_FAIL_ERROR = -1004
AES_DECODE_FAIL_ERROR = -1005
REQUEST_LEN_ERROR_ERROR = -1006
CLOUD_FAILED_ERROR = -1007
PARAMS_ERROR = -1008
INVALID_PUBLIC_KEY_ERROR = -1010 # Unverified
SESSION_PARAM_ERROR = -1101

# Method Specific Errors
QUICK_SETUP_ERROR = -1201
DEVICE_ERROR = -1301
DEVICE_NEXT_EVENT_ERROR = -1302
FIRMWARE_ERROR = -1401
FIRMWARE_VER_ERROR_ERROR = -1402
LOGIN_ERROR = -1501
TIME_ERROR = -1601
TIME_SYS_ERROR = -1602
TIME_SAVE_ERROR = -1603
WIRELESS_ERROR = -1701
WIRELESS_UNSUPPORTED_ERROR = -1702
SCHEDULE_ERROR = -1801
SCHEDULE_FULL_ERROR = -1802
SCHEDULE_CONFLICT_ERROR = -1803
SCHEDULE_SAVE_ERROR = -1804
SCHEDULE_INDEX_ERROR = -1805
COUNTDOWN_ERROR = -1901
COUNTDOWN_CONFLICT_ERROR = -1902
COUNTDOWN_SAVE_ERROR = -1903
ANTITHEFT_ERROR = -2001
ANTITHEFT_CONFLICT_ERROR = -2002
ANTITHEFT_SAVE_ERROR = -2003
ACCOUNT_ERROR = -2101
STAT_ERROR = -2201
STAT_SAVE_ERROR = -2202
DST_ERROR = -2301
DST_SAVE_ERROR = -2302


SMART_RETRYABLE_ERRORS = [
SmartErrorCode.TRANSPORT_NOT_AVAILABLE_ERROR,
SmartErrorCode.HTTP_TRANSPORT_FAILED_ERROR,
SmartErrorCode.UNSPECIFIC_ERROR,
]

SMART_AUTHENTICATION_ERRORS = [
SmartErrorCode.LOGIN_ERROR,
SmartErrorCode.LOGIN_FAILED_ERROR,
SmartErrorCode.AES_DECODE_FAIL_ERROR,
SmartErrorCode.HAND_SHAKE_FAILED_ERROR,
]

SMART_TIMEOUT_ERRORS = [
SmartErrorCode.SESSION_TIMEOUT_ERROR,
]
1 change: 1 addition & 0 deletions kasa/klaptransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ async def close(self) -> None:
"""Close the transport."""
client = self._http_client
self._http_client = None
self._handshake_done = False
if client:
await client.aclose()

Expand Down
61 changes: 54 additions & 7 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@

from .aestransport import AesTransport
from .credentials import Credentials
from .exceptions import AuthenticationException, SmartDeviceException
from .exceptions import (
SMART_AUTHENTICATION_ERRORS,
SMART_RETRYABLE_ERRORS,
SMART_TIMEOUT_ERRORS,
AuthenticationException,
RetryableException,
SmartDeviceException,
SmartErrorCode,
TimeoutException,
)
from .json import dumps as json_dumps
from .protocol import BaseTransport, TPLinkProtocol, md5

Expand All @@ -28,6 +37,7 @@ class SmartProtocol(TPLinkProtocol):
"""Class for the new TPLink SMART protocol."""

DEFAULT_PORT = 80
SLEEP_SECONDS_AFTER_TIMEOUT = 1

def __init__(
self,
Expand Down Expand Up @@ -64,6 +74,22 @@ async def query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
"""Query the device retrying for retry_count on failure."""
async with self._query_lock:
resp_dict = await self._query(request, retry_count)

if (
error_code := SmartErrorCode(resp_dict.get("error_code")) # type: ignore[arg-type]
) != SmartErrorCode.SUCCESS:
msg = (
f"Error querying device: {self.host}: "
+ f"{error_code.name}({error_code.value})"
)
if error_code in SMART_TIMEOUT_ERRORS:
raise TimeoutException(msg)
if error_code in SMART_RETRYABLE_ERRORS:
raise RetryableException(msg)
if error_code in SMART_AUTHENTICATION_ERRORS:
raise AuthenticationException(msg)
raise SmartDeviceException(msg)

if "result" in resp_dict:
return resp_dict["result"]
return {}
Expand All @@ -86,20 +112,41 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
f"Unable to connect to the device: {self.host}: {cex}"
) from cex
except TimeoutError as tex:
await self.close()
raise SmartDeviceException(
f"Unable to connect to the device, timed out: {self.host}: {tex}"
) from tex
if retry >= retry_count:
await self.close()
raise SmartDeviceException(
"Unable to connect to the device, "
+ f"timed out: {self.host}: {tex}"
) from tex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except AuthenticationException as auex:
await self.close()
_LOGGER.debug("Unable to authenticate with %s, not retrying", self.host)
raise auex
except RetryableException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise ex
continue
except TimeoutException as ex:
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise ex
await asyncio.sleep(self.SLEEP_SECONDS_AFTER_TIMEOUT)
continue
except Exception as ex:
await self.close()
if retry >= retry_count:
await self.close()
_LOGGER.debug("Giving up on %s after %s retries", self.host, retry)
raise SmartDeviceException(
f"Unable to connect to the device: {self.host}: {ex}"
f"Unable to query the device {self.host}:{self.port}: {ex}"
) from ex
_LOGGER.debug(
"Unable to query the device %s, retrying: %s", self.host, ex
)
continue

# make mypy happy, this should never be reached..
Expand Down
8 changes: 1 addition & 7 deletions kasa/tapo/tapobulb.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,7 @@ async def set_hsv(
if value is not None:
request_payload["brightness"] = value

return await self.protocol.query(
{
"set_device_info": {
**request_payload
}
}
)
return await self.protocol.query({"set_device_info": {**request_payload}})

async def set_color_temp(
self, temp: int, *, brightness=None, transition: Optional[int] = None
Expand Down
4 changes: 2 additions & 2 deletions kasa/tests/newfakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ async def send(self, request: str):
method = request_dict["method"]
params = request_dict["params"]
if method == "component_nego" or method[:4] == "get_":
return {"result": self.info[method]}
return {"result": self.info[method], "error_code": 0}
elif method[:4] == "set_":
target_method = f"get_{method[4:]}"
self.info[target_method].update(params)
return {"result": ""}
return {"error_code": 0}

async def close(self) -> None:
pass
Expand Down
2 changes: 1 addition & 1 deletion kasa/tests/test_klapprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def test_protocol_retry_recoverable_error(
async def test_protocol_reconnect(mocker, retry_count, protocol_class, transport_class):
host = "127.0.0.1"
remaining = retry_count
mock_response = {"result": {"great": "success"}}
mock_response = {"result": {"great": "success"}, "error_code": 0}

def _fail_one_less_than_retry_count(*_, **__):
nonlocal remaining
Expand Down