diff --git a/agent/agent_test.go b/agent/agent_test.go index c53438404a2fb..133167e45e952 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -498,12 +498,7 @@ func TestAgent(t *testing.T) { }() conn, _ := setupAgent(t, codersdk.WorkspaceAgentMetadata{}, 0) - require.Eventually(t, func() bool { - ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalFast) - defer cancelFunc() - _, err := conn.Ping(ctx) - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) + require.True(t, conn.AwaitReachable(context.Background())) conn1, err := conn.DialContext(context.Background(), l.Addr().Network(), l.Addr().String()) require.NoError(t, err) defer conn1.Close() diff --git a/cli/agent_test.go b/cli/agent_test.go index a2c79ceae2753..edcab3ac4a0e5 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -14,7 +14,6 @@ import ( "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/coder/coder/testutil" ) func TestWorkspaceAgent(t *testing.T) { @@ -71,12 +70,7 @@ func TestWorkspaceAgent(t *testing.T) { dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() - require.Eventually(t, func() bool { - ctx, cancelFunc := context.WithTimeout(ctx, testutil.IntervalFast) - defer cancelFunc() - _, err := dialer.Ping(ctx) - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) + require.True(t, dialer.AwaitReachable(context.Background())) cancelFunc() err = <-errC require.NoError(t, err) @@ -134,12 +128,7 @@ func TestWorkspaceAgent(t *testing.T) { dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() - require.Eventually(t, func() bool { - ctx, cancelFunc := context.WithTimeout(ctx, testutil.IntervalFast) - defer cancelFunc() - _, err := dialer.Ping(ctx) - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) + require.True(t, dialer.AwaitReachable(context.Background())) cancelFunc() err = <-errC require.NoError(t, err) @@ -197,13 +186,7 @@ func TestWorkspaceAgent(t *testing.T) { dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() - require.Eventually(t, func() bool { - ctx, cancelFunc := context.WithTimeout(ctx, testutil.IntervalFast) - defer cancelFunc() - _, err := dialer.Ping(ctx) - return err == nil - }, testutil.WaitMedium, testutil.IntervalFast) - + require.True(t, dialer.AwaitReachable(context.Background())) sshClient, err := dialer.SSHClient(ctx) require.NoError(t, err) defer sshClient.Close() diff --git a/cli/portforward.go b/cli/portforward.go index ca7cb51f14719..ea6edb2c9d89e 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -10,7 +10,6 @@ import ( "strings" "sync" "syscall" - "time" "github.com/pion/udp" "github.com/spf13/cobra" @@ -145,22 +144,7 @@ func portForward() *cobra.Command { closeAllListeners() }() - ticker := time.NewTicker(250 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - } - - _, err = conn.Ping(ctx) - if err != nil { - continue - } - break - } - ticker.Stop() + conn.AwaitReachable(ctx) _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!") wg.Wait() return closeErr diff --git a/cli/speedtest.go b/cli/speedtest.go index 873e5e2794963..0761b558ef39f 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -62,33 +62,37 @@ func speedtest() *cobra.Command { return err } defer conn.Close() - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: + if direct { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + dur, err := conn.Ping(ctx) + if err != nil { + continue + } + status := conn.Status() + if len(status.Peers()) != 1 { + continue + } + peer := status.Peer[status.Peers()[0]] + if peer.CurAddr == "" && direct { + cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay) + continue + } + via := peer.Relay + if via == "" { + via = "direct" + } + cmd.Printf("%dms via %s\n", dur.Milliseconds(), via) + break } - dur, err := conn.Ping(ctx) - if err != nil { - continue - } - status := conn.Status() - if len(status.Peers()) != 1 { - continue - } - peer := status.Peer[status.Peers()[0]] - if peer.CurAddr == "" && direct { - cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay) - continue - } - via := peer.Relay - if via == "" { - via = "direct" - } - cmd.Printf("%dms via %s\n", dur.Milliseconds(), via) - break + } else { + conn.AwaitReachable(ctx) } dir := tsspeedtest.Download if reverse { diff --git a/cli/ssh.go b/cli/ssh.go index 811d87af18ff1..57a8c4aab4ac4 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -90,12 +90,12 @@ func ssh() *cobra.Command { return xerrors.Errorf("await agent: %w", err) } - conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, nil) + conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{}) if err != nil { return err } defer conn.Close() - + conn.AwaitReachable(ctx) stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace) defer stopPolling() diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 9a3088711fe24..e89b913f1bf17 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -178,12 +178,7 @@ func TestWorkspaceAgentListen(t *testing.T) { defer func() { _ = conn.Close() }() - require.Eventually(t, func() bool { - ctx, cancelFunc := context.WithTimeout(ctx, testutil.IntervalFast) - defer cancelFunc() - _, err := conn.Ping(ctx) - return err == nil - }, testutil.WaitLong, testutil.IntervalFast) + conn.AwaitReachable(ctx) }) t.Run("FailNonLatestBuild", func(t *testing.T) { diff --git a/codersdk/agentconn.go b/codersdk/agentconn.go index a68ab0672ad6b..f980767336daa 100644 --- a/codersdk/agentconn.go +++ b/codersdk/agentconn.go @@ -16,9 +16,7 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/xerrors" - "tailscale.com/ipn/ipnstate" "tailscale.com/net/speedtest" - "tailscale.com/tailcfg" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/tailnet" @@ -133,27 +131,18 @@ type AgentConn struct { CloseFunc func() } +func (c *AgentConn) AwaitReachable(ctx context.Context) bool { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + return c.Conn.AwaitReachable(ctx, TailnetIP) +} + func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() - errCh := make(chan error, 1) - durCh := make(chan time.Duration, 1) - go c.Conn.Ping(TailnetIP, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { - if pr.Err != "" { - errCh <- xerrors.New(pr.Err) - return - } - durCh <- time.Duration(pr.LatencySeconds * float64(time.Second)) - }) - select { - case err := <-errCh: - return 0, err - case <-ctx.Done(): - return 0, ctx.Err() - case dur := <-durCh: - return dur, nil - } + return c.Conn.Ping(ctx, TailnetIP) } func (c *AgentConn) CloseWithError(_ error) error { diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 8f6d43f6eaf4e..a5e1b0ce2e3a0 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -447,13 +447,14 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti _ = conn.Close() return nil, err } + return &AgentConn{ Conn: conn, CloseFunc: func() { cancelFunc() <-closed }, - }, err + }, nil } // WorkspaceAgent returns an agent by ID. diff --git a/go.mod b/go.mod index 517f5bedb3f47..f38504b38d9dc 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221104170440-ef53dca69a41 +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc // Switch to our fork that imports fixes from http://github.com/tailscale/ssh. // See: https://github.com/coder/coder/issues/3371 diff --git a/go.sum b/go.sum index e6525b0464a99..29b60a8021290 100644 --- a/go.sum +++ b/go.sum @@ -355,8 +355,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= -github.com/coder/tailscale v1.1.1-0.20221104170440-ef53dca69a41 h1:/mjNjfUarvH8BdmvDVLvtIIENoe3PevqCyZQmAlILuw= -github.com/coder/tailscale v1.1.1-0.20221104170440-ef53dca69a41/go.mod h1:lkCb74eSJwxeNq8YwyILoHD5vtHktiZnTOxBxo3tbNc= +github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc h1:qozpteSLz0ifMasetJ+/Qac5Ud/NRNIlgTubGf6TAaQ= +github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc/go.mod h1:lkCb74eSJwxeNq8YwyILoHD5vtHktiZnTOxBxo3tbNc= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= diff --git a/tailnet/conn.go b/tailnet/conn.go index 108a7cc5dda47..5a5c8e50f394d 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -2,6 +2,7 @@ package tailnet import ( "context" + "errors" "fmt" "io" "net" @@ -198,7 +199,7 @@ func NewConn(options *Options) (*Conn, error) { wireguardEngine: wireguardEngine, } wireguardEngine.SetStatusCallback(func(s *wgengine.Status, err error) { - server.logger.Info(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err)) + server.logger.Debug(context.Background(), "wireguard status", slog.F("status", s), slog.F("err", err)) if err != nil { return } @@ -217,7 +218,7 @@ func NewConn(options *Options) (*Conn, error) { server.sendNode() }) wireguardEngine.SetNetInfoCallback(func(ni *tailcfg.NetInfo) { - server.logger.Info(context.Background(), "netinfo callback", slog.F("netinfo", ni)) + server.logger.Debug(context.Background(), "netinfo callback", slog.F("netinfo", ni)) // If the lastMutex is blocked, it's possible that // multiple NetInfo callbacks occur at the same time. // @@ -383,6 +384,9 @@ func (c *Conn) UpdateNodes(nodes []*Node) error { if c.isClosed() { return nil } + if errors.Is(err, wgengine.ErrNoChanges) { + return nil + } return xerrors.Errorf("reconfig: %w", err) } return nil @@ -395,9 +399,56 @@ func (c *Conn) Status() *ipnstate.Status { return sb.Status() } -// Ping sends a ping to the Wireguard engine. -func (c *Conn) Ping(ip netip.Addr, pingType tailcfg.PingType, cb func(*ipnstate.PingResult)) { - c.wireguardEngine.Ping(ip, pingType, cb) +// Ping sends a Disco ping to the Wireguard engine. +func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, error) { + errCh := make(chan error, 1) + durCh := make(chan time.Duration, 1) + go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { + if pr.Err != "" { + errCh <- xerrors.New(pr.Err) + return + } + durCh <- time.Duration(pr.LatencySeconds * float64(time.Second)) + }) + select { + case err := <-errCh: + return 0, err + case <-ctx.Done(): + return 0, ctx.Err() + case dur := <-durCh: + return dur, nil + } +} + +// AwaitReachable pings the provided IP continually until the +// address is reachable. It's the callers responsibility to provide +// a timeout, otherwise this function will block forever. +func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool { + ticker := time.NewTicker(time.Millisecond * 100) + defer ticker.Stop() + completedCtx, completed := context.WithCancel(ctx) + run := func() { + ctx, cancelFunc := context.WithTimeout(completedCtx, time.Second) + defer cancelFunc() + _, err := c.Ping(ctx, ip) + if err == nil { + completed() + } + } + go run() + defer completed() + for { + select { + case <-completedCtx.Done(): + return true + case <-ticker.C: + // Pings can take a while, so we can run multiple + // in parallel to return ASAP. + go run() + case <-ctx.Done(): + return false + } + } } // Closed is a channel that ends when the connection has @@ -466,7 +517,7 @@ func (c *Conn) sendNode() { } c.nodeSending = true go func() { - c.logger.Info(context.Background(), "sending node", slog.F("node", node)) + c.logger.Debug(context.Background(), "sending node", slog.F("node", node)) nodeCallback(node) c.lastMutex.Lock() c.nodeSending = false diff --git a/tailnet/conn_test.go b/tailnet/conn_test.go index bc3d5aec284af..a967a0772cdd8 100644 --- a/tailnet/conn_test.go +++ b/tailnet/conn_test.go @@ -62,14 +62,16 @@ func TestTailnet(t *testing.T) { err := w1.UpdateNodes([]*tailnet.Node{node}) require.NoError(t, err) }) - + require.True(t, w2.AwaitReachable(context.Background(), w1IP)) conn := make(chan struct{}) go func() { listener, err := w1.Listen("tcp", ":35565") assert.NoError(t, err) defer listener.Close() nc, err := listener.Accept() - assert.NoError(t, err) + if !assert.NoError(t, err) { + return + } _ = nc.Close() conn <- struct{}{} }()