Skip to content

fix(agent): fix deadlock if closed while starting listeners #17329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,21 @@ type agent struct {
// we track 2 contexts and associated cancel functions: "graceful" which is Done when it is time
// to start gracefully shutting down and "hard" which is Done when it is time to close
// everything down (regardless of whether graceful shutdown completed).
gracefulCtx context.Context
gracefulCancel context.CancelFunc
hardCtx context.Context
hardCancel context.CancelFunc
closeWaitGroup sync.WaitGroup
gracefulCtx context.Context
gracefulCancel context.CancelFunc
hardCtx context.Context
hardCancel context.CancelFunc

// closeMutex protects the following:
closeMutex sync.Mutex
closeWaitGroup sync.WaitGroup
coordDisconnected chan struct{}
closing bool
// note that once the network is set to non-nil, it is never modified, as with the statsReporter. So, routines
// that run after createOrUpdateNetwork and check the networkOK checkpoint do not need to hold the lock to use them.
network *tailnet.Conn
statsReporter *statsReporter
// end fields protected by closeMutex

environmentVariables map[string]string

Expand All @@ -259,9 +267,7 @@ type agent struct {
reportConnectionsMu sync.Mutex
reportConnections []*proto.ReportConnectionRequest

network *tailnet.Conn
statsReporter *statsReporter
logSender *agentsdk.LogSender
logSender *agentsdk.LogSender

prometheusRegistry *prometheus.Registry
// metrics are prometheus registered metrics that will be collected and
Expand All @@ -274,6 +280,8 @@ type agent struct {
}

func (a *agent) TailnetConn() *tailnet.Conn {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
return a.network
}

Expand Down Expand Up @@ -1205,15 +1213,15 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
}
a.closeMutex.Lock()
// Re-check if agent was closed while initializing the network.
closed := a.isClosed()
if !closed {
closing := a.closing
if !closing {
a.network = network
a.statsReporter = newStatsReporter(a.logger, network, a)
}
a.closeMutex.Unlock()
if closed {
if closing {
_ = network.Close()
return xerrors.New("agent is closed")
return xerrors.New("agent is closing")
}
} else {
// Update the wireguard IPs if the agent ID changed.
Expand Down Expand Up @@ -1328,8 +1336,8 @@ func (*agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
func (a *agent) trackGoroutine(fn func()) error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
return xerrors.New("track conn goroutine: agent is closed")
if a.closing {
return xerrors.New("track conn goroutine: agent is closing")
}
a.closeWaitGroup.Add(1)
go func() {
Expand Down Expand Up @@ -1547,7 +1555,7 @@ func (a *agent) runCoordinator(ctx context.Context, tClient tailnetproto.DRPCTai
func (a *agent) setCoordDisconnected() chan struct{} {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
if a.isClosed() {
if a.closing {
return nil
}
disconnected := make(chan struct{})
Expand Down Expand Up @@ -1772,7 +1780,10 @@ func (a *agent) HTTPDebug() http.Handler {

func (a *agent) Close() error {
a.closeMutex.Lock()
defer a.closeMutex.Unlock()
network := a.network
coordDisconnected := a.coordDisconnected
a.closing = true
a.closeMutex.Unlock()
if a.isClosed() {
return nil
}
Expand Down Expand Up @@ -1849,7 +1860,7 @@ lifecycleWaitLoop:
select {
case <-a.hardCtx.Done():
a.logger.Warn(context.Background(), "timed out waiting for Coordinator RPC disconnect")
case <-a.coordDisconnected:
case <-coordDisconnected:
a.logger.Debug(context.Background(), "coordinator RPC disconnected")
}

Expand All @@ -1860,8 +1871,8 @@ lifecycleWaitLoop:
}

a.hardCancel()
if a.network != nil {
_ = a.network.Close()
if network != nil {
_ = network.Close()
}
a.closeWaitGroup.Wait()

Expand Down
48 changes: 48 additions & 0 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,54 @@ func TestMain(m *testing.M) {

var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}

// TestAgent_CloseWhileStarting is a regression test for https://github.com/coder/coder/issues/17328
func TestAgent_ImmediateClose(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()

logger := slogtest.Make(t, &slogtest.Options{
// Agent can drop errors when shutting down, and some, like the
// fasthttplistener connection closed error, are unexported.
IgnoreErrors: true,
}).Leveled(slog.LevelDebug)
manifest := agentsdk.Manifest{
AgentID: uuid.New(),
AgentName: "test-agent",
WorkspaceName: "test-workspace",
WorkspaceID: uuid.New(),
}

coordinator := tailnet.NewCoordinator(logger)
t.Cleanup(func() {
_ = coordinator.Close()
})
statsCh := make(chan *proto.Stats, 50)
fs := afero.NewMemMapFs()
client := agenttest.NewClient(t, logger.Named("agenttest"), manifest.AgentID, manifest, statsCh, coordinator)
t.Cleanup(client.Close)

options := agent.Options{
Client: client,
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: 0,
EnvironmentVariables: map[string]string{},
}

agentUnderTest := agent.New(options)
t.Cleanup(func() {
_ = agentUnderTest.Close()
})

// wait until the agent has connected and is starting to find races in the startup code
_ = testutil.RequireRecvCtx(ctx, t, client.GetStartup())
t.Log("Closing Agent")
err := agentUnderTest.Close()
require.NoError(t, err)
}

// NOTE: These tests only work when your default shell is bash for some reason.

func TestAgent_Stats_SSH(t *testing.T) {
Expand Down
Loading