|
20 | 20 | USA
|
21 | 21 | """
|
22 | 22 |
|
| 23 | +import asyncio |
23 | 24 | import ipaddress
|
24 | 25 | import random
|
25 | 26 | from functools import lru_cache
|
|
37 | 38 | from .._logger import log
|
38 | 39 | from .._protocol.outgoing import DNSOutgoing
|
39 | 40 | from .._updates import RecordUpdate, RecordUpdateListener
|
40 |
| -from .._utils.asyncio import get_running_loop, run_coro_with_timeout |
| 41 | +from .._utils.asyncio import ( |
| 42 | + get_running_loop, |
| 43 | + run_coro_with_timeout, |
| 44 | + wait_event_or_timeout, |
| 45 | +) |
41 | 46 | from .._utils.name import service_type_name
|
42 | 47 | from .._utils.net import IPVersion, _encode_address
|
43 |
| -from .._utils.time import current_time_millis |
| 48 | +from .._utils.time import current_time_millis, millis_to_seconds |
44 | 49 | from ..const import (
|
45 | 50 | _CLASS_IN,
|
46 | 51 | _CLASS_UNIQUE,
|
@@ -166,6 +171,7 @@ def __init__(
|
166 | 171 | self.host_ttl = host_ttl
|
167 | 172 | self.other_ttl = other_ttl
|
168 | 173 | self.interface_index = interface_index
|
| 174 | + self._notify_event: Optional[asyncio.Event] = None |
169 | 175 |
|
170 | 176 | @property
|
171 | 177 | def name(self) -> str:
|
@@ -221,6 +227,12 @@ def properties(self) -> Dict:
|
221 | 227 | """
|
222 | 228 | return self._properties
|
223 | 229 |
|
| 230 | + async def async_wait(self, timeout: float) -> None: |
| 231 | + """Calling task waits for a given number of milliseconds or until notified.""" |
| 232 | + if self._notify_event is None: |
| 233 | + self._notify_event = asyncio.Event() |
| 234 | + await wait_event_or_timeout(self._notify_event, timeout=millis_to_seconds(timeout)) |
| 235 | + |
224 | 236 | def addresses_by_version(self, version: IPVersion) -> List[bytes]:
|
225 | 237 | """List addresses matching IP version.
|
226 | 238 |
|
@@ -384,7 +396,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordU
|
384 | 396 |
|
385 | 397 | This method will be run in the event loop.
|
386 | 398 | """
|
387 |
| - self._process_records_threadsafe(zc, now, records) |
| 399 | + if self._process_records_threadsafe(zc, now, records) and self._notify_event: |
| 400 | + self._notify_event.set() |
| 401 | + self._notify_event.clear() |
388 | 402 |
|
389 | 403 | def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> bool:
|
390 | 404 | """Thread safe record updating.
|
@@ -605,7 +619,7 @@ async def async_request(
|
605 | 619 | delay *= 2
|
606 | 620 | next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL)
|
607 | 621 |
|
608 |
| - await zc.async_wait(min(next_, last) - now) |
| 622 | + await self.async_wait(min(next_, last) - now) |
609 | 623 | now = current_time_millis()
|
610 | 624 | finally:
|
611 | 625 | zc.async_remove_listener(self)
|
|
0 commit comments