Skip to content

Commit 7032d0e

Browse files
committed
Introduce InterfaceChoice.AllWithLoopback
1 parent dfc9b8d commit 7032d0e

File tree

10 files changed

+61
-18
lines changed

10 files changed

+61
-18
lines changed

src/zeroconf/_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,12 @@ def __init__(
159159
* `InterfaceChoice.All` is an alias for `InterfaceChoice.Default`
160160
on Python versions before 3.8.
161161
162+
* `InterfaceChoice.AllWithLoopback` is the same as `InterfaceChoice.All`
163+
on POSIX systems, but includes the loopback interfaces. This likely
164+
only works on macOS/BSD.
165+
162166
Also listening on loopback (``::1``) doesn't work, use a real address.
167+
163168
:param ip_version: IP versions to support. If `choice` is a list, the default is detected
164169
from it. Otherwise defaults to V4 only for backward compatibility.
165170
:param apple_p2p: use AWDL interface (only macOS)

src/zeroconf/_handlers/query_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171

7272

7373
class _AnswerStrategy:
74-
7574
__slots__ = ("question", "strategy_type", "types", "services")
7675

7776
def __init__(

src/zeroconf/_services/browser.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@
105105

106106

107107
class _ScheduledPTRQuery:
108-
109108
__slots__ = ('alias', 'name', 'ttl', 'cancelled', 'expire_time_millis', 'when_millis')
110109

111110
def __init__(

src/zeroconf/_utils/ipaddress.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434

3535
class ZeroconfIPv4Address(IPv4Address):
36-
3736
__slots__ = ("_str", "_is_link_local", "_is_unspecified")
3837

3938
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -59,7 +58,6 @@ def is_unspecified(self) -> bool:
5958

6059

6160
class ZeroconfIPv6Address(IPv6Address):
62-
6361
__slots__ = ("_str", "_is_link_local", "_is_unspecified")
6462

6563
def __init__(self, *args: Any, **kwargs: Any) -> None:

src/zeroconf/_utils/net.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import enum
2424
import errno
25-
import ipaddress
2625
import socket
2726
import struct
2827
import sys
@@ -32,12 +31,14 @@
3231

3332
from .._logger import log
3433
from ..const import _IPPROTO_IPV6, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT
34+
from .ipaddress import _cached_ip_addresses
3535

3636

3737
@enum.unique
3838
class InterfaceChoice(enum.Enum):
3939
Default = 1
4040
All = 2
41+
AllWithLoopback = 3
4142

4243

4344
InterfacesType = Union[Sequence[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice]
@@ -85,11 +86,11 @@ def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]:
8586
def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]:
8687
if '%' in ip:
8788
ip = ip[: ip.index('%')] # Strip scope_id.
88-
ipaddr = ipaddress.ip_address(ip)
89+
ipaddr = _cached_ip_addresses(ip)
8990
for adapter in adapters:
9091
for adapter_ip in adapter.ips:
9192
# IPv6 addresses are represented as tuples
92-
if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr:
93+
if isinstance(adapter_ip.ip, tuple) and _cached_ip_addresses(adapter_ip.ip[0]) == ipaddr:
9394
return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index))
9495

9596
raise RuntimeError('No adapter found for IP address %s' % ip)
@@ -122,7 +123,9 @@ def ip6_addresses_to_indexes(
122123
for iface in interfaces:
123124
if isinstance(iface, int):
124125
result.append((interface_index_to_ip6_address(adapters, iface), iface))
125-
elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6:
126+
elif (
127+
isinstance(iface, str) and (ip_address := _cached_ip_addresses(iface)) and ip_address.version == 6
128+
):
126129
result.append(ip6_to_address_and_index(adapters, iface))
127130

128131
return result
@@ -145,6 +148,23 @@ def normalize_interface_choice(
145148
if ip_version != IPVersion.V6Only:
146149
result.append('0.0.0.0')
147150
elif choice is InterfaceChoice.All:
151+
if ip_version != IPVersion.V4Only:
152+
result.extend(
153+
ip
154+
for ip in get_all_addresses_v6()
155+
if (ip_address := _cached_ip_addresses(ip[0])) and not ip_address.is_loopback
156+
)
157+
if ip_version != IPVersion.V6Only:
158+
result.extend(
159+
ip
160+
for ip in get_all_addresses()
161+
if (ip_address := _cached_ip_addresses(ip[0])) and not ip_address.is_loopback
162+
)
163+
if not result:
164+
raise RuntimeError(
165+
'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version
166+
)
167+
elif choice is InterfaceChoice.AllWithLoopback:
148168
if ip_version != IPVersion.V4Only:
149169
result.extend(get_all_addresses_v6())
150170
if ip_version != IPVersion.V6Only:
@@ -155,7 +175,11 @@ def normalize_interface_choice(
155175
)
156176
elif isinstance(choice, list):
157177
# First, take IPv4 addresses.
158-
result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4]
178+
result = [
179+
i
180+
for i in choice
181+
if isinstance(i, str) and (ip_address := _cached_ip_addresses(i)) and ip_address.version == 4
182+
]
159183
# Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes.
160184
result += ip6_addresses_to_indexes(choice)
161185
else:
@@ -406,10 +430,14 @@ def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion:
406430
"""Auto detect the IP version when it is not provided."""
407431
if isinstance(interfaces, list):
408432
has_v6 = any(
409-
isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6)
433+
isinstance(i, int)
434+
or (isinstance(i, str) and (ip_address := _cached_ip_addresses(i)) and ip_address.version == 6)
435+
for i in interfaces
436+
)
437+
has_v4 = any(
438+
isinstance(i, str) and (ip_address := _cached_ip_addresses(i)) and ip_address.version == 4
410439
for i in interfaces
411440
)
412-
has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces)
413441
if has_v4 and has_v6:
414442
return IPVersion.All
415443
if has_v6:

src/zeroconf/asyncio.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ def __init__(
162162
* `InterfaceChoice.All` is an alias for `InterfaceChoice.Default`
163163
on Python versions before 3.8.
164164
165+
* `InterfaceChoice.AllWithLoopback` is the same as `InterfaceChoice.All`
166+
on POSIX systems, but includes the loopback interfaces. This likely
167+
only works on macOS/BSD.
168+
165169
Also listening on loopback (``::1``) doesn't work, use a real address.
166170
:param ip_version: IP versions to support. If `choice` is a list, the default is detected
167171
from it. Otherwise defaults to V4 only for backward compatibility.

tests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def time_changed_millis(millis: Optional[float] = None) -> None:
9999
mock_seconds_into_future = loop_time
100100

101101
with mock.patch("time.monotonic", return_value=mock_seconds_into_future):
102-
103102
for task in list(loop._scheduled): # type: ignore[attr-defined]
104103
if not isinstance(task, asyncio.TimerHandle):
105104
continue

tests/services/test_browser.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,6 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
11611161

11621162
# patch the zeroconf send so we can capture what is being sent
11631163
with patch.object(zc, "async_send", send):
1164-
11651164
query_scheduler.start(loop)
11661165

11671166
original_now = loop.time()
@@ -1251,7 +1250,6 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
12511250

12521251
# patch the zeroconf send so we can capture what is being sent
12531252
with patch.object(zc, "async_send", send):
1254-
12551253
query_scheduler.start(loop)
12561254

12571255
original_now = loop.time()

tests/test_core.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,29 @@ async def make_query():
6363

6464
class Framework(unittest.TestCase):
6565
def test_launch_and_close(self):
66+
rv = r.Zeroconf(interfaces=r.InterfaceChoice.AllWithLoopback)
67+
rv.close()
6668
rv = r.Zeroconf(interfaces=r.InterfaceChoice.All)
6769
rv.close()
6870
rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default)
6971
rv.close()
7072

7173
def test_launch_and_close_context_manager(self):
72-
with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv:
74+
with r.Zeroconf(interfaces=r.InterfaceChoice.AllWithLoopback) as rv:
75+
assert rv.done is False
76+
assert rv.done is True
77+
78+
with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv: # type: ignore[unreachable]
7379
assert rv.done is False
7480
assert rv.done is True
7581

76-
with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv: # type: ignore[unreachable]
82+
with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv:
7783
assert rv.done is False
7884
assert rv.done is True
7985

8086
def test_launch_and_close_unicast(self):
87+
rv = r.Zeroconf(interfaces=r.InterfaceChoice.AllWithLoopback, unicast=True)
88+
rv.close()
8189
rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True)
8290
rv.close()
8391
rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, unicast=True)
@@ -91,6 +99,8 @@ def test_close_multiple_times(self):
9199
@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
92100
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
93101
def test_launch_and_close_v4_v6(self):
102+
rv = r.Zeroconf(interfaces=r.InterfaceChoice.AllWithLoopback, ip_version=r.IPVersion.All)
103+
rv.close()
94104
rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All)
95105
rv.close()
96106
rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All)
@@ -99,6 +109,8 @@ def test_launch_and_close_v4_v6(self):
99109
@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6')
100110
@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled')
101111
def test_launch_and_close_v6_only(self):
112+
rv = r.Zeroconf(interfaces=r.InterfaceChoice.AllWithLoopback, ip_version=r.IPVersion.V6Only)
113+
rv.close()
102114
rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only)
103115
rv.close()
104116
rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only)

tests/utils/test_net.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,13 @@ def test_ip6_addresses_to_indexes():
6868
assert netutils.ip6_addresses_to_indexes(interfaces_2) == [(('2001:db8::', 1, 1), 1)]
6969

7070

71-
def test_normalize_interface_choice_errors():
71+
@pytest.mark.parametrize("interface_choice", (r.InterfaceChoice.All, r.InterfaceChoice.AllWithLoopback))
72+
def test_normalize_interface_choice_errors(interface_choice: r.InterfaceChoice) -> None:
7273
"""Test we generate exception on invalid input."""
7374
with patch("zeroconf._utils.net.get_all_addresses", return_value=[]), patch(
7475
"zeroconf._utils.net.get_all_addresses_v6", return_value=[]
7576
), pytest.raises(RuntimeError):
76-
netutils.normalize_interface_choice(r.InterfaceChoice.All)
77+
netutils.normalize_interface_choice(interface_choice)
7778

7879
with pytest.raises(TypeError):
7980
netutils.normalize_interface_choice("1.2.3.4")

0 commit comments

Comments
 (0)