From 61f5d2687fd2baafd4e6aab5c63e273b78e9cc40 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 21 Dec 2023 11:32:25 +0400 Subject: [PATCH] chore: refactor agent connection updates --- coderd/workspaceagents.go | 217 +--------- coderd/workspaceagentsrpc.go | 351 ++++++++++------- coderd/workspaceagentsrpc_internal_test.go | 436 +++++++++++++++++++++ 3 files changed, 661 insertions(+), 343 deletions(-) create mode 100644 coderd/workspaceagentsrpc_internal_test.go diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index dd47275a4f6ac..e75f0ae28a9a0 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -12,11 +12,9 @@ import ( "net/http" "net/netip" "net/url" - "runtime/pprof" "sort" "strconv" "strings" - "sync/atomic" "time" "github.com/google/uuid" @@ -42,7 +40,6 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/prometheusmetrics" "github.com/coder/coder/v2/coderd/rbac" - "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" @@ -1084,21 +1081,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request api.WebsocketWaitMutex.Unlock() defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) - resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to accept websocket.", - Detail: err.Error(), - }) - return - } - - build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Internal error fetching workspace build job.", - Detail: err.Error(), - }) + // Ensure the resource is still valid! + // We only accept agents for resources on the latest build. + build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent) + if !ok { return } @@ -1120,32 +1106,6 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request return } - // Ensure the resource is still valid! - // We only accept agents for resources on the latest build. - ensureLatestBuild := func() error { - latestBuild, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID) - if err != nil { - return err - } - if build.ID != latestBuild.ID { - return xerrors.New("build is outdated") - } - return nil - } - - err = ensureLatestBuild() - if err != nil { - api.Logger.Debug(ctx, "agent tried to connect from non-latest build", - slog.F("resource", resource), - slog.F("agent", workspaceAgent), - ) - httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ - Message: "Agent trying to connect from non-latest build.", - Detail: err.Error(), - }) - return - } - conn, err := websocket.Accept(rw, r, nil) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -1158,109 +1118,10 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request ctx, wsNetConn := websocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() - // We use a custom heartbeat routine here instead of `httpapi.Heartbeat` - // because we want to log the agent's last ping time. - var lastPing atomic.Pointer[time.Time] - lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive. - - go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) { - // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? - t := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer t.Stop() - - for { - select { - case <-t.C: - case <-ctx.Done(): - return - } - - // We don't need a context that times out here because the ping will - // eventually go through. If the context times out, then other - // websocket read operations will receive an error, obfuscating the - // actual problem. - err := conn.Ping(ctx) - if err != nil { - return - } - lastPing.Store(ptr.Ref(time.Now())) - } - }) - - firstConnectedAt := workspaceAgent.FirstConnectedAt - if !firstConnectedAt.Valid { - firstConnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } - lastConnectedAt := sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - disconnectedAt := workspaceAgent.DisconnectedAt - updateConnectionTimes := func(ctx context.Context) error { - //nolint:gocritic // We only update ourself. - err = api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: workspaceAgent.ID, - FirstConnectedAt: firstConnectedAt, - LastConnectedAt: lastConnectedAt, - DisconnectedAt: disconnectedAt, - UpdatedAt: dbtime.Now(), - LastConnectedReplicaID: uuid.NullUUID{ - UUID: api.ID, - Valid: true, - }, - }) - if err != nil { - return err - } - return nil - } - - defer func() { - // If connection closed then context will be canceled, try to - // ensure our final update is sent. By waiting at most the agent - // inactive disconnect timeout we ensure that we don't block but - // also guarantee that the agent will be considered disconnected - // by normal status check. - // - // Use a system context as the agent has disconnected and that token - // may no longer be valid. - //nolint:gocritic - ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout) - defer cancel() - - // Only update timestamp if the disconnect is new. - if !disconnectedAt.Valid { - disconnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } - err := updateConnectionTimes(ctx) - if err != nil { - // This is a bug with unit tests that cancel the app context and - // cause this error log to be generated. We should fix the unit tests - // as this is a valid log. - // - // The pq error occurs when the server is shutting down. - if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) { - api.Logger.Error(ctx, "failed to update agent disconnect time", - slog.Error(err), - slog.F("workspace_id", build.WorkspaceID), - ) - } - } - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) - }() - - err = updateConnectionTimes(ctx) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, err.Error()) - return - } - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) + closeCtx, closeCtxCancel := context.WithCancel(ctx) + defer closeCtxCancel() + monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn) + defer monitor.close() api.Logger.Debug(ctx, "accepting agent", slog.F("owner", owner.Username), @@ -1271,61 +1132,13 @@ func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request defer conn.Close(websocket.StatusNormalClosure, "") - closeChan := make(chan struct{}) - go func() { - defer close(closeChan) - err := (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID, - fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), - ) - if err != nil { - api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err)) - _ = conn.Close(websocket.StatusInternalError, err.Error()) - return - } - }() - ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer ticker.Stop() - for { - select { - case <-closeChan: - return - case <-ticker.C: - } - - lastPing := *lastPing.Load() - - var connectionStatusChanged bool - if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout { - if !disconnectedAt.Valid { - connectionStatusChanged = true - disconnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } - } else { - connectionStatusChanged = disconnectedAt.Valid - // TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it? - disconnectedAt = sql.NullTime{} - lastConnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } - err = updateConnectionTimes(ctx) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, err.Error()) - return - } - if connectionStatusChanged { - api.publishWorkspaceUpdate(ctx, build.WorkspaceID) - } - err := ensureLatestBuild() - if err != nil { - // Disconnect agents that are no longer valid. - _ = conn.Close(websocket.StatusGoingAway, "") - return - } + err = (*api.TailnetCoordinator.Load()).ServeAgent(wsNetConn, workspaceAgent.ID, + fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name), + ) + if err != nil { + api.Logger.Warn(ctx, "tailnet coordinator agent error", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return } } diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 66cde3876f95d..6b9438a8b8c9f 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "runtime/pprof" + "sync" "sync/atomic" "time" @@ -42,7 +43,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) - ensureLatestBuildFn, build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent) + build, ok := ensureLatestBuild(ctx, api.Database, api.Logger, rw, workspaceAgent) if !ok { return } @@ -96,10 +97,10 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { defer conn.Close(websocket.StatusNormalClosure, "") - pingFn, ok := api.agentConnectionUpdate(ctx, workspaceAgent, build.WorkspaceID, conn) - if !ok { - return - } + closeCtx, closeCtxCancel := context.WithCancel(ctx) + defer closeCtxCancel() + monitor := api.startAgentWebsocketMonitor(closeCtx, workspaceAgent, build, conn) + defer monitor.close() agentAPI := agentapi.New(agentapi.Options{ AgentID: workspaceAgent.ID, @@ -136,29 +137,22 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { Auth: tailnet.AgentTunnelAuth{}, } ctx = tailnet.WithStreamID(ctx, streamID) - - closeCtx, closeCtxCancel := context.WithCancel(ctx) - go func() { - defer closeCtxCancel() - err := agentAPI.Serve(ctx, mux) - if err != nil { - api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err)) - _ = conn.Close(websocket.StatusInternalError, err.Error()) - return - } - }() - - pingFn(closeCtx, ensureLatestBuildFn) + err = agentAPI.Serve(ctx, mux) + if err != nil { + api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, err.Error()) + return + } } -func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logger, rw http.ResponseWriter, workspaceAgent database.WorkspaceAgent) (func() error, database.WorkspaceBuild, bool) { +func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logger, rw http.ResponseWriter, workspaceAgent database.WorkspaceAgent) (database.WorkspaceBuild, bool) { resource, err := db.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Internal error fetching workspace agent resource.", Detail: err.Error(), }) - return nil, database.WorkspaceBuild{}, false + return database.WorkspaceBuild{}, false } build, err := db.GetWorkspaceBuildByJobID(ctx, resource.JobID) @@ -167,23 +161,12 @@ func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logge Message: "Internal error fetching workspace build job.", Detail: err.Error(), }) - return nil, database.WorkspaceBuild{}, false + return database.WorkspaceBuild{}, false } // Ensure the resource is still valid! // We only accept agents for resources on the latest build. - ensureLatestBuild := func() error { - latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID) - if err != nil { - return err - } - if build.ID != latestBuild.ID { - return xerrors.New("build is outdated") - } - return nil - } - - err = ensureLatestBuild() + err = checkBuildIsLatest(ctx, db, build) if err != nil { logger.Debug(ctx, "agent tried to connect from non-latest build", slog.F("resource", resource), @@ -193,73 +176,159 @@ func ensureLatestBuild(ctx context.Context, db database.Store, logger slog.Logge Message: "Agent trying to connect from non-latest build.", Detail: err.Error(), }) - return nil, database.WorkspaceBuild{}, false + return database.WorkspaceBuild{}, false } - return ensureLatestBuild, build, true + return build, true } -func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent database.WorkspaceAgent, workspaceID uuid.UUID, conn *websocket.Conn) (func(closeCtx context.Context, ensureLatestBuildFn func() error), bool) { - // We use a custom heartbeat routine here instead of `httpapi.Heartbeat` - // because we want to log the agent's last ping time. - var lastPing atomic.Pointer[time.Time] - lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive. - - go pprof.Do(ctx, pprof.Labels("agent", workspaceAgent.ID.String()), func(ctx context.Context) { - // TODO(mafredri): Is this too frequent? Use separate ping disconnect timeout? - t := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer t.Stop() - - for { - select { - case <-t.C: - case <-ctx.Done(): - return - } +func checkBuildIsLatest(ctx context.Context, db database.Store, build database.WorkspaceBuild) error { + latestBuild, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, build.WorkspaceID) + if err != nil { + return err + } + if build.ID != latestBuild.ID { + return xerrors.New("build is outdated") + } + return nil +} - // We don't need a context that times out here because the ping will - // eventually go through. If the context times out, then other - // websocket read operations will receive an error, obfuscating the - // actual problem. - err := conn.Ping(ctx) - if err != nil { - return - } - lastPing.Store(ptr.Ref(time.Now())) +func (api *API) startAgentWebsocketMonitor(ctx context.Context, + workspaceAgent database.WorkspaceAgent, workspaceBuild database.WorkspaceBuild, + conn *websocket.Conn, +) *agentWebsocketMonitor { + monitor := &agentWebsocketMonitor{ + apiCtx: api.ctx, + workspaceAgent: workspaceAgent, + workspaceBuild: workspaceBuild, + conn: conn, + pingPeriod: api.AgentConnectionUpdateFrequency, + db: api.Database, + replicaID: api.ID, + updater: api, + disconnectTimeout: api.AgentInactiveDisconnectTimeout, + logger: api.Logger.With( + slog.F("workspace_id", workspaceBuild.WorkspaceID), + slog.F("agent_id", workspaceAgent.ID), + ), + } + monitor.init() + monitor.start(ctx) + + return monitor +} + +type workspaceUpdater interface { + publishWorkspaceUpdate(ctx context.Context, workspaceID uuid.UUID) +} + +type pingerCloser interface { + Ping(ctx context.Context) error + Close(code websocket.StatusCode, reason string) error +} + +type agentWebsocketMonitor struct { + apiCtx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + workspaceAgent database.WorkspaceAgent + workspaceBuild database.WorkspaceBuild + conn pingerCloser + db database.Store + replicaID uuid.UUID + updater workspaceUpdater + logger slog.Logger + pingPeriod time.Duration + + // state manipulated by both sendPings() and monitor() goroutines: needs to be threadsafe + lastPing atomic.Pointer[time.Time] + + // state manipulated only by monitor() goroutine: does not need to be threadsafe + firstConnectedAt sql.NullTime + lastConnectedAt sql.NullTime + disconnectedAt sql.NullTime + disconnectTimeout time.Duration +} + +// sendPings sends websocket pings. +// +// We use a custom heartbeat routine here instead of `httpapi.Heartbeat` +// because we want to log the agent's last ping time. +func (m *agentWebsocketMonitor) sendPings(ctx context.Context) { + t := time.NewTicker(m.pingPeriod) + defer t.Stop() + + for { + select { + case <-t.C: + case <-ctx.Done(): + return } + + // We don't need a context that times out here because the ping will + // eventually go through. If the context times out, then other + // websocket read operations will receive an error, obfuscating the + // actual problem. + err := m.conn.Ping(ctx) + if err != nil { + return + } + m.lastPing.Store(ptr.Ref(time.Now())) + } +} + +func (m *agentWebsocketMonitor) updateConnectionTimes(ctx context.Context) error { + //nolint:gocritic // We only update the agent we are minding. + err := m.db.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{ + ID: m.workspaceAgent.ID, + FirstConnectedAt: m.firstConnectedAt, + LastConnectedAt: m.lastConnectedAt, + DisconnectedAt: m.disconnectedAt, + UpdatedAt: dbtime.Now(), + LastConnectedReplicaID: uuid.NullUUID{ + UUID: m.replicaID, + Valid: true, + }, }) + if err != nil { + return xerrors.Errorf("failed to update workspace agent connection times: %w", err) + } + return nil +} - firstConnectedAt := workspaceAgent.FirstConnectedAt - if !firstConnectedAt.Valid { - firstConnectedAt = sql.NullTime{ - Time: dbtime.Now(), +func (m *agentWebsocketMonitor) init() { + now := dbtime.Now() + m.firstConnectedAt = m.workspaceAgent.FirstConnectedAt + if !m.firstConnectedAt.Valid { + m.firstConnectedAt = sql.NullTime{ + Time: now, Valid: true, } } - lastConnectedAt := sql.NullTime{ - Time: dbtime.Now(), + m.lastConnectedAt = sql.NullTime{ + Time: now, Valid: true, } - disconnectedAt := workspaceAgent.DisconnectedAt - updateConnectionTimes := func(ctx context.Context) error { - //nolint:gocritic // We only update ourself. - err := api.Database.UpdateWorkspaceAgentConnectionByID(dbauthz.AsSystemRestricted(ctx), database.UpdateWorkspaceAgentConnectionByIDParams{ - ID: workspaceAgent.ID, - FirstConnectedAt: firstConnectedAt, - LastConnectedAt: lastConnectedAt, - DisconnectedAt: disconnectedAt, - UpdatedAt: dbtime.Now(), - LastConnectedReplicaID: uuid.NullUUID{ - UUID: api.ID, - Valid: true, - }, + m.disconnectedAt = m.workspaceAgent.DisconnectedAt + m.lastPing.Store(ptr.Ref(time.Now())) // Since the agent initiated the request, assume it's alive. +} + +func (m *agentWebsocketMonitor) start(ctx context.Context) { + ctx, m.cancel = context.WithCancel(ctx) + m.wg.Add(2) + go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()), + func(ctx context.Context) { + defer m.wg.Done() + m.sendPings(ctx) }) - if err != nil { - return err - } - return nil - } + go pprof.Do(ctx, pprof.Labels("agent", m.workspaceAgent.ID.String()), + func(ctx context.Context) { + defer m.wg.Done() + m.monitor(ctx) + }) +} +func (m *agentWebsocketMonitor) monitor(ctx context.Context) { defer func() { // If connection closed then context will be canceled, try to // ensure our final update is sent. By waiting at most the agent @@ -270,17 +339,17 @@ func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent databa // Use a system context as the agent has disconnected and that token // may no longer be valid. //nolint:gocritic - ctx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(api.ctx), api.AgentInactiveDisconnectTimeout) + finalCtx, cancel := context.WithTimeout(dbauthz.AsSystemRestricted(m.apiCtx), m.disconnectTimeout) defer cancel() // Only update timestamp if the disconnect is new. - if !disconnectedAt.Valid { - disconnectedAt = sql.NullTime{ + if !m.disconnectedAt.Valid { + m.disconnectedAt = sql.NullTime{ Time: dbtime.Now(), Valid: true, } } - err := updateConnectionTimes(ctx) + err := m.updateConnectionTimes(finalCtx) if err != nil { // This is a bug with unit tests that cancel the app context and // cause this error log to be generated. We should fix the unit tests @@ -288,66 +357,66 @@ func (api *API) agentConnectionUpdate(ctx context.Context, workspaceAgent databa // // The pq error occurs when the server is shutting down. if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) { - api.Logger.Error(ctx, "failed to update agent disconnect time", + m.logger.Error(finalCtx, "failed to update agent disconnect time", slog.Error(err), - slog.F("workspace_id", workspaceID), ) } } - api.publishWorkspaceUpdate(ctx, workspaceID) + m.updater.publishWorkspaceUpdate(finalCtx, m.workspaceBuild.WorkspaceID) + }() + reason := "disconnect" + defer func() { + m.logger.Debug(ctx, "agent websocket monitor is closing connection", + slog.F("reason", reason)) + _ = m.conn.Close(websocket.StatusGoingAway, reason) }() - err := updateConnectionTimes(ctx) + err := m.updateConnectionTimes(ctx) if err != nil { - _ = conn.Close(websocket.StatusGoingAway, err.Error()) - return nil, false + reason = err.Error() + return } - api.publishWorkspaceUpdate(ctx, workspaceID) - - return func(closeCtx context.Context, ensureLatestBuildFn func() error) { - ticker := time.NewTicker(api.AgentConnectionUpdateFrequency) - defer ticker.Stop() - for { - select { - case <-closeCtx.Done(): - return - case <-ticker.C: - } + m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID) + + ticker := time.NewTicker(m.pingPeriod) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + reason = "canceled" + return + case <-ticker.C: + } - lastPing := *lastPing.Load() - - var connectionStatusChanged bool - if time.Since(lastPing) > api.AgentInactiveDisconnectTimeout { - if !disconnectedAt.Valid { - connectionStatusChanged = true - disconnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } - } else { - connectionStatusChanged = disconnectedAt.Valid - // TODO(mafredri): Should we update it here or allow lastConnectedAt to shadow it? - disconnectedAt = sql.NullTime{} - lastConnectedAt = sql.NullTime{ - Time: dbtime.Now(), - Valid: true, - } - } - err = updateConnectionTimes(ctx) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, err.Error()) - return - } - if connectionStatusChanged { - api.publishWorkspaceUpdate(ctx, workspaceID) - } - err := ensureLatestBuildFn() - if err != nil { - // Disconnect agents that are no longer valid. - _ = conn.Close(websocket.StatusGoingAway, "") - return - } + lastPing := *m.lastPing.Load() + if time.Since(lastPing) > m.disconnectTimeout { + reason = "ping timeout" + return + } + connectionStatusChanged := m.disconnectedAt.Valid + m.disconnectedAt = sql.NullTime{} + m.lastConnectedAt = sql.NullTime{ + Time: dbtime.Now(), + Valid: true, + } + + err = m.updateConnectionTimes(ctx) + if err != nil { + reason = err.Error() + return + } + if connectionStatusChanged { + m.updater.publishWorkspaceUpdate(ctx, m.workspaceBuild.WorkspaceID) + } + err = checkBuildIsLatest(ctx, m.db, m.workspaceBuild) + if err != nil { + reason = err.Error() + return } - }, true + } +} + +func (m *agentWebsocketMonitor) close() { + m.cancel() + m.wg.Wait() } diff --git a/coderd/workspaceagentsrpc_internal_test.go b/coderd/workspaceagentsrpc_internal_test.go new file mode 100644 index 0000000000000..b748048b203a2 --- /dev/null +++ b/coderd/workspaceagentsrpc_internal_test.go @@ -0,0 +1,436 @@ +package coderd + +import ( + "context" + "database/sql" + "fmt" + "sync" + "testing" + "time" + + "github.com/coder/coder/v2/coderd/util/ptr" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "nhooyr.io/websocket" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/testutil" +) + +func TestAgentWebsocketMonitor_ContextCancel(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + now := dbtime.Now() + fConn := &fakePingerCloser{} + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + fUpdater := &fakeUpdater{} + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agent := database.WorkspaceAgent{ + ID: uuid.New(), + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + } + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: uuid.New(), + } + replicaID := uuid.New() + + uut := &agentWebsocketMonitor{ + apiCtx: ctx, + workspaceAgent: agent, + workspaceBuild: build, + conn: fConn, + db: mDB, + replicaID: replicaID, + updater: fUpdater, + logger: logger, + pingPeriod: testutil.IntervalFast, + disconnectTimeout: testutil.WaitShort, + } + uut.init() + + connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID), + ). + AnyTimes(). + Return(nil) + mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID, withDisconnected()), + ). + After(connected). + Times(1). + Return(nil) + mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). + AnyTimes(). + Return(database.WorkspaceBuild{ID: build.ID}, nil) + + closeCtx, cancel := context.WithCancel(ctx) + defer cancel() + done := make(chan struct{}) + go func() { + uut.monitor(closeCtx) + close(done) + }() + // wait a couple intervals, but not long enough for a disconnect + time.Sleep(3 * testutil.IntervalFast) + fConn.requireNotClosed(t) + fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) + n := fUpdater.getUpdates() + cancel() + fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "canceled") + + // make sure we got at least one additional update on close + _ = testutil.RequireRecvCtx(ctx, t, done) + m := fUpdater.getUpdates() + require.Greater(t, m, n) +} + +func TestAgentWebsocketMonitor_PingTimeout(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + now := dbtime.Now() + fConn := &fakePingerCloser{} + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + fUpdater := &fakeUpdater{} + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agent := database.WorkspaceAgent{ + ID: uuid.New(), + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + } + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: uuid.New(), + } + replicaID := uuid.New() + + uut := &agentWebsocketMonitor{ + apiCtx: ctx, + workspaceAgent: agent, + workspaceBuild: build, + conn: fConn, + db: mDB, + replicaID: replicaID, + updater: fUpdater, + logger: logger, + pingPeriod: testutil.IntervalFast, + disconnectTimeout: testutil.WaitShort, + } + uut.init() + // set the last ping to the past, so we go thru the timeout + uut.lastPing.Store(ptr.Ref(now.Add(-time.Hour))) + + connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID), + ). + AnyTimes(). + Return(nil) + mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID, withDisconnected()), + ). + After(connected). + Times(1). + Return(nil) + mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). + AnyTimes(). + Return(database.WorkspaceBuild{ID: build.ID}, nil) + + go uut.monitor(ctx) + fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "ping timeout") + fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) +} + +func TestAgentWebsocketMonitor_BuildOutdated(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + now := dbtime.Now() + fConn := &fakePingerCloser{} + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + fUpdater := &fakeUpdater{} + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agent := database.WorkspaceAgent{ + ID: uuid.New(), + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + } + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: uuid.New(), + } + replicaID := uuid.New() + + uut := &agentWebsocketMonitor{ + apiCtx: ctx, + workspaceAgent: agent, + workspaceBuild: build, + conn: fConn, + db: mDB, + replicaID: replicaID, + updater: fUpdater, + logger: logger, + pingPeriod: testutil.IntervalFast, + disconnectTimeout: testutil.WaitShort, + } + uut.init() + + connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID), + ). + AnyTimes(). + Return(nil) + mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID, withDisconnected()), + ). + After(connected). + Times(1). + Return(nil) + + // return a new buildID each time, meaning the connection is outdated + mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). + AnyTimes(). + Return(database.WorkspaceBuild{ID: uuid.New()}, nil) + + go uut.monitor(ctx) + fConn.requireEventuallyClosed(t, websocket.StatusGoingAway, "build is outdated") + fUpdater.requireEventuallySomeUpdates(t, build.WorkspaceID) +} + +func TestAgentWebsocketMonitor_SendPings(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + fConn := &fakePingerCloser{} + uut := &agentWebsocketMonitor{ + pingPeriod: testutil.IntervalFast, + conn: fConn, + } + go uut.sendPings(ctx) + fConn.requireEventuallyHasPing(t) + lastPing := uut.lastPing.Load() + require.NotNil(t, lastPing) +} + +func TestAgentWebsocketMonitor_StartClose(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + fConn := &fakePingerCloser{} + now := dbtime.Now() + ctrl := gomock.NewController(t) + mDB := dbmock.NewMockStore(ctrl) + fUpdater := &fakeUpdater{} + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agent := database.WorkspaceAgent{ + ID: uuid.New(), + FirstConnectedAt: sql.NullTime{ + Time: now.Add(-time.Minute), + Valid: true, + }, + } + build := database.WorkspaceBuild{ + ID: uuid.New(), + WorkspaceID: uuid.New(), + } + replicaID := uuid.New() + uut := &agentWebsocketMonitor{ + apiCtx: ctx, + workspaceAgent: agent, + workspaceBuild: build, + conn: fConn, + db: mDB, + replicaID: replicaID, + updater: fUpdater, + logger: logger, + pingPeriod: testutil.IntervalFast, + disconnectTimeout: testutil.WaitShort, + } + + connected := mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID), + ). + AnyTimes(). + Return(nil) + mDB.EXPECT().UpdateWorkspaceAgentConnectionByID( + gomock.Any(), + connectionUpdate(agent.ID, replicaID, withDisconnected()), + ). + After(connected). + Times(1). + Return(nil) + mDB.EXPECT().GetLatestWorkspaceBuildByWorkspaceID(gomock.Any(), build.WorkspaceID). + AnyTimes(). + Return(database.WorkspaceBuild{ID: build.ID}, nil) + + uut.start(ctx) + closed := make(chan struct{}) + go func() { + uut.close() + close(closed) + }() + _ = testutil.RequireRecvCtx(ctx, t, closed) +} + +type fakePingerCloser struct { + sync.Mutex + pings []time.Time + code websocket.StatusCode + reason string + closed bool +} + +func (f *fakePingerCloser) Ping(context.Context) error { + f.Lock() + defer f.Unlock() + f.pings = append(f.pings, time.Now()) + return nil +} + +func (f *fakePingerCloser) Close(code websocket.StatusCode, reason string) error { + f.Lock() + defer f.Unlock() + if f.closed { + return nil + } + f.closed = true + f.code = code + f.reason = reason + return nil +} + +func (f *fakePingerCloser) requireNotClosed(t *testing.T) { + f.Lock() + defer f.Unlock() + require.False(t, f.closed) +} + +func (f *fakePingerCloser) requireEventuallyClosed(t *testing.T, code websocket.StatusCode, reason string) { + require.Eventually(t, func() bool { + f.Lock() + defer f.Unlock() + return f.closed + }, testutil.WaitShort, testutil.IntervalFast) + f.Lock() + defer f.Unlock() + require.Equal(t, code, f.code) + require.Equal(t, reason, f.reason) +} + +func (f *fakePingerCloser) requireEventuallyHasPing(t *testing.T) { + require.Eventually(t, func() bool { + f.Lock() + defer f.Unlock() + return len(f.pings) > 0 + }, testutil.WaitShort, testutil.IntervalFast) +} + +type fakeUpdater struct { + sync.Mutex + updates []uuid.UUID +} + +func (f *fakeUpdater) publishWorkspaceUpdate(_ context.Context, workspaceID uuid.UUID) { + f.Lock() + defer f.Unlock() + f.updates = append(f.updates, workspaceID) +} + +func (f *fakeUpdater) requireEventuallySomeUpdates(t *testing.T, workspaceID uuid.UUID) { + require.Eventually(t, func() bool { + f.Lock() + defer f.Unlock() + return len(f.updates) >= 1 + }, testutil.WaitShort, testutil.IntervalFast) + + f.Lock() + defer f.Unlock() + for _, u := range f.updates { + require.Equal(t, workspaceID, u) + } +} + +func (f *fakeUpdater) getUpdates() int { + f.Lock() + defer f.Unlock() + return len(f.updates) +} + +type connectionUpdateMatcher struct { + agentID uuid.UUID + replicaID uuid.UUID + disconnected bool +} + +type connectionUpdateMatcherOption func(m connectionUpdateMatcher) connectionUpdateMatcher + +func connectionUpdate(id, replica uuid.UUID, opts ...connectionUpdateMatcherOption) connectionUpdateMatcher { + m := connectionUpdateMatcher{ + agentID: id, + replicaID: replica, + } + for _, opt := range opts { + m = opt(m) + } + return m +} + +func withDisconnected() connectionUpdateMatcherOption { + return func(m connectionUpdateMatcher) connectionUpdateMatcher { + m.disconnected = true + return m + } +} + +func (m connectionUpdateMatcher) Matches(x interface{}) bool { + args, ok := x.(database.UpdateWorkspaceAgentConnectionByIDParams) + if !ok { + return false + } + if args.ID != m.agentID { + return false + } + if !args.LastConnectedReplicaID.Valid { + return false + } + if args.LastConnectedReplicaID.UUID != m.replicaID { + return false + } + if args.DisconnectedAt.Valid != m.disconnected { + return false + } + return true +} + +func (m connectionUpdateMatcher) String() string { + return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}", + m.agentID.String(), m.replicaID.String(), m.disconnected) +} + +func (connectionUpdateMatcher) Got(x interface{}) string { + args, ok := x.(database.UpdateWorkspaceAgentConnectionByIDParams) + if !ok { + return fmt.Sprintf("type=%T", x) + } + return fmt.Sprintf("{agent=%s, replica=%s, disconnected=%t}", + args.ID, args.LastConnectedReplicaID.UUID, args.DisconnectedAt.Valid) +}