diff --git a/agent/agent.go b/agent/agent.go index cf784a2702bfe..a7434b90d4854 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -229,13 +229,21 @@ type agent struct { // we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time // to start gracefully shutting down and "hard" which is Done when it is time to close // everything down (regardless of whether graceful shutdown completed). - gracefulCtx context.Context - gracefulCancel context.CancelFunc - hardCtx context.Context - hardCancel context.CancelFunc - closeWaitGroup sync.WaitGroup + gracefulCtx context.Context + gracefulCancel context.CancelFunc + hardCtx context.Context + hardCancel context.CancelFunc + + // closeMutex protects the following: closeMutex sync.Mutex + closeWaitGroup sync.WaitGroup coordDisconnected chan struct{} + closing bool + // note that once the network is set to non-nil, it is never modified, as with the statsReporter. So, routines + // that run after createOrUpdateNetwork and check the networkOK checkpoint do not need to hold the lock to use them. + network *tailnet.Conn + statsReporter *statsReporter + // end fields protected by closeMutex environmentVariables map[string]string @@ -259,9 +267,7 @@ type agent struct { reportConnectionsMu sync.Mutex reportConnections []*proto.ReportConnectionRequest - network *tailnet.Conn - statsReporter *statsReporter - logSender *agentsdk.LogSender + logSender *agentsdk.LogSender prometheusRegistry *prometheus.Registry // metrics are prometheus registered metrics that will be collected and @@ -274,6 +280,8 @@ type agent struct { } func (a *agent) TailnetConn() *tailnet.Conn { + a.closeMutex.Lock() + defer a.closeMutex.Unlock() return a.network } @@ -1205,15 +1213,15 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co } a.closeMutex.Lock() // Re-check if agent was closed while initializing the network. - closed := a.isClosed() - if !closed { + closing := a.closing + if !closing { a.network = network a.statsReporter = newStatsReporter(a.logger, network, a) } a.closeMutex.Unlock() - if closed { + if closing { _ = network.Close() - return xerrors.New("agent is closed") + return xerrors.New("agent is closing") } } else { // Update the wireguard IPs if the agent ID changed. @@ -1328,8 +1336,8 @@ func (*agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix { func (a *agent) trackGoroutine(fn func()) error { a.closeMutex.Lock() defer a.closeMutex.Unlock() - if a.isClosed() { - return xerrors.New("track conn goroutine: agent is closed") + if a.closing { + return xerrors.New("track conn goroutine: agent is closing") } a.closeWaitGroup.Add(1) go func() { @@ -1547,7 +1555,7 @@ func (a *agent) runCoordinator(ctx context.Context, tClient tailnetproto.DRPCTai func (a *agent) setCoordDisconnected() chan struct{} { a.closeMutex.Lock() defer a.closeMutex.Unlock() - if a.isClosed() { + if a.closing { return nil } disconnected := make(chan struct{}) @@ -1772,7 +1780,10 @@ func (a *agent) HTTPDebug() http.Handler { func (a *agent) Close() error { a.closeMutex.Lock() - defer a.closeMutex.Unlock() + network := a.network + coordDisconnected := a.coordDisconnected + a.closing = true + a.closeMutex.Unlock() if a.isClosed() { return nil } @@ -1849,7 +1860,7 @@ lifecycleWaitLoop: select { case <-a.hardCtx.Done(): a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect") - case <-a.coordDisconnected: + case <-coordDisconnected: a.logger.Debug(context.Background(), "coordinator RPC disconnected") } @@ -1860,8 +1871,8 @@ lifecycleWaitLoop: } a.hardCancel() - if a.network != nil { - _ = a.network.Close() + if network != nil { + _ = network.Close() } a.closeWaitGroup.Wait() diff --git a/agent/agent_test.go b/agent/agent_test.go index bbf0221ab5259..69423a2f83be7 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -68,6 +68,54 @@ func TestMain(m *testing.M) { var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} +// TestAgent_CloseWhileStarting is a regression test for https://github.com/coder/coder/issues/17328 +func TestAgent_ImmediateClose(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + logger := slogtest.Make(t, &slogtest.Options{ + // Agent can drop errors when shutting down, and some, like the + // fasthttplistener connection closed error, are unexported. + IgnoreErrors: true, + }).Leveled(slog.LevelDebug) + manifest := agentsdk.Manifest{ + AgentID: uuid.New(), + AgentName: "test-agent", + WorkspaceName: "test-workspace", + WorkspaceID: uuid.New(), + } + + coordinator := tailnet.NewCoordinator(logger) + t.Cleanup(func() { + _ = coordinator.Close() + }) + statsCh := make(chan *proto.Stats, 50) + fs := afero.NewMemMapFs() + client := agenttest.NewClient(t, logger.Named("agenttest"), manifest.AgentID, manifest, statsCh, coordinator) + t.Cleanup(client.Close) + + options := agent.Options{ + Client: client, + Filesystem: fs, + Logger: logger.Named("agent"), + ReconnectingPTYTimeout: 0, + EnvironmentVariables: map[string]string{}, + } + + agentUnderTest := agent.New(options) + t.Cleanup(func() { + _ = agentUnderTest.Close() + }) + + // wait until the agent has connected and is starting to find races in the startup code + _ = testutil.RequireRecvCtx(ctx, t, client.GetStartup()) + t.Log("Closing Agent") + err := agentUnderTest.Close() + require.NoError(t, err) +} + // NOTE: These tests only work when your default shell is bash for some reason. func TestAgent_Stats_SSH(t *testing.T) {