Skip to content

Update transport close/reset behaviour #689

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions kasa/aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ async def send(self, request: str):
return await self.send_secure_passthrough(request)

async def close(self) -> None:
"""Mark the handshake and login as not done.
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()

Since we likely lost the connection.
"""
async def reset(self) -> None:
"""Reset internal handshake and login state."""
self._handshake_done = False
self._login_token = None

Expand Down
8 changes: 6 additions & 2 deletions kasa/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import aiohttp

from .deviceconfig import DeviceConfig
from .exceptions import ConnectionException, SmartDeviceException, TimeoutException
from .exceptions import (
ConnectionException,
SmartDeviceException,
TimeoutException,
)
from .json import loads as json_loads


Expand Down Expand Up @@ -78,7 +82,7 @@ async def post(

except (aiohttp.ServerDisconnectedError, aiohttp.ClientOSError) as ex:
raise ConnectionException(
f"Unable to connect to the device: {self._config.host}: {ex}", ex
f"Device connection error: {self._config.host}: {ex}", ex
) from ex
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as ex:
raise TimeoutException(
Expand Down
16 changes: 5 additions & 11 deletions kasa/iotprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,31 @@ async def _query(self, request: str, retry_count: int = 3) -> Dict:
try:
return await self._execute_query(request, retry)
except ConnectionException as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
Expand All @@ -85,10 +84,5 @@ async def _execute_query(self, request: str, retry_count: int) -> Dict:
return await self._transport.send(request)

async def close(self) -> None:
"""Close the underlying transport.

Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
"""Close the underlying transport."""
await self._transport.close()
7 changes: 6 additions & 1 deletion kasa/klaptransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,12 @@ async def send(self, request: str):
return json_payload

async def close(self) -> None:
"""Mark the handshake as not done since we likely lost the connection."""
"""Close the http client and reset internal state."""
await self.reset()
await self._http_client.close()

async def reset(self) -> None:
"""Reset internal handshake state."""
self._handshake_done = False

@staticmethod
Expand Down
26 changes: 16 additions & 10 deletions kasa/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ async def send(self, request: str) -> Dict:
async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""

@abstractmethod
async def reset(self) -> None:
"""Reset internal state."""


class BaseProtocol(ABC):
"""Base class for all TP-Link Smart Home communication."""
Expand Down Expand Up @@ -139,7 +143,10 @@ async def send(self, request: str) -> Dict:
return {}

async def close(self) -> None:
"""Close the transport. Abstract method to be overriden."""
"""Close the transport."""

async def reset(self) -> None:
"""Reset internal state.."""


class TPLinkSmartHomeProtocol(BaseProtocol):
Expand Down Expand Up @@ -233,9 +240,9 @@ def close_without_wait(self) -> None:
if writer:
writer.close()

def _reset(self) -> None:
"""Clear any varibles that should not survive between loops."""
self.reader = self.writer = None
async def reset(self) -> None:
"""Reset the transport."""
await self.close()

async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
"""Try to query a device."""
Expand All @@ -252,20 +259,20 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
try:
await self._connect(timeout)
except ConnectionRefusedError as ex:
await self.close()
await self.reset()
raise SmartDeviceException(
f"Unable to connect to the device: {self._host}:{self._port}: {ex}"
) from ex
except OSError as ex:
await self.close()
await self.reset()
if ex.errno in _NO_RETRY_ERRORS or retry >= retry_count:
raise SmartDeviceException(
f"Unable to connect to the device:"
f" {self._host}:{self._port}: {ex}"
) from ex
continue
except Exception as ex:
await self.close()
await self.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
Expand All @@ -290,7 +297,7 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
async with asyncio_timeout(timeout):
return await self._execute_query(request)
except Exception as ex:
await self.close()
await self.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise SmartDeviceException(
Expand All @@ -312,7 +319,7 @@ async def _query(self, request: str, retry_count: int, timeout: int) -> Dict:
raise

# make mypy happy, this should never be reached..
await self.close()
await self.reset()
raise SmartDeviceException("Query reached somehow to unreachable")

def __del__(self) -> None:
Expand All @@ -322,7 +329,6 @@ def __del__(self) -> None:
# or in another thread so we need to make sure the call to
# close is called safely with call_soon_threadsafe
self.loop.call_soon_threadsafe(self.writer.close)
self._reset()

@staticmethod
def _xor_payload(unencrypted: bytes) -> Generator[int, None, None]:
Expand Down
4 changes: 4 additions & 0 deletions kasa/smartdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,10 @@ def config(self) -> DeviceConfig:
"""Return the device configuration."""
return self.protocol.config

async def disconnect(self):
"""Disconnect and close any underlying connection resources."""
await self.protocol.close()

@staticmethod
async def connect(
*,
Expand Down
16 changes: 5 additions & 11 deletions kasa/smartprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,31 @@ async def _query(self, request: Union[str, Dict], retry_count: int = 3) -> Dict:
try:
return await self._execute_query(request, retry)
except ConnectionException as sdex:
await self.close()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise sdex
continue
except AuthenticationException as auex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to authenticate with %s, not retrying", self._host
)
raise auex
except RetryableException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
continue
except TimeoutException as ex:
await self.close()
await self._transport.reset()
if retry >= retry_count:
_LOGGER.debug("Giving up on %s after %s retries", self._host, retry)
raise ex
await asyncio.sleep(self.BACKOFF_SECONDS_AFTER_TIMEOUT)
continue
except SmartDeviceException as ex:
await self.close()
await self._transport.reset()
_LOGGER.debug(
"Unable to query the device: %s, not retrying: %s",
self._host,
Expand Down Expand Up @@ -167,12 +166,7 @@ def _handle_response_error_code(self, resp_dict: dict):
raise SmartDeviceException(msg, error_code=error_code)

async def close(self) -> None:
"""Close the underlying transport.

Some transports may close the connection, and some may
use this as a hint that they need to reconnect, or
reauthenticate.
"""
"""Close the underlying transport."""
await self._transport.close()


Expand Down
11 changes: 9 additions & 2 deletions kasa/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Credentials,
Discover,
SmartBulb,
SmartDevice,
SmartDimmer,
SmartLightStrip,
SmartPlug,
Expand Down Expand Up @@ -416,9 +417,15 @@ async def dev(request):
IP_MODEL_CACHE[ip] = model = d.model
if model not in file:
pytest.skip(f"skipping file {file}")
return d if d else await _discover_update_and_close(ip, username, password)
dev: SmartDevice = (
d if d else await _discover_update_and_close(ip, username, password)
)
else:
dev: SmartDevice = await get_device_for_file(file, protocol)

yield dev

return await get_device_for_file(file, protocol)
await dev.disconnect()


@pytest.fixture
Expand Down
3 changes: 3 additions & 0 deletions kasa/tests/newfakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def _send_request(self, request_dict: dict):
async def close(self) -> None:
pass

async def reset(self) -> None:
pass


class FakeTransportProtocol(TPLinkSmartHomeProtocol):
def __init__(self, info):
Expand Down
1 change: 1 addition & 0 deletions kasa/tests/test_aestransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def test_login_errors(mocker, inner_error_codes, expectation, call_count):
await transport.send(json_dumps(request))
assert transport._login_token == mock_aes_device.token
assert post_mock.call_count == call_count # Login, Handshake, Login
await transport.close()


@status_parameters
Expand Down
2 changes: 2 additions & 0 deletions kasa/tests/test_device_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ async def test_connect(

assert dev.config == config

await dev.disconnect()


@pytest.mark.parametrize("custom_port", [123, None])
async def test_connect_custom_port(all_fixture_data: dict, mocker, custom_port):
Expand Down
4 changes: 2 additions & 2 deletions kasa/tests/test_httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
(
aiohttp.ServerDisconnectedError(),
ConnectionException,
"Unable to connect to the device: ",
"Device connection error: ",
),
(
aiohttp.ClientOSError(),
ConnectionException,
"Unable to connect to the device: ",
"Device connection error: ",
),
(
aiohttp.ServerTimeoutError(),
Expand Down
3 changes: 2 additions & 1 deletion kasa/tests/test_klapprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ async def read(self):
[
(Exception("dummy exception"), False),
(aiohttp.ServerTimeoutError("dummy exception"), True),
(aiohttp.ServerDisconnectedError("dummy exception"), True),
(aiohttp.ClientOSError("dummy exception"), True),
],
ids=("Exception", "SmartDeviceException", "ConnectError"),
ids=("Exception", "ServerTimeoutError", "ServerDisconnectedError", "ClientOSError"),
)
@pytest.mark.parametrize("transport_class", [AesTransport, KlapTransport])
@pytest.mark.parametrize("protocol_class", [IotProtocol, SmartProtocol])
Expand Down