Skip to content

Commit 9eac0a1

Browse files
authored
feat: speed up the query handler (#1350)
1 parent 7ffbed8 commit 9eac0a1

File tree

10 files changed

+107
-73
lines changed

10 files changed

+107
-73
lines changed

src/zeroconf/_core.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@
3131
from ._dns import DNSQuestion, DNSQuestionType
3232
from ._engine import AsyncEngine
3333
from ._exceptions import NonUniqueNameException, NotRunningException
34-
from ._handlers.answers import (
35-
construct_outgoing_multicast_answers,
36-
construct_outgoing_unicast_answers,
37-
)
3834
from ._handlers.multicast_outgoing_queue import MulticastOutgoingQueue
3935
from ._handlers.query_handler import QueryHandler
4036
from ._handlers.record_manager import RecordManager
@@ -187,15 +183,15 @@ def __init__(
187183
self.registry = ServiceRegistry()
188184
self.cache = DNSCache()
189185
self.question_history = QuestionHistory()
190-
self.query_handler = QueryHandler(self.registry, self.cache, self.question_history)
186+
self.query_handler = QueryHandler(self)
191187
self.record_manager = RecordManager(self)
192188

193189
self._notify_futures: Set[asyncio.Future] = set()
194190
self.loop: Optional[asyncio.AbstractEventLoop] = None
195191
self._loop_thread: Optional[threading.Thread] = None
196192

197-
self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
198-
self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)
193+
self.out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY)
194+
self.out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY)
199195

200196
self.start()
201197

@@ -567,45 +563,6 @@ def handle_response(self, msg: DNSIncoming) -> None:
567563
self.log_warning_once("handle_response is deprecated, use record_manager.async_updates_from_response")
568564
self.record_manager.async_updates_from_response(msg)
569565

570-
def handle_assembled_query(
571-
self,
572-
packets: List[DNSIncoming],
573-
addr: str,
574-
port: int,
575-
transport: _WrappedTransport,
576-
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
577-
) -> None:
578-
"""Respond to a (re)assembled query.
579-
580-
If the protocol received packets with the TC bit set, it will
581-
wait a bit for the rest of the packets and only call
582-
handle_assembled_query once it has a complete set of packets
583-
or the timer expires. If the TC bit is not set, a single
584-
packet will be in packets.
585-
"""
586-
ucast_source = port != _MDNS_PORT
587-
question_answers = self.query_handler.async_response(packets, ucast_source)
588-
if not question_answers:
589-
return
590-
now = packets[0].now
591-
if question_answers.ucast:
592-
questions = packets[0].questions
593-
id_ = packets[0].id
594-
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
595-
# When sending unicast, only send back the reply
596-
# via the same socket that it was recieved from
597-
# as we know its reachable from that socket
598-
self.async_send(out, addr, port, v6_flow_scope, transport)
599-
if question_answers.mcast_now:
600-
self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
601-
if question_answers.mcast_aggregate:
602-
self._out_queue.async_add(now, question_answers.mcast_aggregate)
603-
if question_answers.mcast_aggregate_last_second:
604-
# https://datatracker.ietf.org/doc/html/rfc6762#section-14
605-
# If we broadcast it in the last second, we have to delay
606-
# at least a second before we send it again
607-
self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)
608-
609566
def send(
610567
self,
611568
out: DNSOutgoing,

src/zeroconf/_handlers/query_handler.pxd

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ from .._history cimport QuestionHistory
77
from .._protocol.incoming cimport DNSIncoming
88
from .._services.info cimport ServiceInfo
99
from .._services.registry cimport ServiceRegistry
10-
from .answers cimport QuestionAnswers
10+
from .answers cimport (
11+
QuestionAnswers,
12+
construct_outgoing_multicast_answers,
13+
construct_outgoing_unicast_answers,
14+
)
15+
from .multicast_outgoing_queue cimport MulticastOutgoingQueue
1116

1217

1318
cdef bint TYPE_CHECKING
@@ -65,6 +70,7 @@ cdef class _QueryResponse:
6570

6671
cdef class QueryHandler:
6772

73+
cdef object zc
6874
cdef ServiceRegistry registry
6975
cdef DNSCache cache
7076
cdef QuestionHistory question_history
@@ -93,7 +99,22 @@ cdef class QueryHandler:
9399
is_probe=object,
94100
now=double
95101
)
96-
cpdef async_response(self, cython.list msgs, cython.bint unicast_source)
102+
cpdef QuestionAnswers async_response(self, cython.list msgs, cython.bint unicast_source)
97103

98104
@cython.locals(name=str, question_lower_name=str)
99105
cdef _get_answer_strategies(self, DNSQuestion question)
106+
107+
@cython.locals(
108+
first_packet=DNSIncoming,
109+
ucast_source=bint,
110+
out_queue=MulticastOutgoingQueue,
111+
out_delay_queue=MulticastOutgoingQueue
112+
)
113+
cpdef void handle_assembled_query(
114+
self,
115+
list packets,
116+
object addr,
117+
object port,
118+
object transport,
119+
tuple v6_flow_scope
120+
)

src/zeroconf/_handlers/query_handler.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@
2020
USA
2121
"""
2222

23-
from typing import TYPE_CHECKING, List, Optional, Set, cast
23+
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union, cast
2424

2525
from .._cache import DNSCache, _UniqueRecordsType
2626
from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
27-
from .._history import QuestionHistory
2827
from .._protocol.incoming import DNSIncoming
2928
from .._services.info import ServiceInfo
30-
from .._services.registry import ServiceRegistry
29+
from .._transport import _WrappedTransport
3130
from .._utils.net import IPVersion
3231
from ..const import (
3332
_ADDRESS_RECORD_TYPES,
3433
_CLASS_IN,
3534
_DNS_OTHER_TTL,
35+
_MDNS_PORT,
3636
_ONE_SECOND,
3737
_SERVICE_TYPE_ENUMERATION_NAME,
3838
_TYPE_A,
@@ -43,7 +43,12 @@
4343
_TYPE_SRV,
4444
_TYPE_TXT,
4545
)
46-
from .answers import QuestionAnswers, _AnswerWithAdditionalsType
46+
from .answers import (
47+
QuestionAnswers,
48+
_AnswerWithAdditionalsType,
49+
construct_outgoing_multicast_answers,
50+
construct_outgoing_unicast_answers,
51+
)
4752

4853
_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES}
4954

@@ -53,14 +58,17 @@
5358
_IPVersion_ALL = IPVersion.All
5459

5560
_int = int
56-
61+
_str = str
5762

5863
_ANSWER_STRATEGY_SERVICE_TYPE_ENUMERATION = 0
5964
_ANSWER_STRATEGY_POINTER = 1
6065
_ANSWER_STRATEGY_ADDRESS = 2
6166
_ANSWER_STRATEGY_SERVICE = 3
6267
_ANSWER_STRATEGY_TEXT = 4
6368

69+
if TYPE_CHECKING:
70+
from .._core import Zeroconf
71+
6472

6573
class _AnswerStrategy:
6674

@@ -183,13 +191,14 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
183191
class QueryHandler:
184192
"""Query the ServiceRegistry."""
185193

186-
__slots__ = ("registry", "cache", "question_history")
194+
__slots__ = ("zc", "registry", "cache", "question_history")
187195

188-
def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None:
196+
def __init__(self, zc: 'Zeroconf') -> None:
189197
"""Init the query handler."""
190-
self.registry = registry
191-
self.cache = cache
192-
self.question_history = question_history
198+
self.zc = zc
199+
self.registry = zc.registry
200+
self.cache = zc.cache
201+
self.question_history = zc.question_history
193202

194203
def _add_service_type_enumeration_query_answers(
195204
self, types: List[str], answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet
@@ -385,3 +394,45 @@ def _get_answer_strategies(
385394
)
386395

387396
return strategies
397+
398+
def handle_assembled_query(
399+
self,
400+
packets: List[DNSIncoming],
401+
addr: _str,
402+
port: _int,
403+
transport: _WrappedTransport,
404+
v6_flow_scope: Union[Tuple[()], Tuple[int, int]],
405+
) -> None:
406+
"""Respond to a (re)assembled query.
407+
408+
If the protocol recieved packets with the TC bit set, it will
409+
wait a bit for the rest of the packets and only call
410+
handle_assembled_query once it has a complete set of packets
411+
or the timer expires. If the TC bit is not set, a single
412+
packet will be in packets.
413+
"""
414+
first_packet = packets[0]
415+
now = first_packet.now
416+
ucast_source = port != _MDNS_PORT
417+
question_answers = self.async_response(packets, ucast_source)
418+
if not question_answers:
419+
return
420+
if question_answers.ucast:
421+
questions = first_packet.questions
422+
id_ = first_packet.id
423+
out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_)
424+
# When sending unicast, only send back the reply
425+
# via the same socket that it was recieved from
426+
# as we know its reachable from that socket
427+
self.zc.async_send(out, addr, port, v6_flow_scope, transport)
428+
if question_answers.mcast_now:
429+
self.zc.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now))
430+
if question_answers.mcast_aggregate:
431+
out_queue = self.zc.out_queue
432+
out_queue.async_add(now, question_answers.mcast_aggregate)
433+
if question_answers.mcast_aggregate_last_second:
434+
# https://datatracker.ietf.org/doc/html/rfc6762#section-14
435+
# If we broadcast it in the last second, we have to delay
436+
# at least a second before we send it again
437+
out_delay_queue = self.zc.out_delay_queue
438+
out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second)

src/zeroconf/_handlers/record_manager.pxd

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,21 @@ cdef class RecordManager:
2222
cdef public DNSCache cache
2323
cdef public cython.set listeners
2424

25-
cpdef async_updates(self, object now, object records)
25+
cpdef void async_updates(self, object now, object records)
2626

27-
cpdef async_updates_complete(self, object notify)
27+
cpdef void async_updates_complete(self, bint notify)
2828

2929
@cython.locals(
3030
cache=DNSCache,
3131
record=DNSRecord,
3232
answers=cython.list,
3333
maybe_entry=DNSRecord,
34-
now_double=double
3534
)
36-
cpdef async_updates_from_response(self, DNSIncoming msg)
35+
cpdef void async_updates_from_response(self, DNSIncoming msg)
3736

38-
cpdef async_add_listener(self, RecordUpdateListener listener, object question)
37+
cpdef void async_add_listener(self, RecordUpdateListener listener, object question)
3938

40-
cpdef async_remove_listener(self, RecordUpdateListener listener)
39+
cpdef void async_remove_listener(self, RecordUpdateListener listener)
4140

4241
@cython.locals(question=DNSQuestion, record=DNSRecord)
43-
cdef _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions)
42+
cdef void _async_update_matching_records(self, RecordUpdateListener listener, cython.list questions)

src/zeroconf/_handlers/record_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
8484
other_adds: List[DNSRecord] = []
8585
removes: Set[DNSRecord] = set()
8686
now = msg.now
87-
now_double = now
8887
unique_types: Set[Tuple[str, int, int]] = set()
8988
cache = self.cache
9089
answers = msg.answers()
@@ -113,7 +112,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
113112
record = cast(_UniqueRecordsType, record)
114113

115114
maybe_entry = cache.async_get_unique(record)
116-
if not record.is_expired(now_double):
115+
if not record.is_expired(now):
117116
if maybe_entry is not None:
118117
maybe_entry.reset_ttl(record)
119118
else:
@@ -129,7 +128,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
129128
removes.add(record)
130129

131130
if unique_types:
132-
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now_double)
131+
cache.async_mark_unique_records_older_than_1s_to_expire(unique_types, answers, now)
133132

134133
if updates:
135134
self.async_updates(now, updates)

src/zeroconf/_listener.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
import cython
33

4+
from ._handlers.query_handler cimport QueryHandler
45
from ._handlers.record_manager cimport RecordManager
56
from ._protocol.incoming cimport DNSIncoming
67
from ._services.registry cimport ServiceRegistry
@@ -21,6 +22,7 @@ cdef class AsyncListener:
2122
cdef public object zc
2223
cdef ServiceRegistry _registry
2324
cdef RecordManager _record_manager
25+
cdef QueryHandler _query_handler
2426
cdef public cython.bytes data
2527
cdef public double last_time
2628
cdef public DNSIncoming last_message

src/zeroconf/_listener.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class AsyncListener:
5959
'zc',
6060
'_registry',
6161
'_record_manager',
62+
"_query_handler",
6263
'data',
6364
'last_time',
6465
'last_message',
@@ -72,6 +73,7 @@ def __init__(self, zc: 'Zeroconf') -> None:
7273
self.zc = zc
7374
self._registry = zc.registry
7475
self._record_manager = zc.record_manager
76+
self._query_handler = zc.query_handler
7577
self.data: Optional[bytes] = None
7678
self.last_time: float = 0
7779
self.last_message: Optional[DNSIncoming] = None
@@ -228,7 +230,7 @@ def _respond_query(
228230
if msg:
229231
packets.append(msg)
230232

231-
self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
233+
self._query_handler.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
232234

233235
def error_received(self, exc: Exception) -> None:
234236
"""Likely socket closed or IPv6."""

src/zeroconf/_protocol/incoming.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ cdef class DNSIncoming:
5656
cdef cython.uint _num_authorities
5757
cdef cython.uint _num_additionals
5858
cdef public bint valid
59-
cdef public object now
59+
cdef public double now
6060
cdef public object scope_id
6161
cdef public object source
6262
cdef bint _has_qu_question

src/zeroconf/_transport.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import asyncio
2424
import socket
25-
from typing import Any
25+
from typing import Tuple
2626

2727

2828
class _WrappedTransport:
@@ -42,7 +42,7 @@ def __init__(
4242
is_ipv6: bool,
4343
sock: socket.socket,
4444
fileno: int,
45-
sock_name: Any,
45+
sock_name: Tuple,
4646
) -> None:
4747
"""Initialize the wrapped transport.
4848

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010

1111
from zeroconf import _core, const
12+
from zeroconf._handlers import query_handler
1213

1314

1415
@pytest.fixture(autouse=True)
@@ -23,7 +24,9 @@ def verify_threads_ended():
2324
@pytest.fixture
2425
def run_isolated():
2526
"""Change the mDNS port to run the test in isolation."""
26-
with patch.object(_core, "_MDNS_PORT", 5454), patch.object(const, "_MDNS_PORT", 5454):
27+
with patch.object(query_handler, "_MDNS_PORT", 5454), patch.object(
28+
_core, "_MDNS_PORT", 5454
29+
), patch.object(const, "_MDNS_PORT", 5454):
2730
yield
2831

2932

0 commit comments

Comments
 (0)