Skip to content

Commit 80b138f

Browse files
committed
wgengine/magicsock: keep advertising endpoints after we stop discovering them
Previously, when updating endpoints we would immediately stop advertising any endpoint that wasn't discovered during determineEndpoints. This could result in, for example, a case where we performed an incremental netcheck, didn't get any of our three STUN packets back, and then dropped our STUN endpoint from the set of advertised endpoints... which would result in clients falling back to a DERP connection until the next call to determineEndpoints. Instead, let's cache endpoints that we've discovered and continue reporting them to clients until a timeout expires. In the above case where we temporarily don't have a discovered STUN endpoint, we would continue reporting the old value, then re-discover the STUN endpoint again and continue reporting it as normal, so clients never see a withdrawal. Updates tailscale/coral#108 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I42de72e7418ab328a6c732bdefc74549708cf8b9
1 parent 4b49ca4 commit 80b138f

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed

wgengine/magicsock/magicsock.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ import (
6464
"tailscale.com/util/clientmetric"
6565
"tailscale.com/util/mak"
6666
"tailscale.com/util/ringbuffer"
67+
"tailscale.com/util/set"
6768
"tailscale.com/util/sysresources"
6869
"tailscale.com/util/uniq"
6970
"tailscale.com/version"
@@ -419,6 +420,10 @@ type Conn struct {
419420
// when endpoints are refreshed.
420421
onEndpointRefreshed map[*endpoint]func()
421422

423+
// endpointTracker tracks the set of cached endpoints that we advertise
424+
// for a period of time before withdrawing them.
425+
endpointTracker endpointTracker
426+
422427
// peerSet is the set of peers that are currently configured in
423428
// WireGuard. These are not used to filter inbound or outbound
424429
// traffic at all, but only to track what state can be cleaned up
@@ -1196,6 +1201,22 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro
11961201

11971202
c.ignoreSTUNPackets()
11981203

1204+
// Update our set of endpoints by adding any endpoints that we
1205+
// previously found but haven't expired yet. This also updates the
1206+
// cache with the set of endpoints discovered in this function.
1207+
//
1208+
// NOTE: we do this here and not below so that we don't cache local
1209+
// endpoints; we know that the local endpoints we discover are all
1210+
// possible local endpoints since we determine them by looking at the
1211+
// set of addresses on our local interfaces.
1212+
//
1213+
// TODO(andrew): If we pull in any cached endpoints, we should probably
1214+
// do something to ensure we're propagating the removal of those cached
1215+
// endpoints if they do actually time out without being rediscovered.
1216+
// For now, though, rely on a minor LinkChange event causing this to
1217+
// re-run.
1218+
eps = c.endpointTracker.update(time.Now(), eps)
1219+
11991220
if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() {
12001221
ips, loopback, err := interfaces.LocalAddresses()
12011222
if err != nil {
@@ -4148,6 +4169,11 @@ const (
41484169
// STUN-derived endpoint valid for. UDP NAT mappings typically
41494170
// expire at 30 seconds, so this is a few seconds shy of that.
41504171
endpointsFreshEnoughDuration = 27 * time.Second
4172+
4173+
// endpointTrackerLifetime is how long we continue advertising an
4174+
// endpoint after we last see it. This is intentionally chosen to be
4175+
// slightly longer than a full netcheck period.
4176+
endpointTrackerLifetime = 5*time.Minute + 10*time.Second
41514177
)
41524178

41534179
// Constants that are variable for testing.
@@ -5105,6 +5131,79 @@ func (s derpAddrFamSelector) PreferIPv6() bool {
51055131
return false
51065132
}
51075133

5134+
type endpointTrackerEntry struct {
5135+
endpoint tailcfg.Endpoint
5136+
until time.Time
5137+
}
5138+
5139+
type endpointTracker struct {
5140+
mu sync.Mutex
5141+
cache map[netip.AddrPort]endpointTrackerEntry
5142+
}
5143+
5144+
func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
5145+
epsPlusCached = eps
5146+
5147+
var inputEps set.Slice[netip.AddrPort]
5148+
for _, ep := range eps {
5149+
inputEps.Add(ep.Addr)
5150+
}
5151+
5152+
et.mu.Lock()
5153+
defer et.mu.Unlock()
5154+
5155+
// Add entries to the return array that aren't already there.
5156+
for k, ep := range et.cache {
5157+
// If the endpoint was in the input list, or has expired, skip it.
5158+
if inputEps.Contains(k) {
5159+
continue
5160+
} else if now.After(ep.until) {
5161+
continue
5162+
}
5163+
5164+
// We haven't seen this endpoint; add to the return array
5165+
epsPlusCached = append(epsPlusCached, ep.endpoint)
5166+
}
5167+
5168+
// Add entries from the original input array into the cache, and/or
5169+
// extend the lifetime of entries that are already in the cache.
5170+
until := now.Add(endpointTrackerLifetime)
5171+
for _, ep := range eps {
5172+
et.addLocked(now, ep, until)
5173+
}
5174+
5175+
// Remove everything that has now expired.
5176+
et.removeExpiredLocked(now)
5177+
return epsPlusCached
5178+
}
5179+
5180+
// add will store the provided endpoint(s) in the cache for a fixed period of
5181+
// time, and remove any entries in the cache that have expired.
5182+
//
5183+
// et.mu must be held.
5184+
func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
5185+
// If we already have an entry for this endpoint, update the timeout on
5186+
// it; otherwise, add it.
5187+
entry, found := et.cache[ep.Addr]
5188+
if found {
5189+
entry.until = until
5190+
} else {
5191+
entry = endpointTrackerEntry{ep, until}
5192+
}
5193+
mak.Set(&et.cache, ep.Addr, entry)
5194+
}
5195+
5196+
// removeExpired will remove all expired entries from the cache
5197+
//
5198+
// et.mu must be held
5199+
func (et *endpointTracker) removeExpiredLocked(now time.Time) {
5200+
for k, ep := range et.cache {
5201+
if now.After(ep.until) {
5202+
delete(et.cache, k)
5203+
}
5204+
}
5205+
}
5206+
51085207
var (
51095208
metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers")
51105209
metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns")

wgengine/magicsock/magicsock_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"net/http/httptest"
1919
"net/netip"
2020
"os"
21+
"reflect"
2122
"runtime"
2223
"strconv"
2324
"strings"
@@ -31,6 +32,7 @@ import (
3132
"github.com/tailscale/wireguard-go/tun/tuntest"
3233
"go4.org/mem"
3334
"golang.org/x/exp/maps"
35+
"golang.org/x/exp/slices"
3436
"golang.org/x/net/ipv6"
3537
"tailscale.com/cmd/testwrapper/flakytest"
3638
"tailscale.com/derp"
@@ -390,6 +392,7 @@ collectEndpoints:
390392
for {
391393
select {
392394
case ep := <-epCh:
395+
t.Logf("TestNewConn: got endpoint: %v", ep)
393396
endpoints = append(endpoints, ep)
394397
if strings.HasSuffix(ep, suffix) {
395398
break collectEndpoints
@@ -2280,3 +2283,113 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) {
22802283
t.Fatal("no packet after 1s")
22812284
}
22822285
}
2286+
2287+
func TestEndpointTracker(t *testing.T) {
2288+
local := tailcfg.Endpoint{
2289+
Addr: netip.MustParseAddrPort("192.168.1.1:12345"),
2290+
Type: tailcfg.EndpointLocal,
2291+
}
2292+
2293+
stun4_1 := tailcfg.Endpoint{
2294+
Addr: netip.MustParseAddrPort("1.2.3.4:12345"),
2295+
Type: tailcfg.EndpointSTUN,
2296+
}
2297+
stun4_2 := tailcfg.Endpoint{
2298+
Addr: netip.MustParseAddrPort("5.6.7.8:12345"),
2299+
Type: tailcfg.EndpointSTUN,
2300+
}
2301+
2302+
stun6_1 := tailcfg.Endpoint{
2303+
Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"),
2304+
Type: tailcfg.EndpointSTUN,
2305+
}
2306+
stun6_2 := tailcfg.Endpoint{
2307+
Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"),
2308+
Type: tailcfg.EndpointSTUN,
2309+
}
2310+
2311+
start := time.Unix(1681503440, 0)
2312+
2313+
steps := []struct {
2314+
name string
2315+
now time.Time
2316+
eps []tailcfg.Endpoint
2317+
want []tailcfg.Endpoint
2318+
}{
2319+
{
2320+
name: "initial endpoints",
2321+
now: start,
2322+
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
2323+
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
2324+
},
2325+
{
2326+
name: "no change",
2327+
now: start.Add(1 * time.Minute),
2328+
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
2329+
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
2330+
},
2331+
{
2332+
name: "missing stun4",
2333+
now: start.Add(2 * time.Minute),
2334+
eps: []tailcfg.Endpoint{local, stun6_1},
2335+
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
2336+
},
2337+
{
2338+
name: "missing stun6",
2339+
now: start.Add(3 * time.Minute),
2340+
eps: []tailcfg.Endpoint{local, stun4_1},
2341+
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
2342+
},
2343+
{
2344+
name: "multiple STUN addresses within timeout",
2345+
now: start.Add(4 * time.Minute),
2346+
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
2347+
want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2},
2348+
},
2349+
{
2350+
name: "endpoint extended",
2351+
now: start.Add(3*time.Minute + endpointTrackerLifetime - 1),
2352+
eps: []tailcfg.Endpoint{local},
2353+
want: []tailcfg.Endpoint{
2354+
local, stun4_2, stun6_2,
2355+
// stun4_1 had its lifetime extended by the
2356+
// "missing stun6" test above to that start
2357+
// time plus the lifetime, while stun6 should
2358+
// have expired a minute sooner. It should thus
2359+
// be in this returned list.
2360+
stun4_1,
2361+
},
2362+
},
2363+
{
2364+
name: "after timeout",
2365+
now: start.Add(4*time.Minute + endpointTrackerLifetime + 1),
2366+
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
2367+
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
2368+
},
2369+
{
2370+
name: "after timeout still caches",
2371+
now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute),
2372+
eps: []tailcfg.Endpoint{local},
2373+
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
2374+
},
2375+
}
2376+
2377+
var et endpointTracker
2378+
for _, tt := range steps {
2379+
t.Logf("STEP: %s", tt.name)
2380+
2381+
got := et.update(tt.now, tt.eps)
2382+
2383+
// Sort both arrays for comparison
2384+
slices.SortFunc(got, func(a, b tailcfg.Endpoint) bool {
2385+
return a.Addr.String() < b.Addr.String()
2386+
})
2387+
slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) bool {
2388+
return a.Addr.String() < b.Addr.String()
2389+
})
2390+
2391+
if !reflect.DeepEqual(got, tt.want) {
2392+
t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want)
2393+
}
2394+
}
2395+
}

0 commit comments

Comments
 (0)