diff --git a/tailnet/conn.go b/tailnet/conn.go index c6d8d3ef928a6..0b830a7913cfd 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -374,9 +374,13 @@ func (c *Conn) Status() *ipnstate.Status { // Ping sends a ping to the Wireguard engine. // The bool returned is true if the ping was performed P2P. func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, *ipnstate.PingResult, error) { + return c.pingWithType(ctx, ip, tailcfg.PingDisco) +} + +func (c *Conn) pingWithType(ctx context.Context, ip netip.Addr, pt tailcfg.PingType) (time.Duration, bool, *ipnstate.PingResult, error) { errCh := make(chan error, 1) prChan := make(chan *ipnstate.PingResult, 1) - go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { + go c.wireguardEngine.Ping(ip, pt, func(pr *ipnstate.PingResult) { if pr.Err != "" { errCh <- xerrors.New(pr.Err) return @@ -418,7 +422,13 @@ func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool { ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() - _, _, _, err := c.Ping(ctx, ip) + // For reachability, we use TSMP ping, which pings at the IP layer, and + // therefore requires that wireguard and the netstack are up. If we + // don't wait for wireguard to be up, we could miss a handshake, and it + // might take 5 seconds for the handshake to be retried. A 5s initial + // round trip can set us up for poor TCP performance, since the initial + // round-trip-time sets the initial retransmit timeout. + _, _, _, err := c.pingWithType(ctx, ip, tailcfg.PingTSMP) if err == nil { completed() } diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index f3bc96e242f9e..b904e98fe6173 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -88,7 +88,7 @@ func TestTailnet(t *testing.T) { } }) node := testutil.RequireRecvCtx(ctx, t, nodes) - // Ensure this connected over DERP! + // Ensure this connected over raw (not websocket) DERP! require.Len(t, node.DERPForcedWebsocket, 0) w1.Close() @@ -157,6 +157,94 @@ func TestTailnet(t *testing.T) { w1.Close() w2.Close() }) + + t.Run("PingDirect", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitLong) + w1IP := tailnet.IP() + w1, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, + Logger: logger.Named("w1"), + DERPMap: derpMap, + }) + require.NoError(t, err) + + w2, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + Logger: logger.Named("w2"), + DERPMap: derpMap, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = w1.Close() + _ = w2.Close() + }) + stitch(t, w2, w1) + stitch(t, w1, w2) + require.True(t, w2.AwaitReachable(context.Background(), w1IP)) + + require.Eventually(t, func() bool { + _, direct, pong, err := w2.Ping(ctx, w1IP) + if err != nil { + t.Logf("ping error: %s", err.Error()) + return false + } + if !direct { + t.Logf("got pong: %+v", pong) + return false + } + return true + }, testutil.WaitShort, testutil.IntervalFast) + + w1.Close() + w2.Close() + }) + + t.Run("PingDERPOnly", func(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + ctx := testutil.Context(t, testutil.WaitLong) + w1IP := tailnet.IP() + w1, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(w1IP, 128)}, + Logger: logger.Named("w1"), + DERPMap: derpMap, + BlockEndpoints: true, + }) + require.NoError(t, err) + + w2, err := tailnet.NewConn(&tailnet.Options{ + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + Logger: logger.Named("w2"), + DERPMap: derpMap, + BlockEndpoints: true, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = w1.Close() + _ = w2.Close() + }) + stitch(t, w2, w1) + stitch(t, w1, w2) + require.True(t, w2.AwaitReachable(context.Background(), w1IP)) + + require.Eventually(t, func() bool { + _, direct, pong, err := w2.Ping(ctx, w1IP) + if err != nil { + t.Logf("ping error: %s", err.Error()) + return false + } + if direct || pong.DERPRegionID != derpMap.RegionIDs()[0] { + t.Logf("got pong: %+v", pong) + return false + } + return true + }, testutil.WaitShort, testutil.IntervalFast) + + w1.Close() + w2.Close() + }) } // TestConn_PreferredDERP tests that we only trigger the NodeCallback when we have a preferred DERP server.