diff --git a/agent/agent.go b/agent/agent.go index b03cdb11c6810..31b4b8959f8df 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -215,8 +215,16 @@ func (a *agent) run(ctx context.Context) error { return xerrors.Errorf("create tailnet: %w", err) } a.closeMutex.Lock() - a.network = network + // Re-check if agent was closed while initializing the network. + closed := a.isClosed() + if !closed { + a.network = network + } a.closeMutex.Unlock() + if closed { + _ = network.Close() + return xerrors.New("agent is closed") + } } else { // Update the DERP map! network.SetDERPMap(metadata.DERPMap) @@ -246,11 +254,6 @@ func (a *agent) trackConnGoroutine(fn func()) error { } func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ *tailnet.Conn, err error) { - a.closeMutex.Lock() - if a.isClosed() { - a.closeMutex.Unlock() - return nil, xerrors.New("closed") - } network, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)}, DERPMap: derpMap, @@ -258,7 +261,6 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ EnableTrafficStats: true, }) if err != nil { - a.closeMutex.Unlock() return nil, xerrors.Errorf("create tailnet: %w", err) } defer func() { @@ -266,7 +268,6 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (_ network.Close() } }() - a.closeMutex.Unlock() sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort)) if err != nil { diff --git a/agent/agent_test.go b/agent/agent_test.go index f54a699f654af..6d2b2e5a8e09c 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -806,12 +806,17 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo }) require.NoError(t, err) clientConn, serverConn := net.Pipe() + serveClientDone := make(chan struct{}) t.Cleanup(func() { _ = clientConn.Close() _ = serverConn.Close() _ = conn.Close() + <-serveClientDone }) - go coordinator.ServeClient(serverConn, uuid.New(), agentID) + go func() { + defer close(serveClientDone) + coordinator.ServeClient(serverConn, uuid.New(), agentID) + }() sendNode, _ := tailnet.ServeCoordinator(clientConn, func(node []*tailnet.Node) error { return conn.UpdateNodes(node) })